您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
222 行
8.0 KiB
222 行
8.0 KiB
using System.Collections;
|
|
using UnityEngine;
|
|
using Unity.MLAgents;
|
|
using Unity.MLAgents.Actuators;
|
|
using Unity.MLAgents.Sensors;
|
|
|
|
public class HallwayCollabAgent : HallwayAgent
|
|
{
|
|
public GameObject symbolSGoal;
|
|
public GameObject symbolS;
|
|
public HallwayCollabAgent teammate;
|
|
public bool isSpotter = true;
|
|
TextMesh m_MessageText;
|
|
TextMesh m_MessageRec;
|
|
int m_Message = 0;
|
|
|
|
|
|
[HideInInspector]
|
|
public int selection = 0;
|
|
|
|
public override void Initialize()
|
|
{
|
|
base.Initialize();
|
|
if (isSpotter)
|
|
{
|
|
m_MessageText = gameObject.GetComponentInChildren<TextMesh>();
|
|
}
|
|
|
|
}
|
|
|
|
public override void OnEpisodeBegin()
|
|
{
|
|
m_Message = -1;
|
|
var agentOffset = 10f;
|
|
if (isSpotter)
|
|
{
|
|
agentOffset = -15;
|
|
}
|
|
|
|
if (!isSpotter)
|
|
{
|
|
transform.position = new Vector3(0f + Random.Range(-3f, 3f),
|
|
1f, agentOffset + Random.Range(-5f, 5f))
|
|
+ ground.transform.position;
|
|
transform.rotation = Quaternion.Euler(0f, Random.Range(0f, 360f), 0f);
|
|
}
|
|
else
|
|
{
|
|
transform.position = new Vector3(0f,
|
|
1f, agentOffset)
|
|
+ ground.transform.position;
|
|
transform.rotation = Quaternion.Euler(0f, 0f, 0f);
|
|
}
|
|
|
|
// Remove the randomness
|
|
|
|
m_AgentRb.velocity *= 0f;
|
|
if (isSpotter)
|
|
{
|
|
var blockOffset = -9f;
|
|
// Only the Spotter has the correct selection
|
|
selection = Random.Range(0, 3);
|
|
if (selection == 0)
|
|
{
|
|
symbolO.transform.position =
|
|
new Vector3(0f, 2f, blockOffset)
|
|
+ ground.transform.position;
|
|
symbolX.transform.position =
|
|
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
|
|
+ ground.transform.position;
|
|
symbolS.transform.position =
|
|
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
|
|
+ ground.transform.position;
|
|
}
|
|
else if (selection == 1)
|
|
{
|
|
symbolO.transform.position =
|
|
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
|
|
+ ground.transform.position;
|
|
symbolX.transform.position =
|
|
new Vector3(0f, 2f, blockOffset)
|
|
+ ground.transform.position;
|
|
symbolS.transform.position =
|
|
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
|
|
+ ground.transform.position;
|
|
}
|
|
else
|
|
{
|
|
symbolO.transform.position =
|
|
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
|
|
+ ground.transform.position;
|
|
symbolX.transform.position =
|
|
new Vector3(0f, -1000f, blockOffset)
|
|
+ ground.transform.position;
|
|
symbolS.transform.position =
|
|
new Vector3(0f, 2f, blockOffset)
|
|
+ ground.transform.position;
|
|
}
|
|
|
|
var goalPos = Random.Range(0, 7);
|
|
if (goalPos == 0)
|
|
{
|
|
symbolOGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolXGoal.transform.position = new Vector3(0f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolSGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
|
|
}
|
|
else if (goalPos == 1)
|
|
{
|
|
symbolOGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolXGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolSGoal.transform.position = new Vector3(0f, 0.5f, 22.29f) + area.transform.position;
|
|
}
|
|
else if (goalPos == 2)
|
|
{
|
|
symbolOGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolXGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolSGoal.transform.position = new Vector3(0f, 0.5f, 22.29f) + area.transform.position;
|
|
}
|
|
else if (goalPos == 3)
|
|
{
|
|
symbolOGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolXGoal.transform.position = new Vector3(0f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolSGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
|
|
}
|
|
else if (goalPos == 4)
|
|
{
|
|
symbolOGoal.transform.position = new Vector3(0f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolXGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolSGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
|
|
}
|
|
else
|
|
{
|
|
symbolOGoal.transform.position = new Vector3(0f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolXGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolSGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
|
|
}
|
|
}
|
|
}
|
|
public override void CollectObservations(VectorSensor sensor)
|
|
{
|
|
//if (useVectorObs)
|
|
//{
|
|
// sensor.AddObservation(StepCount / (float)MaxStep);
|
|
//}
|
|
sensor.AddObservation(toOnehot(m_Message));
|
|
}
|
|
|
|
float[] toOnehot(int message)
|
|
{
|
|
float[] onehot = new float[3];
|
|
if (message < 0 || message >= 3)
|
|
{
|
|
return onehot;
|
|
}
|
|
onehot[message] = 1f;
|
|
return onehot;
|
|
}
|
|
|
|
public void tellAgent(int message)
|
|
{
|
|
m_Message = message;
|
|
}
|
|
|
|
public override void OnActionReceived(ActionBuffers actionBuffers)
|
|
{
|
|
AddReward(-1f / MaxStep);
|
|
if (!isSpotter)
|
|
{
|
|
MoveAgent(actionBuffers.DiscreteActions);
|
|
}
|
|
|
|
int comm_act = actionBuffers.DiscreteActions[1];
|
|
|
|
if (isSpotter)
|
|
{
|
|
m_MessageText.text = "Sent:" + comm_act.ToString();
|
|
}
|
|
teammate.tellAgent(comm_act);
|
|
// if (isSpotter) // Test
|
|
// {
|
|
// teammate.tellAgent(selection);
|
|
// }
|
|
}
|
|
|
|
void OnCollisionEnter(Collision col)
|
|
{
|
|
if (col.gameObject.CompareTag("symbol_O_Goal") || col.gameObject.CompareTag("symbol_X_Goal") || col.gameObject.CompareTag("symbol_S_Goal"))
|
|
{
|
|
if (!isSpotter)
|
|
{
|
|
// Check the ground truth
|
|
if ((teammate.selection == 0 && col.gameObject.CompareTag("symbol_O_Goal")) ||
|
|
(teammate.selection == 1 && col.gameObject.CompareTag("symbol_X_Goal")) ||
|
|
(teammate.selection == 2 && col.gameObject.CompareTag("symbol_S_Goal")))
|
|
{
|
|
SetReward(1f);
|
|
teammate.SetReward(1f);
|
|
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.goalScoredMaterial, 0.5f));
|
|
}
|
|
else
|
|
{
|
|
SetReward(-0.1f);
|
|
teammate.SetReward(-1f);
|
|
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.failMaterial, 0.5f));
|
|
}
|
|
EndEpisode();
|
|
teammate.EndEpisode();
|
|
}
|
|
}
|
|
}
|
|
|
|
//public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
|
|
//{
|
|
// // Mask the necessary actions if selected by the user.
|
|
// if (!isSpotter)
|
|
// {
|
|
// // Prevents the agent from picking an action that would make it collide with a wall
|
|
// actionMask.WriteMask(1, new[] { 0 });
|
|
// }
|
|
//}
|
|
|
|
}
|