浏览代码

Develop return float array (#3319)

* Decide Action to return float array

* Removing Debug statement

* Fixing the tests

* Fixing the format

* Renaming some variables

* Better memory allocation
/asymm-envs
GitHub 5 年前
当前提交
c6e5b23e
共有 12 个文件被更改,包括 193 次插入135 次删除
  1. 11
      com.unity.ml-agents/Runtime/Agent.cs
  2. 91
      com.unity.ml-agents/Runtime/Grpc/RpcCommunicator.cs
  3. 7
      com.unity.ml-agents/Runtime/ICommunicator.cs
  4. 58
      com.unity.ml-agents/Runtime/InferenceBrain/ApplierImpl.cs
  5. 38
      com.unity.ml-agents/Runtime/InferenceBrain/ModelRunner.cs
  6. 13
      com.unity.ml-agents/Runtime/InferenceBrain/TensorApplier.cs
  7. 10
      com.unity.ml-agents/Runtime/Policy/BarracudaPolicy.cs
  8. 14
      com.unity.ml-agents/Runtime/Policy/HeuristicPolicy.cs
  9. 4
      com.unity.ml-agents/Runtime/Policy/IPolicy.cs
  10. 17
      com.unity.ml-agents/Runtime/Policy/RemotePolicy.cs
  11. 45
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorApplier.cs
  12. 20
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs

11
com.unity.ml-agents/Runtime/Agent.cs


m_Info.maxStepReached = maxStepReached;
// Request the last decision with no callbacks
// We request a decision so Python knows the Agent is done immediately
m_Brain?.RequestDecision(m_Info, sensors, (a) => {});
m_Brain?.RequestDecision(m_Info, sensors);
// The Agent is done, so we give it a new episode Id
m_EpisodeId = EpisodeIdCounter.GetEpisodeId();
m_Reward = 0f;

m_Info.maxStepReached = false;
m_Info.episodeId = m_EpisodeId;
m_Brain.RequestDecision(m_Info, sensors, UpdateAgentAction);
m_Brain.RequestDecision(m_Info, sensors);
if (m_Recorder != null && m_Recorder.record && Application.isEditor)
{

if ((m_RequestAction) && (m_Brain != null))
{
m_RequestAction = false;
AgentAction(m_Action.vectorActions);
if (m_Action.vectorActions != null)
{
AgentAction(m_Action.vectorActions);
}
m_Brain?.DecideAction();
m_Action.vectorActions = m_Brain?.DecideAction();
}
}
}

91
com.unity.ml-agents/Runtime/Grpc/RpcCommunicator.cs


/// Responsible for communication with External using gRPC.
public class RpcCommunicator : ICommunicator
{
public struct IdCallbackPair
{
public int AgentId;
public Action<AgentAction> Callback;
}
public event QuitCommandHandler QuitCommandReceived;
public event ResetCommandHandler ResetCommandReceived;

/// The default number of agents in the scene
const int k_NumAgents = 32;
Dictionary<string, List<IdCallbackPair>> m_ActionCallbacks = new Dictionary<string, List<IdCallbackPair>>();
Dictionary<string, List<int>> m_OrderedAgentsRequestingDecisions = new Dictionary<string, List<int>>();
Dictionary<string, Dictionary<int, AgentAction>> m_LastActionsReceived =
new Dictionary<string, Dictionary<int, AgentAction>>();
Dictionary<string, Dictionary<int, float[]>> m_LastActionsReceived =
new Dictionary<string, Dictionary<int, float[]>>();
// Brains that we have sent over the communicator with agents.
HashSet<string> m_SentBrainKeys = new HashSet<string>();

}
case CommandProto.Reset:
{
foreach (var brainName in m_ActionCallbacks.Keys)
foreach (var brainName in m_OrderedAgentsRequestingDecisions.Keys)
m_ActionCallbacks[brainName].Clear();
m_OrderedAgentsRequestingDecisions[brainName].Clear();
}
ResetCommandReceived?.Invoke();
return;

