GitHub
5 年前
当前提交
e7ec5007
共有 10 个文件被更改,包括 312 次插入 和 117 次删除
-
53Project/Assets/ML-Agents/Examples/Basic/Scenes/Basic.unity
-
2com.unity.ml-agents/CHANGELOG.md
-
13com.unity.ml-agents/Runtime/Agent.cs
-
110Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicController.cs
-
69Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs
-
3Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs.meta
-
56Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
-
11Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs.meta
-
112Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs
-
0/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicController.cs.meta
|
|||
using UnityEngine; |
|||
using MLAgents; |
|||
|
|||
/// <summary>
|
|||
/// An example of how to use ML-Agents without inheriting from the Agent class.
|
|||
/// Observations are generated by the attached SensorComponent, and the actions
|
|||
/// are retrieved from the Agent.
|
|||
/// </summary>
|
|||
public class BasicController : MonoBehaviour |
|||
{ |
|||
public float timeBetweenDecisionsAtInference; |
|||
float m_TimeSinceDecision; |
|||
[HideInInspector] |
|||
public int m_Position; |
|||
const int k_SmallGoalPosition = 7; |
|||
const int k_LargeGoalPosition = 17; |
|||
public GameObject largeGoal; |
|||
public GameObject smallGoal; |
|||
const int k_MinPosition = 0; |
|||
const int k_MaxPosition = 20; |
|||
public const int k_Extents = k_MaxPosition - k_MinPosition; |
|||
|
|||
Agent m_Agent; |
|||
|
|||
public void OnEnable() |
|||
{ |
|||
m_Agent = GetComponent<Agent>(); |
|||
ResetAgent(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Controls the movement of the GameObject based on the actions received.
|
|||
/// </summary>
|
|||
/// <param name="vectorAction"></param>
|
|||
public void ApplyAction(float[] vectorAction) |
|||
{ |
|||
var movement = (int)vectorAction[0]; |
|||
|
|||
var direction = 0; |
|||
|
|||
switch (movement) |
|||
{ |
|||
case 1: |
|||
direction = -1; |
|||
break; |
|||
case 2: |
|||
direction = 1; |
|||
break; |
|||
} |
|||
|
|||
m_Position += direction; |
|||
if (m_Position < k_MinPosition) { m_Position = k_MinPosition; } |
|||
if (m_Position > k_MaxPosition) { m_Position = k_MaxPosition; } |
|||
|
|||
gameObject.transform.position = new Vector3(m_Position - 10f, 0f, 0f); |
|||
|
|||
m_Agent.AddReward(-0.01f); |
|||
|
|||
if (m_Position == k_SmallGoalPosition) |
|||
{ |
|||
m_Agent.AddReward(0.1f); |
|||
m_Agent.Done(); |
|||
ResetAgent(); |
|||
} |
|||
|
|||
if (m_Position == k_LargeGoalPosition) |
|||
{ |
|||
m_Agent.AddReward(1f); |
|||
m_Agent.Done(); |
|||
ResetAgent(); |
|||
} |
|||
} |
|||
|
|||
public void ResetAgent() |
|||
{ |
|||
m_Position = 10; |
|||
smallGoal.transform.position = new Vector3(k_SmallGoalPosition - 10f, 0f, 0f); |
|||
largeGoal.transform.position = new Vector3(k_LargeGoalPosition - 10f, 0f, 0f); |
|||
} |
|||
|
|||
public void FixedUpdate() |
|||
{ |
|||
WaitTimeInference(); |
|||
} |
|||
|
|||
void WaitTimeInference() |
|||
{ |
|||
if (Academy.Instance.IsCommunicatorOn) |
|||
{ |
|||
// Apply the previous step's actions
|
|||
ApplyAction(m_Agent.GetAction()); |
|||
m_Agent.RequestDecision(); |
|||
} |
|||
else |
|||
{ |
|||
if (m_TimeSinceDecision >= timeBetweenDecisionsAtInference) |
|||
{ |
|||
// Apply the previous step's actions
|
|||
ApplyAction(m_Agent.GetAction()); |
|||
|
|||
m_TimeSinceDecision = 0f; |
|||
m_Agent.RequestDecision(); |
|||
} |
|||
else |
|||
{ |
|||
m_TimeSinceDecision += Time.fixedDeltaTime; |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
using System; |
|||
using MLAgents.Sensors; |
|||
using UnityEngine.Serialization; |
|||
|
|||
namespace MLAgentsExamples |
|||
{ |
|||
/// <summary>
|
|||
/// A simple example of a SensorComponent.
|
|||
/// This should be added to the same GameObject as the BasicController
|
|||
/// </summary>
|
|||
public class BasicSensorComponent : SensorComponent |
|||
{ |
|||
public BasicController basicController; |
|||
|
|||
/// <summary>
|
|||
/// Creates a BasicSensor.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public override ISensor CreateSensor() |
|||
{ |
|||
return new BasicSensor(basicController); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override int[] GetObservationShape() |
|||
{ |
|||
return new[] { BasicController.k_Extents }; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Simple Sensor implementation that uses a one-hot encoding of the Agent's
|
|||
/// position as the observation.
|
|||
/// </summary>
|
|||
public class BasicSensor : SensorBase |
|||
{ |
|||
public BasicController basicController; |
|||
|
|||
public BasicSensor(BasicController controller) |
|||
{ |
|||
basicController = controller; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Generate the observations for the sensor.
|
|||
/// In this case, the observations are all 0 except for a 1 at the position of the agent.
|
|||
/// </summary>
|
|||
/// <param name="output"></param>
|
|||
public override void WriteObservation(float[] output) |
|||
{ |
|||
// One-hot encoding of the position
|
|||
Array.Clear(output, 0, output.Length); |
|||
output[basicController.m_Position] = 1; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override int[] GetObservationShape() |
|||
{ |
|||
return new[] { BasicController.k_Extents }; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override string GetName() |
|||
{ |
|||
return "Basic"; |
|||
} |
|||
|
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 6ee410d6d45349218d5e69bb2a347c63 |
|||
timeCreated: 1582857786 |
|
|||
using MLAgents.Sensors; |
|||
|
|||
namespace MLAgentsExamples |
|||
{ |
|||
/// <summary>
|
|||
/// A simple sensor that provides a number default implementations.
|
|||
/// </summary>
|
|||
public abstract class SensorBase : ISensor |
|||
{ |
|||
/// <summary>
|
|||
/// Write the observations to the output buffer. This size of the buffer will be product
|
|||
/// of the sizes returned by <see cref="GetObservationShape"/>.
|
|||
/// </summary>
|
|||
/// <param name="output"></param>
|
|||
public abstract void WriteObservation(float[] output); |
|||
|
|||
/// <inheritdoc/>
|
|||
public abstract int[] GetObservationShape(); |
|||
|
|||
/// <inheritdoc/>
|
|||
public abstract string GetName(); |
|||
|
|||
/// <summary>
|
|||
/// Default implementation of Write interface. This creates a temporary array,
|
|||
/// calls WriteObservation, and then writes the results to the WriteAdapter.
|
|||
/// </summary>
|
|||
/// <param name="adapter"></param>
|
|||
/// <returns>The number of elements written.</returns>
|
|||
public virtual int Write(WriteAdapter adapter) |
|||
{ |
|||
// TODO reuse buffer for similar agents, don't call GetObservationShape()
|
|||
var numFloats = this.ObservationSize(); |
|||
float[] buffer = new float[numFloats]; |
|||
WriteObservation(buffer); |
|||
|
|||
adapter.AddRange(buffer); |
|||
|
|||
return numFloats; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Update() {} |
|||
|
|||
/// <inheritdoc/>
|
|||
public virtual byte[] GetCompressedObservation() |
|||
{ |
|||
return null; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public virtual SensorCompressionType GetCompressionType() |
|||
{ |
|||
return SensorCompressionType.None; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 553b05a1b59a94260b3e545f13190389 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using UnityEngine; |
|||
using MLAgents; |
|||
using MLAgents.Sensors; |
|||
|
|||
public class BasicAgent : Agent |
|||
{ |
|||
[Header("Specific to Basic")] |
|||
public float timeBetweenDecisionsAtInference; |
|||
float m_TimeSinceDecision; |
|||
int m_Position; |
|||
int m_SmallGoalPosition; |
|||
int m_LargeGoalPosition; |
|||
public GameObject largeGoal; |
|||
public GameObject smallGoal; |
|||
int m_MinPosition; |
|||
int m_MaxPosition; |
|||
|
|||
public override void InitializeAgent() |
|||
{ |
|||
} |
|||
|
|||
public override void CollectObservations(VectorSensor sensor) |
|||
{ |
|||
sensor.AddOneHotObservation(m_Position, 20); |
|||
} |
|||
|
|||
public override void AgentAction(float[] vectorAction) |
|||
{ |
|||
var movement = (int)vectorAction[0]; |
|||
|
|||
var direction = 0; |
|||
|
|||
switch (movement) |
|||
{ |
|||
case 1: |
|||
direction = -1; |
|||
break; |
|||
case 2: |
|||
direction = 1; |
|||
break; |
|||
} |
|||
|
|||
m_Position += direction; |
|||
if (m_Position < m_MinPosition) { m_Position = m_MinPosition; } |
|||
if (m_Position > m_MaxPosition) { m_Position = m_MaxPosition; } |
|||
|
|||
gameObject.transform.position = new Vector3(m_Position - 10f, 0f, 0f); |
|||
|
|||
AddReward(-0.01f); |
|||
|
|||
if (m_Position == m_SmallGoalPosition) |
|||
{ |
|||
AddReward(0.1f); |
|||
Done(); |
|||
} |
|||
|
|||
if (m_Position == m_LargeGoalPosition) |
|||
{ |
|||
AddReward(1f); |
|||
Done(); |
|||
} |
|||
} |
|||
|
|||
public override void AgentReset() |
|||
{ |
|||
m_Position = 10; |
|||
m_MinPosition = 0; |
|||
m_MaxPosition = 20; |
|||
m_SmallGoalPosition = 7; |
|||
m_LargeGoalPosition = 17; |
|||
smallGoal.transform.position = new Vector3(m_SmallGoalPosition - 10f, 0f, 0f); |
|||
largeGoal.transform.position = new Vector3(m_LargeGoalPosition - 10f, 0f, 0f); |
|||
} |
|||
|
|||
public override float[] Heuristic() |
|||
{ |
|||
if (Input.GetKey(KeyCode.D)) |
|||
{ |
|||
return new float[] { 2 }; |
|||
} |
|||
if (Input.GetKey(KeyCode.A)) |
|||
{ |
|||
return new float[] { 1 }; |
|||
} |
|||
return new float[] { 0 }; |
|||
} |
|||
|
|||
public void FixedUpdate() |
|||
{ |
|||
WaitTimeInference(); |
|||
} |
|||
|
|||
void WaitTimeInference() |
|||
{ |
|||
if (!Academy.Instance.IsCommunicatorOn) |
|||
{ |
|||
RequestDecision(); |
|||
} |
|||
else |
|||
{ |
|||
if (m_TimeSinceDecision >= timeBetweenDecisionsAtInference) |
|||
{ |
|||
m_TimeSinceDecision = 0f; |
|||
RequestDecision(); |
|||
} |
|||
else |
|||
{ |
|||
m_TimeSinceDecision += Time.fixedDeltaTime; |
|||
} |
|||
} |
|||
} |
|||
} |
撰写
预览
正在加载...
取消
保存
Reference in new issue