Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

223 行
7.1 KiB

//Put this script on your blue cube.
using System.Collections;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class DodgeAgent : Agent
{
/// <summary>
/// The ground. The bounds are used to spawn the elements.
/// </summary>
public GameObject ground;
public GameObject area;
/// <summary>
/// The area bounds.
/// </summary>
[HideInInspector]
public Bounds areaBounds;
BulletSettings m_BulletSettings;
Rigidbody m_AgentRb; //cached on initialization
Material m_GroundMaterial; //cached on Awake()
/// <summary>
/// We will be changing the ground material based on success/failue
/// </summary>
Renderer m_GroundRenderer;
EnvironmentParameters m_ResetParams;
BufferSensorComponent m_BufferSensor;
void Awake()
{
m_BulletSettings = FindObjectOfType<BulletSettings>();
}
public override void Initialize()
{
m_BufferSensor = GetComponent<BufferSensorComponent>();
// Cache the agent rigidbody
m_AgentRb = GetComponent<Rigidbody>();
// Cache the block rigidbody
// Get the ground's bounds
areaBounds = ground.GetComponent<Collider>().bounds;
// Get the ground renderer so we can change the material when a goal is scored
m_GroundRenderer = ground.GetComponent<Renderer>();
// Starting material
m_GroundMaterial = m_GroundRenderer.material;
m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
}
/// <summary>
/// Use the ground's bounds to pick a random spawn position.
/// </summary>
public Vector3 GetRandomSpawnPos()
{
var foundNewSpawnLocation = false;
var randomSpawnPos = Vector3.zero;
while (foundNewSpawnLocation == false)
{
var randomPosX = Random.Range(-areaBounds.extents.x * m_BulletSettings.spawnAreaMarginMultiplier,
areaBounds.extents.x * m_BulletSettings.spawnAreaMarginMultiplier);
var randomPosZ = Random.Range(-areaBounds.extents.z * m_BulletSettings.spawnAreaMarginMultiplier,
areaBounds.extents.z * m_BulletSettings.spawnAreaMarginMultiplier);
randomSpawnPos = ground.transform.position + new Vector3(randomPosX, 1f, randomPosZ);
if (Physics.CheckBox(randomSpawnPos, new Vector3(2.5f, 0.01f, 2.5f)) == false)
{
foundNewSpawnLocation = true;
}
}
return randomSpawnPos;
}
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation((transform.position.x - area.transform.position.x) / 10f);
sensor.AddObservation((transform.position.z - area.transform.position.z) / 10f);
sensor.AddObservation(transform.forward.x);
sensor.AddObservation(transform.forward.z);
// Collect observation about the 20 closest Bullets
var bullets = transform.parent.GetComponentsInChildren<Bullet>();
// Sort by closest :
System.Array.Sort(bullets, (a, b) => (Vector3.Distance(a.transform.position, transform.position)).CompareTo(Vector3.Distance(b.transform.position, transform.position)));
int numBulletAdded = 0;
// foreach (Bullet b in bullets)
// {
// b.transform.localScale = new Vector3(1, 1, 1);
// }
foreach (Bullet b in bullets)
{
if (numBulletAdded >= 20)
{
break;
}
float[] bulletObservation = new float[]{
(b.transform.position.x - transform.position.x) / 10f, // relative position
(b.transform.position.z - transform.position.z) / 10f,
b.transform.forward.x,
b.transform.forward.z
};
numBulletAdded += 1;
m_BufferSensor.AppendObservation(bulletObservation);
// b.transform.localScale = new Vector3(2, 2, 2);
};
}
/// <summary>
/// Called every step of the engine. Here the agent takes an action.
/// </summary>
public override void OnActionReceived(ActionBuffers actionBuffers)
{
var forwardForce = Mathf.Clamp(actionBuffers.ContinuousActions[0], -1f, 1f);
var lateralForce = Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f);
var rotationForce = Mathf.Clamp(actionBuffers.ContinuousActions[2], -1f, 1f);
//Vector3 dirToGo = transform.forward * forwardForce + transform.right * lateralForce;
Vector3 rotateDir = transform.up * rotationForce;
transform.Rotate(rotateDir, Time.fixedDeltaTime * 200f);
Vector3 dirToGo = new Vector3(1, 0, 0) * forwardForce + new Vector3(0, 0, 1) * lateralForce;
m_AgentRb.AddForce(dirToGo * m_BulletSettings.agentRunSpeed,
ForceMode.VelocityChange);
//Vector3 dirToCenter = new Vector3((transform.position.x - area.transform.position.x) / 10f, 0f, (transform.position.z - area.transform.position.z) / 10f);
//AddReward(.001f / (dirToCenter.magnitude + .0000001f));
AddReward(.001f);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = 0f;
continuousActionsOut[1] = 0f;
continuousActionsOut[2] = 0f;
if (Input.GetKey(KeyCode.D))
{
continuousActionsOut[2] = 1f;
}
else if (Input.GetKey(KeyCode.W))
{
continuousActionsOut[0] = 1f;
}
else if (Input.GetKey(KeyCode.A))
{
continuousActionsOut[2] = -1f;
}
else if (Input.GetKey(KeyCode.S))
{
continuousActionsOut[0] = -1f;
}
else if (Input.GetKey(KeyCode.Q))
{
continuousActionsOut[1] = -1f;
}
else if (Input.GetKey(KeyCode.E))
{
continuousActionsOut[1] = 1f;
}
}
void OnCollisionEnter(Collision collision)
{
if (collision.gameObject.CompareTag("bullet") || collision.gameObject.CompareTag("wall"))
{
SetReward(0f);
EndEpisode();
}
}
/// <summary>
/// In the editor, if "Reset On Done" is checked then AgentReset() will be
/// called automatically anytime we mark done = true in an agent script.
/// </summary>
public override void OnEpisodeBegin()
{
// var rotation = Random.Range(0, 4);
// var rotationAngle = rotation * 90f;
// area.transform.Rotate(new Vector3(0f, rotationAngle, 0f));
transform.position = GetRandomSpawnPos();//
m_AgentRb.velocity = Vector3.zero;
m_AgentRb.angularVelocity = Vector3.zero;
SetResetParameters();
}
public void SetGroundMaterialFriction()
{
var groundCollider = ground.GetComponent<Collider>();
groundCollider.material.dynamicFriction = m_ResetParams.GetWithDefault("dynamic_friction", 0);
groundCollider.material.staticFriction = m_ResetParams.GetWithDefault("static_friction", 0);
}
void SetResetParameters()
{
//SetGroundMaterialFriction();
}
}