您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
165 行
5.4 KiB
165 行
5.4 KiB
using System.Collections;
|
|
using UnityEngine;
|
|
using Unity.MLAgents;
|
|
using Unity.MLAgents.Actuators;
|
|
using Unity.MLAgents.Sensors;
|
|
|
|
public class HallwayAgent : Agent
|
|
{
|
|
public GameObject ground;
|
|
public GameObject area;
|
|
public GameObject symbolOGoal;
|
|
public GameObject symbolXGoal;
|
|
public GameObject symbolO;
|
|
public GameObject symbolX;
|
|
public bool useVectorObs;
|
|
protected Rigidbody m_AgentRb;
|
|
protected Material m_GroundMaterial;
|
|
protected Renderer m_GroundRenderer;
|
|
protected HallwaySettings m_HallwaySettings;
|
|
protected int m_Selection;
|
|
|
|
public override void Initialize()
|
|
{
|
|
m_HallwaySettings = FindObjectOfType<HallwaySettings>();
|
|
m_AgentRb = GetComponent<Rigidbody>();
|
|
m_GroundRenderer = ground.GetComponent<Renderer>();
|
|
m_GroundMaterial = m_GroundRenderer.material;
|
|
m_statsRecorder = Academy.Instance.StatsRecorder;
|
|
}
|
|
|
|
public override void CollectObservations(VectorSensor sensor)
|
|
{
|
|
if (useVectorObs)
|
|
{
|
|
sensor.AddObservation(StepCount / (float)MaxStep);
|
|
}
|
|
}
|
|
|
|
protected IEnumerator GoalScoredSwapGroundMaterial(Material mat, float time)
|
|
{
|
|
m_GroundRenderer.material = mat;
|
|
yield return new WaitForSeconds(time);
|
|
m_GroundRenderer.material = m_GroundMaterial;
|
|
}
|
|
|
|
public void MoveAgent(ActionSegment<int> act)
|
|
{
|
|
var dirToGo = Vector3.zero;
|
|
var rotateDir = Vector3.zero;
|
|
|
|
var action = act[0];
|
|
switch (action)
|
|
{
|
|
case 1:
|
|
dirToGo = transform.forward * 1f;
|
|
break;
|
|
case 2:
|
|
dirToGo = transform.forward * -1f;
|
|
break;
|
|
case 3:
|
|
rotateDir = transform.up * 1f;
|
|
break;
|
|
case 4:
|
|
rotateDir = transform.up * -1f;
|
|
break;
|
|
}
|
|
transform.Rotate(rotateDir, Time.deltaTime * 150f);
|
|
m_AgentRb.AddForce(dirToGo * m_HallwaySettings.agentRunSpeed, ForceMode.VelocityChange);
|
|
}
|
|
|
|
public override void OnActionReceived(ActionBuffers actionBuffers)
|
|
|
|
{
|
|
AddReward(-1f / MaxStep);
|
|
MoveAgent(actionBuffers.DiscreteActions);
|
|
}
|
|
|
|
void OnCollisionEnter(Collision col)
|
|
{
|
|
if (col.gameObject.CompareTag("symbol_O_Goal") || col.gameObject.CompareTag("symbol_X_Goal"))
|
|
{
|
|
if ((m_Selection == 0 && col.gameObject.CompareTag("symbol_O_Goal")) ||
|
|
(m_Selection == 1 && col.gameObject.CompareTag("symbol_X_Goal")))
|
|
{
|
|
SetReward(1f);
|
|
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.goalScoredMaterial, 0.5f));
|
|
m_statsRecorder.Add("Goal/Correct", 1, StatAggregationMethod.Sum);
|
|
}
|
|
else
|
|
{
|
|
SetReward(-0.1f);
|
|
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.failMaterial, 0.5f));
|
|
m_statsRecorder.Add("Goal/Wrong", 1, StatAggregationMethod.Sum);
|
|
}
|
|
EndEpisode();
|
|
}
|
|
}
|
|
|
|
public override void Heuristic(in ActionBuffers actionsOut)
|
|
{
|
|
var discreteActionsOut = actionsOut.DiscreteActions;
|
|
discreteActionsOut[0] = 0;
|
|
if (Input.GetKey(KeyCode.D))
|
|
{
|
|
discreteActionsOut[0] = 3;
|
|
}
|
|
else if (Input.GetKey(KeyCode.W))
|
|
{
|
|
discreteActionsOut[0] = 1;
|
|
}
|
|
else if (Input.GetKey(KeyCode.A))
|
|
{
|
|
discreteActionsOut[0] = 4;
|
|
}
|
|
else if (Input.GetKey(KeyCode.S))
|
|
{
|
|
discreteActionsOut[0] = 2;
|
|
}
|
|
}
|
|
|
|
public override void OnEpisodeBegin()
|
|
{
|
|
var agentOffset = -15f;
|
|
var blockOffset = 0f;
|
|
m_Selection = Random.Range(0, 2);
|
|
if (m_Selection == 0)
|
|
{
|
|
symbolO.transform.position =
|
|
new Vector3(0f + Random.Range(-3f, 3f), 2f, blockOffset + Random.Range(-5f, 5f))
|
|
+ ground.transform.position;
|
|
symbolX.transform.position =
|
|
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
|
|
+ ground.transform.position;
|
|
}
|
|
else
|
|
{
|
|
symbolO.transform.position =
|
|
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
|
|
+ ground.transform.position;
|
|
symbolX.transform.position =
|
|
new Vector3(0f, 2f, blockOffset + Random.Range(-5f, 5f))
|
|
+ ground.transform.position;
|
|
}
|
|
|
|
transform.position = new Vector3(0f + Random.Range(-3f, 3f),
|
|
1f, agentOffset + Random.Range(-5f, 5f))
|
|
+ ground.transform.position;
|
|
transform.rotation = Quaternion.Euler(0f, Random.Range(0f, 360f), 0f);
|
|
m_AgentRb.velocity *= 0f;
|
|
|
|
var goalPos = Random.Range(0, 2);
|
|
if (goalPos == 0)
|
|
{
|
|
symbolOGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolXGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
|
|
}
|
|
else
|
|
{
|
|
symbolXGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
|
|
symbolOGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
|
|
}
|
|
m_statsRecorder.Add("Goal/Correct", 0, StatAggregationMethod.Sum);
|
|
m_statsRecorder.Add("Goal/Wrong", 0, StatAggregationMethod.Sum);
|
|
}
|
|
}
|