/// <summary>
/// Sends the observations of one Agent.
/// </summary>
/// <param name="brainKey">Batch Key.</param>
/// <param name="behaviorName">Batch Key.</param>
public void PutObservations(string brainKey, AgentInfo info, List<ISensor> sensors, Action<AgentAction> action)
public void PutObservations(string behaviorName, AgentInfo info, List<ISensor> sensors)
if (!m_SensorShapeValidators.ContainsKey(brainKey))
if (!m_SensorShapeValidators.ContainsKey(behaviorName))
m_SensorShapeValidators[brainKey] = new SensorShapeValidator();
m_SensorShapeValidators[behaviorName] = new SensorShapeValidator();
m_SensorShapeValidators[brainKey].ValidateSensors(sensors);
m_SensorShapeValidators[behaviorName].ValidateSensors(sensors);
#endif
using (TimerStack.Instance.Scoped("AgentInfo.ToProto"))

agentInfoProto.Observations.Add(obsProto);
}
}
m_CurrentUnityRlOutput.AgentInfos[brainKey].Value.Add(agentInfoProto);
m_CurrentUnityRlOutput.AgentInfos[behaviorName].Value.Add(agentInfoProto);
if (!m_ActionCallbacks.ContainsKey(brainKey))
if (!m_OrderedAgentsRequestingDecisions.ContainsKey(behaviorName))
{
m_OrderedAgentsRequestingDecisions[behaviorName] = new List<int>();
}
m_OrderedAgentsRequestingDecisions[behaviorName].Add(info.episodeId);
if (!m_LastActionsReceived.ContainsKey(behaviorName))
{
m_LastActionsReceived[behaviorName] = new Dictionary<int, float[]>();
}
m_LastActionsReceived[behaviorName][info.episodeId] = null;
if (info.done)
m_ActionCallbacks[brainKey] = new List<IdCallbackPair>();
m_LastActionsReceived[behaviorName].Remove(info.episodeId);
m_ActionCallbacks[brainKey].Add(new IdCallbackPair { AgentId = info.episodeId, Callback = action });
}
/// <summary>

UpdateEnvironmentWithInput(rlInput);
m_LastActionsReceived.Clear();
if (!m_ActionCallbacks[brainName].Any())
if (!m_OrderedAgentsRequestingDecisions[brainName].Any())
{
continue;
}

}
var agentActions = rlInput.AgentActions[brainName].ToAgentActionList();
var numAgents = m_ActionCallbacks[brainName].Count;
var agentActionDict = new Dictionary<int, AgentAction>(numAgents);
m_LastActionsReceived[brainName] = agentActionDict;
var numAgents = m_OrderedAgentsRequestingDecisions[brainName].Count;
var agentId = m_ActionCallbacks[brainName][i].AgentId;
agentActionDict[agentId] = agentAction;
m_ActionCallbacks[brainName][i].Callback.Invoke(agentAction);
var agentId = m_OrderedAgentsRequestingDecisions[brainName][i];
if (m_LastActionsReceived[brainName].ContainsKey(agentId))
{
m_LastActionsReceived[brainName][agentId] = agentAction.vectorActions;
}
foreach (var brainName in m_ActionCallbacks.Keys)
foreach (var brainName in m_OrderedAgentsRequestingDecisions.Keys)
m_ActionCallbacks[brainName].Clear();
m_OrderedAgentsRequestingDecisions[brainName].Clear();
public Dictionary<int, AgentAction> GetActions(string key)
public float[] GetActions(string behaviorName, int agentId)
return m_LastActionsReceived[key];
if (m_LastActionsReceived.ContainsKey(behaviorName))
{
if (m_LastActionsReceived[behaviorName].ContainsKey(agentId))
{
return m_LastActionsReceived[behaviorName][agentId];
}
}
return null;
}
/// <summary>

}
/// <summary>
/// Wraps the UnityOuptut into a message with the appropriate status.
/// Wraps the UnityOutput into a message with the appropriate status.
/// </summary>
/// <returns>The UnityMessage corresponding.</returns>
/// <param name="content">The UnityOutput to be wrapped.</param>

};
}
void CacheBrainParameters(string brainKey, BrainParameters brainParameters)
void CacheBrainParameters(string behaviorName, BrainParameters brainParameters)
if (m_SentBrainKeys.Contains(brainKey))
if (m_SentBrainKeys.Contains(behaviorName))
m_UnsentBrainKeys[brainKey] = brainParameters;
m_UnsentBrainKeys[behaviorName] = brainParameters;
foreach (var brainKey in m_UnsentBrainKeys.Keys)
foreach (var behaviorName in m_UnsentBrainKeys.Keys)
if (m_CurrentUnityRlOutput.AgentInfos.ContainsKey(brainKey))
if (m_CurrentUnityRlOutput.AgentInfos.ContainsKey(behaviorName))
{
if (output == null)
{

var brainParameters = m_UnsentBrainKeys[brainKey];
output.BrainParameters.Add(brainParameters.ToProto(brainKey, true));
var brainParameters = m_UnsentBrainKeys[behaviorName];
output.BrainParameters.Add(brainParameters.ToProto(behaviorName, true));
}
}

