Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

166 行
5.4 KiB

using System.Collections;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class HallwayAgent : Agent
{
public GameObject ground;
public GameObject area;
public GameObject symbolOGoal;
public GameObject symbolXGoal;
public GameObject symbolO;
public GameObject symbolX;
public bool useVectorObs;
Rigidbody m_AgentRb;
Material m_GroundMaterial;
Renderer m_GroundRenderer;
HallwaySettings m_HallwaySettings;
int m_Selection;
StatsRecorder m_statsRecorder;
public override void Initialize()
{
m_HallwaySettings = FindObjectOfType<HallwaySettings>();
m_AgentRb = GetComponent<Rigidbody>();
m_GroundRenderer = ground.GetComponent<Renderer>();
m_GroundMaterial = m_GroundRenderer.material;
m_statsRecorder = Academy.Instance.StatsRecorder;
}
public override void CollectObservations(VectorSensor sensor)
{
if (useVectorObs)
{
sensor.AddObservation(StepCount / (float)MaxStep);
}
}
IEnumerator GoalScoredSwapGroundMaterial(Material mat, float time)
{
m_GroundRenderer.material = mat;
yield return new WaitForSeconds(time);
m_GroundRenderer.material = m_GroundMaterial;
}
public void MoveAgent(ActionSegment<int> act)
{
var dirToGo = Vector3.zero;
var rotateDir = Vector3.zero;
var action = act[0];
switch (action)
{
case 1:
dirToGo = transform.forward * 1f;
break;
case 2:
dirToGo = transform.forward * -1f;
break;
case 3:
rotateDir = transform.up * 1f;
break;
case 4:
rotateDir = transform.up * -1f;
break;
}
transform.Rotate(rotateDir, Time.deltaTime * 150f);
m_AgentRb.AddForce(dirToGo * m_HallwaySettings.agentRunSpeed, ForceMode.VelocityChange);
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
AddReward(-1f / MaxStep);
MoveAgent(actionBuffers.DiscreteActions);
}
void OnCollisionEnter(Collision col)
{
if (col.gameObject.CompareTag("symbol_O_Goal") || col.gameObject.CompareTag("symbol_X_Goal"))
{
if ((m_Selection == 0 && col.gameObject.CompareTag("symbol_O_Goal")) ||
(m_Selection == 1 && col.gameObject.CompareTag("symbol_X_Goal")))
{
SetReward(1f);
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.goalScoredMaterial, 0.5f));
m_statsRecorder.Add("Goal/Correct", 1, StatAggregationMethod.Sum);
}
else
{
SetReward(-0.1f);
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.failMaterial, 0.5f));
m_statsRecorder.Add("Goal/Wrong", 1, StatAggregationMethod.Sum);
}
EndEpisode();
}
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
}
else if (Input.GetKey(KeyCode.W))
{
discreteActionsOut[0] = 1;
}
else if (Input.GetKey(KeyCode.A))
{
discreteActionsOut[0] = 4;
}
else if (Input.GetKey(KeyCode.S))
{
discreteActionsOut[0] = 2;
}
}
public override void OnEpisodeBegin()
{
var agentOffset = -15f;
var blockOffset = 0f;
m_Selection = Random.Range(0, 2);
if (m_Selection == 0)
{
symbolO.transform.position =
new Vector3(0f + Random.Range(-3f, 3f), 2f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
symbolX.transform.position =
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
}
else
{
symbolO.transform.position =
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
symbolX.transform.position =
new Vector3(0f, 2f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
}
transform.position = new Vector3(0f + Random.Range(-3f, 3f),
1f, agentOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
transform.rotation = Quaternion.Euler(0f, Random.Range(0f, 360f), 0f);
m_AgentRb.velocity *= 0f;
var goalPos = Random.Range(0, 2);
if (goalPos == 0)
{
symbolOGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
symbolXGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
}
else
{
symbolXGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
symbolOGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
}
m_statsRecorder.Add("Goal/Correct", 0, StatAggregationMethod.Sum);
m_statsRecorder.Add("Goal/Wrong", 0, StatAggregationMethod.Sum);
}
}