您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
272 行
7.7 KiB
272 行
7.7 KiB
using UnityEngine;
|
|
using Unity.MLAgents;
|
|
using Unity.MLAgents.Actuators;
|
|
using Unity.MLAgents.Sensors;
|
|
using Random = UnityEngine.Random;
|
|
|
|
public class FoodCollectorAgent : Agent
|
|
{
|
|
FoodCollectorSettings m_FoodCollecterSettings;
|
|
public GameObject area;
|
|
FoodCollectorArea m_MyArea;
|
|
bool m_Frozen;
|
|
bool m_Poisoned;
|
|
bool m_Satiated;
|
|
bool m_Shoot;
|
|
float m_FrozenTime;
|
|
float m_EffectTime;
|
|
Rigidbody m_AgentRb;
|
|
float m_LaserLength;
|
|
// Speed of agent rotation.
|
|
public float turnSpeed = 300;
|
|
|
|
// Speed of agent movement.
|
|
public float moveSpeed = 2;
|
|
public Material normalMaterial;
|
|
public Material badMaterial;
|
|
public Material goodMaterial;
|
|
public Material frozenMaterial;
|
|
public GameObject myLaser;
|
|
public bool contribute;
|
|
public bool useVectorObs;
|
|
[Tooltip("Use only the frozen flag in vector observations. If \"Use Vector Obs\" " +
|
|
"is checked, this option has no effect. This option is necessary for the " +
|
|
"VisualFoodCollector scene.")]
|
|
public bool useVectorFrozenFlag;
|
|
|
|
EnvironmentParameters m_ResetParams;
|
|
|
|
public override void Initialize()
|
|
{
|
|
m_AgentRb = GetComponent<Rigidbody>();
|
|
m_MyArea = area.GetComponent<FoodCollectorArea>();
|
|
m_FoodCollecterSettings = FindObjectOfType<FoodCollectorSettings>();
|
|
m_ResetParams = Academy.Instance.EnvironmentParameters;
|
|
SetResetParameters();
|
|
}
|
|
|
|
public override void CollectObservations(VectorSensor sensor)
|
|
{
|
|
if (useVectorObs)
|
|
{
|
|
var localVelocity = transform.InverseTransformDirection(m_AgentRb.velocity);
|
|
sensor.AddObservation(localVelocity.x);
|
|
sensor.AddObservation(localVelocity.z);
|
|
sensor.AddObservation(m_Frozen);
|
|
sensor.AddObservation(m_Shoot);
|
|
}
|
|
else if (useVectorFrozenFlag)
|
|
{
|
|
sensor.AddObservation(m_Frozen);
|
|
}
|
|
}
|
|
|
|
public Color32 ToColor(int hexVal)
|
|
{
|
|
var r = (byte)((hexVal >> 16) & 0xFF);
|
|
var g = (byte)((hexVal >> 8) & 0xFF);
|
|
var b = (byte)(hexVal & 0xFF);
|
|
return new Color32(r, g, b, 255);
|
|
}
|
|
|
|
public void MoveAgent(ActionBuffers actionBuffers)
|
|
{
|
|
m_Shoot = false;
|
|
|
|
if (Time.time > m_FrozenTime + 4f && m_Frozen)
|
|
{
|
|
Unfreeze();
|
|
}
|
|
if (Time.time > m_EffectTime + 0.5f)
|
|
{
|
|
if (m_Poisoned)
|
|
{
|
|
Unpoison();
|
|
}
|
|
if (m_Satiated)
|
|
{
|
|
Unsatiate();
|
|
}
|
|
}
|
|
|
|
var dirToGo = Vector3.zero;
|
|
var rotateDir = Vector3.zero;
|
|
|
|
var continuousActions = actionBuffers.ContinuousActions;
|
|
var discreteActions = actionBuffers.DiscreteActions;
|
|
|
|
if (!m_Frozen)
|
|
{
|
|
var forward = Mathf.Clamp(continuousActions[0], -1f, 1f);
|
|
var right = Mathf.Clamp(continuousActions[1], -1f, 1f);
|
|
var rotate = Mathf.Clamp(continuousActions[2], -1f, 1f);
|
|
|
|
dirToGo = transform.forward * forward;
|
|
dirToGo += transform.right * right;
|
|
rotateDir = -transform.up * rotate;
|
|
|
|
var shootCommand = discreteActions[0] > 0;
|
|
if (shootCommand)
|
|
{
|
|
m_Shoot = true;
|
|
dirToGo *= 0.5f;
|
|
m_AgentRb.velocity *= 0.75f;
|
|
}
|
|
m_AgentRb.AddForce(dirToGo * moveSpeed, ForceMode.VelocityChange);
|
|
transform.Rotate(rotateDir, Time.fixedDeltaTime * turnSpeed);
|
|
}
|
|
|
|
if (m_AgentRb.velocity.sqrMagnitude > 25f) // slow it down
|
|
{
|
|
m_AgentRb.velocity *= 0.95f;
|
|
}
|
|
|
|
if (m_Shoot)
|
|
{
|
|
var myTransform = transform;
|
|
myLaser.transform.localScale = new Vector3(1f, 1f, m_LaserLength);
|
|
var rayDir = 25.0f * myTransform.forward;
|
|
Debug.DrawRay(myTransform.position, rayDir, Color.red, 0f, true);
|
|
RaycastHit hit;
|
|
if (Physics.SphereCast(transform.position, 2f, rayDir, out hit, 25f))
|
|
{
|
|
if (hit.collider.gameObject.CompareTag("agent"))
|
|
{
|
|
hit.collider.gameObject.GetComponent<FoodCollectorAgent>().Freeze();
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
myLaser.transform.localScale = new Vector3(0f, 0f, 0f);
|
|
}
|
|
}
|
|
|
|
void Freeze()
|
|
{
|
|
gameObject.tag = "frozenAgent";
|
|
m_Frozen = true;
|
|
m_FrozenTime = Time.time;
|
|
gameObject.GetComponentInChildren<Renderer>().material = frozenMaterial;
|
|
}
|
|
|
|
void Unfreeze()
|
|
{
|
|
m_Frozen = false;
|
|
gameObject.tag = "agent";
|
|
gameObject.GetComponentInChildren<Renderer>().material = normalMaterial;
|
|
}
|
|
|
|
void Poison()
|
|
{
|
|
m_Poisoned = true;
|
|
m_EffectTime = Time.time;
|
|
gameObject.GetComponentInChildren<Renderer>().material = badMaterial;
|
|
}
|
|
|
|
void Unpoison()
|
|
{
|
|
m_Poisoned = false;
|
|
gameObject.GetComponentInChildren<Renderer>().material = normalMaterial;
|
|
}
|
|
|
|
void Satiate()
|
|
{
|
|
m_Satiated = true;
|
|
m_EffectTime = Time.time;
|
|
gameObject.GetComponentInChildren<Renderer>().material = goodMaterial;
|
|
}
|
|
|
|
void Unsatiate()
|
|
{
|
|
m_Satiated = false;
|
|
gameObject.GetComponentInChildren<Renderer>().material = normalMaterial;
|
|
}
|
|
|
|
public override void OnActionReceived(ActionBuffers actionBuffers)
|
|
|
|
{
|
|
MoveAgent(actionBuffers);
|
|
}
|
|
|
|
public override void Heuristic(in ActionBuffers actionsOut)
|
|
{
|
|
var continuousActionsOut = actionsOut.ContinuousActions;
|
|
if (Input.GetKey(KeyCode.D))
|
|
{
|
|
continuousActionsOut[2] = 1;
|
|
}
|
|
if (Input.GetKey(KeyCode.W))
|
|
{
|
|
continuousActionsOut[0] = 1;
|
|
}
|
|
if (Input.GetKey(KeyCode.A))
|
|
{
|
|
continuousActionsOut[2] = -1;
|
|
}
|
|
if (Input.GetKey(KeyCode.S))
|
|
{
|
|
continuousActionsOut[0] = -1;
|
|
}
|
|
var discreteActionsOut = actionsOut.DiscreteActions;
|
|
discreteActionsOut[0] = Input.GetKey(KeyCode.Space) ? 1 : 0;
|
|
}
|
|
|
|
public override void OnEpisodeBegin()
|
|
{
|
|
Unfreeze();
|
|
Unpoison();
|
|
Unsatiate();
|
|
m_Shoot = false;
|
|
m_AgentRb.velocity = Vector3.zero;
|
|
myLaser.transform.localScale = new Vector3(0f, 0f, 0f);
|
|
transform.position = new Vector3(Random.Range(-m_MyArea.range, m_MyArea.range),
|
|
2f, Random.Range(-m_MyArea.range, m_MyArea.range))
|
|
+ area.transform.position;
|
|
transform.rotation = Quaternion.Euler(new Vector3(0f, Random.Range(0, 360)));
|
|
|
|
SetResetParameters();
|
|
}
|
|
|
|
void OnCollisionEnter(Collision collision)
|
|
{
|
|
if (collision.gameObject.CompareTag("food"))
|
|
{
|
|
Satiate();
|
|
collision.gameObject.GetComponent<FoodLogic>().OnEaten();
|
|
AddReward(1f);
|
|
if (contribute)
|
|
{
|
|
m_FoodCollecterSettings.totalScore += 1;
|
|
}
|
|
}
|
|
if (collision.gameObject.CompareTag("badFood"))
|
|
{
|
|
Poison();
|
|
collision.gameObject.GetComponent<FoodLogic>().OnEaten();
|
|
|
|
AddReward(-1f);
|
|
if (contribute)
|
|
{
|
|
m_FoodCollecterSettings.totalScore -= 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
public void SetLaserLengths()
|
|
{
|
|
m_LaserLength = m_ResetParams.GetWithDefault("laser_length", 1.0f);
|
|
}
|
|
|
|
public void SetAgentScale()
|
|
{
|
|
float agentScale = m_ResetParams.GetWithDefault("agent_scale", 1.0f);
|
|
gameObject.transform.localScale = new Vector3(agentScale, agentScale, agentScale);
|
|
}
|
|
|
|
public void SetResetParameters()
|
|
{
|
|
SetLaserLengths();
|
|
SetAgentScale();
|
|
}
|
|
}
|