7
com.unity.ml-agents/Runtime/ICommunicator.cs


/// <param name="info">Agent info.</param>
/// <param name="sensors">The list of ISensors of the Agent.</param>
/// <param name="action">The action that will be called once the next AgentAction is ready.</param>
void PutObservations(string brainKey, AgentInfo info, List<ISensor> sensors, Action<AgentAction> action);
void PutObservations(string brainKey, AgentInfo info, List<ISensor> sensors);
/// <summary>
/// Signals the ICommunicator that the Agents are now ready to receive their action

/// <summary>
/// Gets the AgentActions based on the batching key.
/// </summary>
/// <param name="key">A key to identify which actions to get</param>
/// <param name="key">A key to identify which behavior actions to get</param>
/// <param name="agentId">A key to identify which Agent actions to get</param>
Dictionary<int, AgentAction> GetActions(string key);
float[] GetActions(string key, int agentId);
/// <summary>
/// Registers a side channel to the communicator. The side channel will exchange

58
com.unity.ml-agents/Runtime/InferenceBrain/ApplierImpl.cs


/// </summary>
public class ContinuousActionOutputApplier : TensorApplier.IApplier
{
public void Apply(TensorProxy tensorProxy, IEnumerable<AgentIdActionPair> actions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
foreach (var idActionPair in actions)
foreach (int agentId in actionIds)
var actionValue = new float[actionSize];
for (var j = 0; j < actionSize; j++)
if (lastActions.ContainsKey(agentId))
actionValue[j] = tensorProxy.data[agentIndex, j];
var actionValue = lastActions[agentId];
if (actionValue == null)
{
actionValue = new float[actionSize];
lastActions[agentId] = actionValue;
}
for (var j = 0; j < actionSize; j++)
{
actionValue[j] = tensorProxy.data[agentIndex, j];
}
idActionPair.action.Invoke(new AgentAction { vectorActions = actionValue });
agentIndex++;
}
}

m_Allocator = allocator;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<AgentIdActionPair> actions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
var idActionPairList = actions as List<AgentIdActionPair> ?? actions.ToList();
var idActionPairList = actionIds as List<int> ?? actionIds.ToList();
var batchSize = idActionPairList.Count;
var actionValues = new float[batchSize, m_ActionSize.Length];
var startActionIndices = Utilities.CumSum(m_ActionSize);

