Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

183 行
7.0 KiB

using System.Collections;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class HallwayCollabAgent : HallwayAgent
{
public GameObject symbolSGoal;
public GameObject symbolS;
public HallwayCollabAgent teammate;
public bool isSpotter = true;
int m_Message = 0;
[HideInInspector]
public int selection = 0;
public override void OnEpisodeBegin()
{
m_Message = -1;
var agentOffset = 10f;
if (isSpotter)
{
agentOffset = -15;
}
if (!isSpotter)
{
transform.position = new Vector3(0f + Random.Range(-3f, 3f),
1f, agentOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
transform.rotation = Quaternion.Euler(0f, Random.Range(0f, 360f), 0f);
}
else
{
transform.position = new Vector3(0f,
1f, agentOffset)
+ ground.transform.position;
transform.rotation = Quaternion.Euler(0f, 0f, 0f);
}
// Remove the randomness
m_AgentRb.velocity *= 0f;
if (isSpotter)
{
var blockOffset = -9f;
// Only the Spotter has the correct selection
selection = Random.Range(0, 3);
if (selection == 0)
{
symbolO.transform.position =
new Vector3(0f, 2f, blockOffset)
+ ground.transform.position;
symbolX.transform.position =
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
symbolS.transform.position =
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
}
else if (selection == 1)
{
symbolO.transform.position =
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
symbolX.transform.position =
new Vector3(0f, 2f, blockOffset)
+ ground.transform.position;
symbolS.transform.position =
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
}
else
{
symbolO.transform.position =
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
symbolX.transform.position =
new Vector3(0f, -1000f, blockOffset)
+ ground.transform.position;
symbolS.transform.position =
new Vector3(0f, 2f, blockOffset)
+ ground.transform.position;
}
var goalPos = Random.Range(0, 7);
if (goalPos == 0)
{
symbolOGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
symbolXGoal.transform.position = new Vector3(0f, 0.5f, 22.29f) + area.transform.position;
symbolSGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
}
else if (goalPos == 1)
{
symbolOGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
symbolXGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
symbolSGoal.transform.position = new Vector3(0f, 0.5f, 22.29f) + area.transform.position;
}
else if (goalPos == 2)
{
symbolOGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
symbolXGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
symbolSGoal.transform.position = new Vector3(0f, 0.5f, 22.29f) + area.transform.position;
}
else if (goalPos == 3)
{
symbolOGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
symbolXGoal.transform.position = new Vector3(0f, 0.5f, 22.29f) + area.transform.position;
symbolSGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
}
else if (goalPos == 4)
{
symbolOGoal.transform.position = new Vector3(0f, 0.5f, 22.29f) + area.transform.position;
symbolXGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
symbolSGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
}
else
{
symbolOGoal.transform.position = new Vector3(0f, 0.5f, 22.29f) + area.transform.position;
symbolXGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
symbolSGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
}
}
}
public override void CollectObservations(VectorSensor sensor)
{
if (useVectorObs)
{
sensor.AddObservation(StepCount / (float)MaxStep);
}
sensor.AddObservation(m_Message);
}
public void tellAgent(int message)
{
m_Message = message;
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
AddReward(-1f / MaxStep);
if (!isSpotter)
{
MoveAgent(actionBuffers.DiscreteActions);
}
int comm_act = actionBuffers.DiscreteActions[1];
teammate.tellAgent(comm_act);
// if (isSpotter) // Test
// {
// teammate.tellAgent(selection);
// }
}
void OnCollisionEnter(Collision col)
{
if (col.gameObject.CompareTag("symbol_O_Goal") || col.gameObject.CompareTag("symbol_X_Goal") || col.gameObject.CompareTag("symbol_S_Goal"))
{
if (!isSpotter)
{
// Check the ground truth
if ((teammate.selection == 0 && col.gameObject.CompareTag("symbol_O_Goal")) ||
(teammate.selection == 1 && col.gameObject.CompareTag("symbol_X_Goal")) ||
(teammate.selection == 2 && col.gameObject.CompareTag("symbol_S_Goal")))
{
SetReward(1f);
teammate.SetReward(1f);
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.goalScoredMaterial, 0.5f));
}
else
{
SetReward(-0.1f);
teammate.SetReward(-0.1f);
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.failMaterial, 0.5f));
}
EndEpisode();
teammate.EndEpisode();
}
}
}
}