outputTensor.data.Dispose();
}
var agentIndex = 0;
foreach (var idActionPair in idActionPairList)
foreach (int agentId in actionIds)
var actionVal = new float[m_ActionSize.Length];
for (var j = 0; j < m_ActionSize.Length; j++)
if (lastActions.ContainsKey(agentId))
actionVal[j] = actionValues[agentIndex, j];
var actionVal = lastActions[agentId];
if (actionVal == null)
{
actionVal = new float[m_ActionSize.Length];
lastActions[agentId] = actionVal;
}
for (var j = 0; j < m_ActionSize.Length; j++)
{
actionVal[j] = actionValues[agentIndex, j];
}
idActionPair.action.Invoke(new AgentAction { vectorActions = actionVal });
agentIndex++;
}
}

m_Memories = memories;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<AgentIdActionPair> actions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
foreach (var idActionPair in actions)
foreach (int agentId in actionIds)
if (!m_Memories.TryGetValue(idActionPair.agentId, out memory)
if (!m_Memories.TryGetValue(agentId, out memory)
|| memory.Count < memorySize)
{
memory = new List<float>();

m_Memories[idActionPair.agentId] = memory;
m_Memories[agentId] = memory;
agentIndex++;
}
}

m_Memories = memories;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<AgentIdActionPair> actions)
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
foreach (var idActionPair in actions)
foreach (int agentId in actionIds)
if (!m_Memories.TryGetValue(idActionPair.agentId, out memory)
if (!m_Memories.TryGetValue(agentId, out memory)
|| memory.Count < memorySize * m_MemoriesCount)
{
memory = new List<float>();

memory[memorySize * m_MemoryIndex + j] = tensorProxy.data[agentIndex, j];
}
m_Memories[idActionPair.agentId] = memory;
m_Memories[agentId] = memory;
agentIndex++;
}
}

38
com.unity.ml-agents/Runtime/InferenceBrain/ModelRunner.cs


public AgentInfo agentInfo;
public List<ISensor> sensors;
}
public struct AgentIdActionPair
{
public int agentId;
public Action<AgentAction> action;
}
List<AgentIdActionPair> m_ActionFuncs = new List<AgentIdActionPair>();
Dictionary<int, float[]> m_LastActionsReceived = new Dictionary<int, float[]>();
List<int> m_OrderedAgentsRequestingDecisions = new List<int>();
ITensorAllocator m_TensorAllocator;
TensorGenerator m_TensorGenerator;
TensorApplier m_TensorApplier;

return outputs;
}
public void PutObservations(AgentInfo info, List<ISensor> sensors, Action<AgentAction> action)
public void PutObservations(AgentInfo info, List<ISensor> sensors)
{
#if DEBUG
m_SensorShapeValidator.ValidateSensors(sensors);

sensors = sensors
});
m_ActionFuncs.Add(new AgentIdActionPair { action = action, agentId = info.episodeId });
// We add the episodeId to this list to maintain the order in which the decisions were requested
m_OrderedAgentsRequestingDecisions.Add(info.episodeId);
if (!m_LastActionsReceived.ContainsKey(info.episodeId))
{
m_LastActionsReceived[info.episodeId] = null;
}
if (info.done)
{
// If the agent is done, we remove the key from the last action dictionary since no action
// should be taken.
m_LastActionsReceived.Remove(info.episodeId);
}
}
public void DecideBatch()

Profiler.BeginSample($"MLAgents.{m_Model.name}.ApplyTensors");
// Update the outputs
m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_ActionFuncs);
m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived);
Profiler.EndSample();
Profiler.EndSample();

m_ActionFuncs.Clear();
m_OrderedAgentsRequestingDecisions.Clear();
}
public float[] GetAction(int agentId)
{
if (m_LastActionsReceived.ContainsKey(agentId))
{
return m_LastActionsReceived[agentId];
}
return null;
}
}
}

13
com.unity.ml-agents/Runtime/InferenceBrain/TensorApplier.cs


/// <param name="tensorProxy">
/// The Tensor containing the data to be applied to the Agents
/// </param>
/// <param name="agents">
/// List of Agents that will receive the values of the Tensor.
/// <param name="actionIds"> List of Agents Ids that will be updated using the tensor's data</param>
/// <param name="lastActions"> Dictionary of AgentId to Actions to be updated</param>
void Apply(TensorProxy tensorProxy, IEnumerable<AgentIdActionPair> actions);
void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions);
}
readonly Dictionary<string, IApplier> m_Dict = new Dictionary<string, IApplier>();

/// Updates the state of the agents based on the data present in the tensor.
/// </summary>
/// <param name="tensors"> Enumerable of tensors containing the data.</param>
/// <param name="agents"> List of Agents that will be updated using the tensor's data</param>
/// <param name="actionIds"> List of Agents Ids that will be updated using the tensor's data</param>
/// <param name="lastActions"> Dictionary of AgentId to Actions to be updated</param>
IEnumerable<TensorProxy> tensors, IEnumerable<AgentIdActionPair> actions)
IEnumerable<TensorProxy> tensors, IEnumerable<int> actionIds, Dictionary<int, float[]> lastActions)
{
foreach (var tensor in tensors)
{

$"Unknown tensorProxy expected as output : {tensor.name}");
}
m_Dict[tensor.name].Apply(tensor, actions);
m_Dict[tensor.name].Apply(tensor, actionIds, lastActions);
}
}
}

10
com.unity.ml-agents/Runtime/Policy/BarracudaPolicy.cs


{
protected ModelRunner m_ModelRunner;
private int m_AgentId;
/// <summary>
/// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors.
/// </summary>

}
/// <inheritdoc />
public void RequestDecision(AgentInfo info, List<ISensor> sensors, Action<AgentAction> action)
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
m_ModelRunner?.PutObservations(info, sensors, action);
m_AgentId = info.episodeId;
m_ModelRunner?.PutObservations(info, sensors);
public void DecideAction()
public float[] DecideAction()
return m_ModelRunner?.GetAction(m_AgentId);
}
public void Dispose()

14
com.unity.ml-agents/Runtime/Policy/HeuristicPolicy.cs


public class HeuristicPolicy : IPolicy
{
Func<float[]> m_Heuristic;
Action<AgentAction> m_ActionFunc;
float[] m_LastDecision;
/// <inheritdoc />
public HeuristicPolicy(Func<float[]> heuristic)

/// <inheritdoc />
public void RequestDecision(AgentInfo info, List<ISensor> sensors, Action<AgentAction> action)
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
m_ActionFunc = action;
m_LastDecision = m_Heuristic.Invoke();
public void DecideAction()
public float[] DecideAction()
if (m_ActionFunc != null)
{
m_ActionFunc.Invoke(new AgentAction { vectorActions = m_Heuristic.Invoke() });
m_ActionFunc = null;
}
return m_LastDecision;
}
public void Dispose()

4
com.unity.ml-agents/Runtime/Policy/IPolicy.cs


/// batching of requests.
/// </summary>
/// <param name="agent"></param>
void RequestDecision(AgentInfo info, List<ISensor> sensors, Action<AgentAction> action);
void RequestDecision(AgentInfo info, List<ISensor> sensors);
/// <summary>
/// Signals the Policy that if the Decision has not been taken yet,

void DecideAction();
float[] DecideAction();
}
}

17
com.unity.ml-agents/Runtime/Policy/RemotePolicy.cs


/// </summary>
public class RemotePolicy : IPolicy
{
int m_AgentId;
protected ICommunicator m_Communicator;
/// <summary>
/// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors.
/// </summary>
List<int[]> m_SensorShapes;
protected ICommunicator m_Communicator;
/// <inheritdoc />
public RemotePolicy(

}
/// <inheritdoc />
public void RequestDecision(AgentInfo info, List<ISensor> sensors, Action<AgentAction> action)
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
m_Communicator?.PutObservations(m_FullyQualifiedBehaviorName, info, sensors, action);
m_AgentId = info.episodeId;
m_Communicator?.PutObservations(m_FullyQualifiedBehaviorName, info, sensors);
public void DecideAction()
public float[] DecideAction()
return m_Communicator?.GetActions(m_FullyQualifiedBehaviorName, m_AgentId);
}
public void Dispose()

45
com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorApplier.cs


var applier = new ContinuousActionOutputApplier();
var action0 = new AgentAction();
var action1 = new AgentAction();
var callbacks = new List<AgentIdActionPair>()
{
new AgentIdActionPair {agentId = 0, action = (a) => action0 = a},
new AgentIdActionPair {agentId = 1, action = (a) => action1 = a}
};
var agentIds = new List<int>() { 0, 1 };
// Dictionary from AgentId to Action
var actionDict = new Dictionary<int, float[]>() { { 0, null }, { 1, null } };
applier.Apply(inputTensor, callbacks);
applier.Apply(inputTensor, agentIds, actionDict);
Assert.AreEqual(action0.vectorActions[0], 1);
Assert.AreEqual(action0.vectorActions[1], 2);
Assert.AreEqual(action0.vectorActions[2], 3);
Assert.AreEqual(actionDict[0][0], 1);
Assert.AreEqual(actionDict[0][1], 2);
Assert.AreEqual(actionDict[0][2], 3);
Assert.AreEqual(action1.vectorActions[0], 4);
Assert.AreEqual(action1.vectorActions[1], 5);
Assert.AreEqual(action1.vectorActions[2], 6);
Assert.AreEqual(actionDict[1][0], 4);
Assert.AreEqual(actionDict[1][1], 5);
Assert.AreEqual(actionDict[1][2], 6);
}
[Test]

var alloc = new TensorCachingAllocator();
var applier = new DiscreteActionOutputApplier(new[] { 2, 3 }, 0, alloc);
var action0 = new AgentAction();
var action1 = new AgentAction();
var callbacks = new List<AgentIdActionPair>()
{
new AgentIdActionPair {agentId = 0, action = (a) => action0 = a},
new AgentIdActionPair {agentId = 1, action = (a) => action1 = a}
};
var agentIds = new List<int>() { 0, 1 };
// Dictionary from AgentId to Action
var actionDict = new Dictionary<int, float[]>() { { 0, null }, { 1, null } };
applier.Apply(inputTensor, callbacks);
applier.Apply(inputTensor, agentIds, actionDict);
Assert.AreEqual(action0.vectorActions[0], 1);
Assert.AreEqual(action0.vectorActions[1], 1);
Assert.AreEqual(actionDict[0][0], 1);
Assert.AreEqual(actionDict[0][1], 1);
Assert.AreEqual(action1.vectorActions[0], 1);
Assert.AreEqual(action1.vectorActions[1], 2);
Assert.AreEqual(actionDict[1][0], 1);
Assert.AreEqual(actionDict[1][1], 2);
alloc.Dispose();
}
}

20
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


using NUnit.Framework;
using System.Reflection;
using MLAgents.Sensor;
using System.Collections.Generic;
public class TestPolicy : IPolicy
{
public void RequestDecision(AgentInfo info, List<ISensor> sensors) { }
public float[] DecideAction() { return new float[0]; }
public void Dispose() { }
}
public class TestAgent : Agent
{
public AgentInfo _Info

{
typeof(Agent).GetField("m_Info", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(this, value);
}
}
public void SetPolicy(IPolicy policy)
{
typeof(Agent).GetField("m_Brain", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(this, policy);
}
public int initializeAgentCalls;

return sensorName;
}
public void Update() {}
public void Update() { }
}
[TestFixture]

decisionRequester.DecisionPeriod = 2;
decisionRequester.Awake();
agent2.SetPolicy(new TestPolicy());
var j = 0;
for (var i = 0; i < 500; i++)

正在加载...
取消
保存