浏览代码

Develop remove memories (#2795)

* Initial commit removing memories from C# and deprecating memory fields in proto

* initial changes to Python

* Adding functionalities

* Fixes

* adding the memories to the dictionary

* Fixing bugs

* tweeks

* Resolving bugs

* Recreating the proto

* Addressing comments

* Passing by reference does not work. Do not merge

* Fixing huge bug in Inference

* Applying patches

* fixing tests

* Addressing comments

* Renaming variable to reflect type

* test
/develop-gpu-test
GitHub 5 年前
当前提交
0fe5adc2
共有 34 个文件被更改,包括 181 次插入446 次删除
  1. 11
      UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
  2. 51
      UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs
  3. 25
      UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs
  4. 31
      UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
  5. 33
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentAction.cs
  6. 41
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfo.cs
  7. 7
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
  8. 22
      UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs
  9. 44
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs
  10. 78
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
  11. 11
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelRunner.cs
  12. 11
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs
  13. 14
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs
  14. 9
      UnitySDK/Assets/ML-Agents/Scripts/Policy/BarracudaPolicy.cs
  15. 8
      UnitySDK/Assets/ML-Agents/Scripts/Policy/RemotePolicy.cs
  16. 1
      ml-agents-envs/mlagents/envs/action_info.py
  17. 42
      ml-agents-envs/mlagents/envs/brain.py
  18. 15
      ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.py
  19. 6
      ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.pyi
  20. 25
      ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.py
  21. 6
      ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.pyi
  22. 34
      ml-agents-envs/mlagents/envs/environment.py
  23. 1
      ml-agents-envs/mlagents/envs/mock_communicator.py
  24. 4
      ml-agents-envs/mlagents/envs/simple_env_manager.py
  25. 4
      ml-agents-envs/mlagents/envs/subprocess_env_manager.py
  26. 4
      ml-agents/mlagents/trainers/bc/policy.py
  27. 8
      ml-agents/mlagents/trainers/ppo/policy.py
  28. 17
      ml-agents/mlagents/trainers/rl_trainer.py
  29. 4
      ml-agents/mlagents/trainers/sac/policy.py
  30. 21
      ml-agents/mlagents/trainers/tests/test_policy.py
  31. 1
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  32. 34
      ml-agents/mlagents/trainers/tf_policy.py
  33. 2
      protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_action.proto
  34. 2
      protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_info.proto

11
UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs


{
vectorObservationSize = 3,
numStackedVectorObservations = 2,
vectorActionDescriptions = new[] {"TestActionA", "TestActionB"},
vectorActionSize = new[] {2, 2},
vectorActionDescriptions = new[] { "TestActionA", "TestActionB" },
vectorActionSize = new[] { 2, 2 },
vectorActionSpaceType = SpaceType.Discrete
};

var agentInfo = new AgentInfo
{
reward = 1f,
actionMasks = new[] {false, true},
actionMasks = new[] { false, true },
memories = new List<float>(),
stackedVectorObservation = new List<float>() {1f, 1f, 1f},
stackedVectorObservation = new List<float>() { 1f, 1f, 1f },
storedVectorActions = new[] {0f, 1f},
storedVectorActions = new[] { 0f, 1f },
textObservation = "TestAction",
};

51
UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs


{
public AgentAction GetAction()
{
var f = typeof(Agent).GetField(
var f = typeof(Agent).GetField(
"m_Action", BindingFlags.Instance | BindingFlags.NonPublic);
return (AgentAction)f.GetValue(this);
}

var goB = new GameObject("goB");
var agentB = goB.AddComponent<TestAgent>();
return new List<Agent> {agentA, agentB};
return new List<Agent> { agentA, agentB };
}
[Test]

var alloc = new TensorCachingAllocator();
var tensorGenerator = new TensorApplier(bp, 0, alloc);
var mem = new Dictionary<int, List<float>>();
var tensorGenerator = new TensorApplier(bp, 0, alloc, mem);
Assert.IsNotNull(tensorGenerator);
alloc.Dispose();
}

{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 3},
data = new Tensor(2, 3, new float[] {1, 2, 3, 4, 5, 6})
shape = new long[] { 2, 3 },
data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 })
};
var agentInfos = GetFakeAgentInfos();

{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 5},
shape = new long[] { 2, 5 },
new[] {0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f})
new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
var applier = new DiscreteActionOutputApplier(new[] {2, 3}, 0, alloc);
var applier = new DiscreteActionOutputApplier(new[] { 2, 3 }, 0, alloc);
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos;

}
[Test]
public void ApplyMemoryOutput()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 5},
data = new Tensor(
2,
5,
new[] {0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f})
};
var agentInfos = GetFakeAgentInfos();
var applier = new MemoryOutputApplier();
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos;
var agent = agents[0] as TestAgent;
Assert.NotNull(agent);
var action = agent.GetAction();
Assert.AreEqual(action.memories[0], 0.5f);
Assert.AreEqual(action.memories[1], 22.5f);
agent = agents[1] as TestAgent;
Assert.NotNull(agent);
action = agent.GetAction();
Assert.AreEqual(action.memories[2], 6);
Assert.AreEqual(action.memories[3], 7);
}
[Test]
shape = new long[] {2, 1},
data = new Tensor(2, 1, new[] {0.5f, 8f})
shape = new long[] { 2, 1 },
data = new Tensor(2, 1, new[] { 0.5f, 8f })
};
var agentInfos = GetFakeAgentInfos();

25
UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs


var infoA = new AgentInfo
{
stackedVectorObservation = new[] { 1f, 2f, 3f }.ToList(),
memories = null,
storedVectorActions = new[] { 1f, 2f },
actionMasks = null
};

{
stackedVectorObservation = new[] { 4f, 5f, 6f }.ToList(),
memories = new[] { 1f, 1f, 1f }.ToList(),
storedVectorActions = new[] { 3f, 4f },
actionMasks = new[] { true, false, false, false, false },
};

{
var bp = new BrainParameters();
var alloc = new TensorCachingAllocator();
var tensorGenerator = new TensorGenerator(bp, 0, alloc);
var mem = new Dictionary<int, List<float>>();
var tensorGenerator = new TensorGenerator(bp, 0, alloc, mem);
Assert.IsNotNull(tensorGenerator);
alloc.Dispose();
}

Assert.AreEqual(inputTensor.data[0, 2], 3);
Assert.AreEqual(inputTensor.data[1, 0], 4);
Assert.AreEqual(inputTensor.data[1, 2], 6);
alloc.Dispose();
}
[Test]
public void GenerateRecurrentInput()
{
var inputTensor = new TensorProxy
{
shape = new long[] { 2, 5 }
};
const int batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var alloc = new TensorCachingAllocator();
var generator = new RecurrentInputGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 0);
Assert.AreEqual(inputTensor.data[0, 4], 0);
Assert.AreEqual(inputTensor.data[1, 0], 1);
Assert.AreEqual(inputTensor.data[1, 4], 0);
alloc.Dispose();
}

31
UnitySDK/Assets/ML-Agents/Scripts/Agent.cs


public bool[] actionMasks;
/// <summary>
/// Used by the Trainer to store information about the agent. This data
/// structure is not consumed or modified by the agent directly, they are
/// just the owners of their trainier's memory. Currently, however, the
/// size of the memory is in the Brain properties.
/// </summary>
public List<float> memories;
/// <summary>
/// Current agent reward.
/// </summary>
public float reward;

{
public float[] vectorActions;
public string textActions;
public List<float> memories;
public float value;
/// TODO(cgoy): All references to protobuf objects should be removed.
public CommunicatorObjects.CustomActionProto customAction;

if (m_Info.textObservation == null)
m_Info.textObservation = "";
m_Action.textActions = "";
m_Info.memories = new List<float>();
m_Action.memories = new List<float>();
m_Info.vectorObservation =
new List<float>(param.vectorObservationSize);
m_Info.stackedVectorObservation =

return;
}
m_Info.memories = m_Action.memories;
m_Info.storedVectorActions = m_Action.vectorActions;
m_Info.storedTextActions = m_Action.textActions;
m_Info.vectorObservation.Clear();

public void UpdateVectorAction(float[] vectorActions)
{
m_Action.vectorActions = vectorActions;
}
/// <summary>
/// Updates the memories action.
/// </summary>
/// <param name="memories">Memories.</param>
public void UpdateMemoriesAction(List<float> memories)
{
m_Action.memories = memories;
}
public void AppendMemoriesAction(List<float> memories)
{
m_Action.memories.AddRange(memories);
}
public List<float> GetMemoriesAction()
{
return m_Action.memories;
}
/// <summary>

33
UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentAction.cs


string.Concat(
"CjVtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2Fj",
"dGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaNm1sYWdlbnRzL2Vu",
"dnMvY29tbXVuaWNhdG9yX29iamVjdHMvY3VzdG9tX2FjdGlvbi5wcm90byKh",
"dnMvY29tbXVuaWNhdG9yX29iamVjdHMvY3VzdG9tX2FjdGlvbi5wcm90byKV",
"Cgx0ZXh0X2FjdGlvbnMYAiABKAkSEAoIbWVtb3JpZXMYAyADKAISDQoFdmFs",
"dWUYBCABKAISPgoNY3VzdG9tX2FjdGlvbhgFIAEoCzInLmNvbW11bmljYXRv",
"cl9vYmplY3RzLkN1c3RvbUFjdGlvblByb3RvQh+qAhxNTEFnZW50cy5Db21t",
"dW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"Cgx0ZXh0X2FjdGlvbnMYAiABKAkSDQoFdmFsdWUYBCABKAISPgoNY3VzdG9t",
"X2FjdGlvbhgFIAEoCzInLmNvbW11bmljYXRvcl9vYmplY3RzLkN1c3RvbUFj",
"dGlvblByb3RvSgQIAxAEQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmpl",
"Y3RzYgZwcm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "TextActions", "Memories", "Value", "CustomAction" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "TextActions", "Value", "CustomAction" }, null, null, null)
}));
}
#endregion

public AgentActionProto(AgentActionProto other) : this() {
vectorActions_ = other.vectorActions_.Clone();
textActions_ = other.textActions_;
memories_ = other.memories_.Clone();
value_ = other.value_;
CustomAction = other.customAction_ != null ? other.CustomAction.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);

}
}
/// <summary>Field number for the "memories" field.</summary>
public const int MemoriesFieldNumber = 3;
private static readonly pb::FieldCodec<float> _repeated_memories_codec
= pb::FieldCodec.ForFloat(26);
private readonly pbc::RepeatedField<float> memories_ = new pbc::RepeatedField<float>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<float> Memories {
get { return memories_; }
}
/// <summary>Field number for the "value" field.</summary>
public const int ValueFieldNumber = 4;
private float value_;

}
if(!vectorActions_.Equals(other.vectorActions_)) return false;
if (TextActions != other.TextActions) return false;
if(!memories_.Equals(other.memories_)) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Value, other.Value)) return false;
if (!object.Equals(CustomAction, other.CustomAction)) return false;
return Equals(_unknownFields, other._unknownFields);

int hash = 1;
hash ^= vectorActions_.GetHashCode();
if (TextActions.Length != 0) hash ^= TextActions.GetHashCode();
hash ^= memories_.GetHashCode();
if (Value != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Value);
if (customAction_ != null) hash ^= CustomAction.GetHashCode();
if (_unknownFields != null) {

output.WriteRawTag(18);
output.WriteString(TextActions);
}
memories_.WriteTo(output, _repeated_memories_codec);
if (Value != 0F) {
output.WriteRawTag(37);
output.WriteFloat(Value);

if (TextActions.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(TextActions);
}
size += memories_.CalculateSize(_repeated_memories_codec);
if (Value != 0F) {
size += 1 + 4;
}

if (other.TextActions.Length != 0) {
TextActions = other.TextActions;
}
memories_.Add(other.memories_);
if (other.Value != 0F) {
Value = other.Value;
}

}
case 18: {
TextActions = input.ReadString();
break;
}
case 26:
case 29: {
memories_.AddEntriesFrom(input, _repeated_memories_codec);
break;
}
case 37: {

41
UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfo.cs


"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGj9tbGFnZW50cy9lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL2NvbXByZXNzZWRfb2JzZXJ2YXRpb24u",
"cHJvdG8aO21sYWdlbnRzL2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvY3Vz",
"dG9tX29ic2VydmF0aW9uLnByb3RvIpgDCg5BZ2VudEluZm9Qcm90bxIiChpz",
"dG9tX29ic2VydmF0aW9uLnByb3RvIowDCg5BZ2VudEluZm9Qcm90bxIiChpz",
"ChNzdG9yZWRfdGV4dF9hY3Rpb25zGAUgASgJEhAKCG1lbW9yaWVzGAYgAygC",
"Eg4KBnJld2FyZBgHIAEoAhIMCgRkb25lGAggASgIEhgKEG1heF9zdGVwX3Jl",
"YWNoZWQYCSABKAgSCgoCaWQYCiABKAUSEwoLYWN0aW9uX21hc2sYCyADKAgS",
"SAoSY3VzdG9tX29ic2VydmF0aW9uGAwgASgLMiwuY29tbXVuaWNhdG9yX29i",
"amVjdHMuQ3VzdG9tT2JzZXJ2YXRpb25Qcm90bxJRChdjb21wcmVzc2VkX29i",
"c2VydmF0aW9ucxgNIAMoCzIwLmNvbW11bmljYXRvcl9vYmplY3RzLkNvbXBy",
"ZXNzZWRPYnNlcnZhdGlvblByb3RvSgQIAhADQh+qAhxNTEFnZW50cy5Db21t",
"dW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"ChNzdG9yZWRfdGV4dF9hY3Rpb25zGAUgASgJEg4KBnJld2FyZBgHIAEoAhIM",
"CgRkb25lGAggASgIEhgKEG1heF9zdGVwX3JlYWNoZWQYCSABKAgSCgoCaWQY",
"CiABKAUSEwoLYWN0aW9uX21hc2sYCyADKAgSSAoSY3VzdG9tX29ic2VydmF0",
"aW9uGAwgASgLMiwuY29tbXVuaWNhdG9yX29iamVjdHMuQ3VzdG9tT2JzZXJ2",
"YXRpb25Qcm90bxJRChdjb21wcmVzc2VkX29ic2VydmF0aW9ucxgNIAMoCzIw",
"LmNvbW11bmljYXRvcl9vYmplY3RzLkNvbXByZXNzZWRPYnNlcnZhdGlvblBy",
"b3RvSgQIAhADSgQIBhAHQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmpl",
"Y3RzYgZwcm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "StackedVectorObservation", "TextObservation", "StoredVectorActions", "StoredTextActions", "Memories", "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "CustomObservation", "CompressedObservations" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "StackedVectorObservation", "TextObservation", "StoredVectorActions", "StoredTextActions", "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "CustomObservation", "CompressedObservations" }, null, null, null)
}));
}
#endregion

textObservation_ = other.textObservation_;
storedVectorActions_ = other.storedVectorActions_.Clone();
storedTextActions_ = other.storedTextActions_;
memories_ = other.memories_.Clone();
reward_ = other.reward_;
done_ = other.done_;
maxStepReached_ = other.maxStepReached_;

}
}
/// <summary>Field number for the "memories" field.</summary>
public const int MemoriesFieldNumber = 6;
private static readonly pb::FieldCodec<float> _repeated_memories_codec
= pb::FieldCodec.ForFloat(50);
private readonly pbc::RepeatedField<float> memories_ = new pbc::RepeatedField<float>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<float> Memories {
get { return memories_; }
}
/// <summary>Field number for the "reward" field.</summary>
public const int RewardFieldNumber = 7;
private float reward_;

if (TextObservation != other.TextObservation) return false;
if(!storedVectorActions_.Equals(other.storedVectorActions_)) return false;
if (StoredTextActions != other.StoredTextActions) return false;
if(!memories_.Equals(other.memories_)) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Reward, other.Reward)) return false;
if (Done != other.Done) return false;
if (MaxStepReached != other.MaxStepReached) return false;

if (TextObservation.Length != 0) hash ^= TextObservation.GetHashCode();
hash ^= storedVectorActions_.GetHashCode();
if (StoredTextActions.Length != 0) hash ^= StoredTextActions.GetHashCode();
hash ^= memories_.GetHashCode();
if (Reward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Reward);
if (Done != false) hash ^= Done.GetHashCode();
if (MaxStepReached != false) hash ^= MaxStepReached.GetHashCode();

output.WriteRawTag(42);
output.WriteString(StoredTextActions);
}
memories_.WriteTo(output, _repeated_memories_codec);
if (Reward != 0F) {
output.WriteRawTag(61);
output.WriteFloat(Reward);

if (StoredTextActions.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(StoredTextActions);
}
size += memories_.CalculateSize(_repeated_memories_codec);
if (Reward != 0F) {
size += 1 + 4;
}

if (other.StoredTextActions.Length != 0) {
StoredTextActions = other.StoredTextActions;
}
memories_.Add(other.memories_);
if (other.Reward != 0F) {
Reward = other.Reward;
}

}
case 42: {
StoredTextActions = input.ReadString();
break;
}
case 50:
case 53: {
memories_.AddEntriesFrom(input, _repeated_memories_codec);
break;
}
case 61: {

7
UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs


Id = ai.id,
CustomObservation = ai.customObservation
};
if (ai.memories != null)
{
agentInfoProto.Memories.Add(ai.memories);
}
if (ai.actionMasks != null)
{

{
vectorActions = aap.VectorActions.ToArray(),
textActions = aap.TextActions,
memories = aap.Memories.ToList(),
value = aap.Value,
customAction = aap.CustomAction
};

var obsProto = new CompressedObservationProto
{
Data = ByteString.CopyFrom(obs.Data),
CompressionType = (CompressionTypeProto) obs.CompressionType,
CompressionType = (CompressionTypeProto)obs.CompressionType,
};
obsProto.Shape.AddRange(obs.Shape);
return obsProto;

22
UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs


UnityOutput and UnityInput can be extended to provide functionalities beyond RL
UnityRLOutput and UnityRLInput can be extended to provide new RL functionalities
*/
public interface ICommunicator : IBatchedDecisionMaker
public interface ICommunicator
{
/// <summary>
/// Quit was received by the communicator.

void SubscribeBrain(string name, BrainParameters brainParameters);
/// <summary>
/// Sends the observations of one Agent.
/// </summary>
/// <param name="key">Batch Key.</param>
/// <param name="agents">Agent info.</param>
void PutObservations(string brainKey, Agent agent);
/// <summary>
/// Signals the ICommunicator that the Agents are now ready to receive their action
/// and that if the communicator has not yet received an action for one of the Agents
/// it needs to get one at this point.
/// </summary>
void DecideBatch();
/// <summary>
}
public interface IBatchedDecisionMaker : IDisposable
{
void PutObservations(string key, Agent agent);
void DecideBatch();
}
}

44
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs


var actionProbs = new TensorProxy()
{
valueType = TensorProxy.TensorType.FloatingPoint,
shape = new long[] {batchSize, nBranchAction},
shape = new long[] { batchSize, nBranchAction },
data = m_Allocator.Alloc(new TensorShape(batchSize, nBranchAction))
};

var outputTensor = new TensorProxy()
{
valueType = TensorProxy.TensorType.FloatingPoint,
shape = new long[] {batchSize, 1},
shape = new long[] { batchSize, 1 },
data = m_Allocator.Alloc(new TensorShape(batchSize, 1))
};

private readonly int m_MemoriesCount;
private readonly int m_MemoryIndex;
public BarracudaMemoryOutputApplier(int memoriesCount, int memoryIndex)
private Dictionary<int, List<float>> m_Memories;
public BarracudaMemoryOutputApplier(
int memoriesCount,
int memoryIndex,
Dictionary<int, List<float>> memories)
m_Memories = memories;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<Agent> agents)

foreach (var agent in agents)
{
var memory = agent.GetMemoriesAction();
if (memory == null || memory.Count < memorySize * m_MemoriesCount)
List<float> memory = null;
if (!m_Memories.TryGetValue(agent.Info.id, out memory)
|| memory.Count < memorySize * m_MemoriesCount)
{
memory = new List<float>();
memory.AddRange(Enumerable.Repeat(0f, memorySize * m_MemoriesCount));

memory[memorySize * m_MemoryIndex + j] = tensorProxy.data[agentIndex, j];
}
agent.UpdateMemoriesAction(memory);
m_Memories[agent.Info.id] = memory;
/// <summary>
/// The Applier for the Memory output tensor. Tensor is assumed to contain the new
/// memory data of the agents in the batch.
/// </summary>
public class MemoryOutputApplier : TensorApplier.IApplier
{
public void Apply(TensorProxy tensorProxy, IEnumerable<Agent> agents)
{
var agentIndex = 0;
var memorySize = tensorProxy.shape[tensorProxy.shape.Length - 1];
foreach (var agent in agents)
{
var memory = new List<float>();
for (var j = 0; j < memorySize; j++)
{
memory.Add(tensorProxy.data[agentIndex, j]);
}
agent.UpdateMemoriesAction(memory);
agentIndex++;
}
}
}
/// <summary>
/// The Applier for the Value Estimate output tensor. Tensor is assumed to contain the

78
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs


m_Allocator = allocator;
}
public void Generate(
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
var vecObsSizeT = tensorProxy.shape[tensorProxy.shape.Length - 1];

}
}
/// <summary>
/// Generates the Tensor corresponding to the Recurrent input : Will be a two
/// dimensional float array of dimension [batchSize x memorySize].
/// It will use the Memory data contained in the agentInfo to fill the data
/// of the tensor.
/// </summary>
public class RecurrentInputGenerator : TensorGenerator.IGenerator
{
private readonly ITensorAllocator m_Allocator;
public RecurrentInputGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
public void Generate(
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
var memorySize = tensorProxy.shape[tensorProxy.shape.Length - 1];
var agentIndex = 0;
foreach (var agent in agents)
{
var info = agent.Info;
var memory = info.memories;
if (memory == null)
{
agentIndex++;
continue;
}
for (var j = 0; j < Math.Min(memorySize, memory.Count); j++)
{
if (j >= memory.Count)
{
break;
}
tensorProxy.data[agentIndex, j] = memory[j];
}
agentIndex++;
}
}
}
public class BarracudaRecurrentInputGenerator : TensorGenerator.IGenerator
{

public BarracudaRecurrentInputGenerator(int memoryIndex, ITensorAllocator allocator)
private Dictionary<int, List<float>> m_Memories;
public BarracudaRecurrentInputGenerator(
int memoryIndex,
ITensorAllocator allocator,
Dictionary<int, List<float>> memories)
m_Memories = memories;
public void Generate(
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);

{
var agentInfo = agent.Info;
var memory = agentInfo.memories;
if (memory == null)
List<float> memory = null;
if (!m_Memories.TryGetValue(agent.Info.id, out memory))
{
agentIndex++;
continue;

{
break;
}
tensorProxy.data[agentIndex, j] = memory[j + offset];
}
agentIndex++;

m_Allocator = allocator;
}
public void Generate(
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);

m_Allocator = allocator;
}
public void Generate(
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);

m_Allocator = allocator;
}
public void Generate(
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
TensorUtils.FillTensorWithRandomNormal(tensorProxy, m_RandomNormal);

m_Allocator = allocator;
}
public void Generate(
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
var agentIndex = 0;

11
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelRunner.cs


namespace MLAgents.InferenceBrain
{
public class ModelRunner : IBatchedDecisionMaker
public class ModelRunner
{
private List<Agent> m_Agents = new List<Agent>();
private ITensorAllocator m_TensorAllocator;

private string[] m_OutputNames;
private IReadOnlyList<TensorProxy> m_InferenceInputs;
private IReadOnlyList<TensorProxy> m_InferenceOutputs;
private Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>();
private bool m_visualObservationsInitialized = false;

m_InferenceInputs = BarracudaModelParamLoader.GetInputTensors(barracudaModel);
m_OutputNames = BarracudaModelParamLoader.GetOutputNames(barracudaModel);
m_TensorGenerator = new TensorGenerator(brainParameters, seed, m_TensorAllocator, barracudaModel);
m_TensorApplier = new TensorApplier(brainParameters, seed, m_TensorAllocator, barracudaModel);
m_TensorGenerator = new TensorGenerator(
brainParameters, seed, m_TensorAllocator, m_Memories, barracudaModel);
m_TensorApplier = new TensorApplier(
brainParameters, seed, m_TensorAllocator, m_Memories, barracudaModel);
}
private static Dictionary<string, Tensor> PrepareBarracudaInputs(IEnumerable<TensorProxy> infInputs)

return outputs;
}
public void PutObservations(string key, Agent agent)
public void PutObservations(Agent agent)
{
m_Agents.Add(agent);
}

11
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs


/// <param name="allocator"> Tensor allocator</param>
/// <param name="barracudaModel"></param>
public TensorApplier(
BrainParameters bp, int seed, ITensorAllocator allocator, object barracudaModel = null)
BrainParameters bp,
int seed,
ITensorAllocator allocator,
Dictionary<int, List<float>> memories,
object barracudaModel = null)
{
m_Dict[TensorNames.ValueEstimateOutput] = new ValueEstimateApplier();
if (bp.vectorActionSpaceType == SpaceType.Continuous)

m_Dict[TensorNames.ActionOutput] =
new DiscreteActionOutputApplier(bp.vectorActionSize, seed, allocator);
}
m_Dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier();
if (barracudaModel != null)
{

{
m_Dict[model.memories[i].output] =
new BarracudaMemoryOutputApplier(model.memories.Length, i);
new BarracudaMemoryOutputApplier(model.memories.Length, i, memories);
}
}
}

/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated applier.</exception>
public void ApplyTensors(
IEnumerable<TensorProxy> tensors, IEnumerable<Agent> agents)
IEnumerable<TensorProxy> tensors, IEnumerable<Agent> agents)
{
foreach (var tensor in tensors)
{

14
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs


/// <param name="allocator"> Tensor allocator</param>
/// <param name="barracudaModel"></param>
public TensorGenerator(
BrainParameters bp, int seed, ITensorAllocator allocator, object barracudaModel = null)
BrainParameters bp,
int seed,
ITensorAllocator allocator,
Dictionary<int, List<float>> memories,
object barracudaModel = null)
{
// Generator for Inputs
m_Dict[TensorNames.BatchSizePlaceholder] =

m_Dict[TensorNames.VectorObservationPlacholder] =
new VectorObservationGenerator(allocator);
m_Dict[TensorNames.RecurrentInPlaceholder] =
new RecurrentInputGenerator(allocator);
if (barracudaModel != null)
{

m_Dict[model.memories[i].input] =
new BarracudaRecurrentInputGenerator(i, allocator);
new BarracudaRecurrentInputGenerator(i, allocator, memories);
}
}

/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated generator.</exception>
public void GenerateTensors(
IEnumerable<TensorProxy> tensors,
int currentBatchSize,
IEnumerable<Agent> agents)
IEnumerable<TensorProxy> tensors, int currentBatchSize, IEnumerable<Agent> agents)
{
foreach (var tensor in tensors)
{

9
UnitySDK/Assets/ML-Agents/Scripts/Policy/BarracudaPolicy.cs


using UnityEngine;
using Barracuda;
using System.Collections.Generic;
using MLAgents.InferenceBrain;
namespace MLAgents
{

public class BarracudaPolicy : IPolicy
{
protected IBatchedDecisionMaker m_BatchedDecisionMaker;
protected ModelRunner m_ModelRunner;
/// <summary>
/// Sensor shapes for the associated Agents. All Agents must have the same shapes for their sensors.

var aca = GameObject.FindObjectOfType<Academy>();
aca.LazyInitialization();
var modelRunner = aca.GetOrCreateModelRunner(model, brainParameters, inferenceDevice);
m_BatchedDecisionMaker = modelRunner;
m_ModelRunner = modelRunner;
}
/// <inheritdoc />

ValidateAgentSensorShapes(agent);
#endif
m_BatchedDecisionMaker?.PutObservations(null, agent);
m_ModelRunner?.PutObservations(agent);
m_BatchedDecisionMaker?.DecideBatch();
m_ModelRunner?.DecideBatch();
}
/// <summary>

8
UnitySDK/Assets/ML-Agents/Scripts/Policy/RemotePolicy.cs


{
private string m_BehaviorName;
protected IBatchedDecisionMaker m_BatchedDecisionMaker;
protected ICommunicator m_Communicator;
/// <summary>
/// Sensor shapes for the associated Agents. All Agents must have the same shapes for their sensors.

m_BehaviorName = behaviorName;
var aca = GameObject.FindObjectOfType<Academy>();
aca.LazyInitialization();
m_BatchedDecisionMaker = aca.Communicator;
m_Communicator = aca.Communicator;
aca.Communicator.SubscribeBrain(m_BehaviorName, brainParameters);
}

#if DEBUG
ValidateAgentSensorShapes(agent);
#endif
m_BatchedDecisionMaker?.PutObservations(m_BehaviorName, agent);
m_Communicator?.PutObservations(m_BehaviorName, agent);
m_BatchedDecisionMaker?.DecideBatch();
m_Communicator?.DecideBatch();
}
/// <summary>

1
ml-agents-envs/mlagents/envs/action_info.py


class ActionInfo(NamedTuple):
action: Any
memory: Any
text: Any
value: Any
outputs: ActionInfoOutputs

42
ml-agents-envs/mlagents/envs/brain.py


visual_observation,
vector_observation,
text_observations,
memory=None,
reward=None,
agents=None,
local_done=None,

self.visual_observations = visual_observation
self.vector_observations = vector_observation
self.text_observations = text_observations
self.memories = memory
self.rewards = reward
self.local_done = local_done
self.max_reached = max_reached

self.action_masks = action_mask
self.custom_observations = custom_observations
def merge(self, other):
for i in range(len(self.visual_observations)):
self.visual_observations[i].extend(other.visual_observations[i])
self.vector_observations = np.append(
self.vector_observations, other.vector_observations, axis=0
)
self.text_observations.extend(other.text_observations)
self.memories = self.merge_memories(
self.memories, other.memories, self.agents, other.agents
)
self.rewards = safe_concat_lists(self.rewards, other.rewards)
self.local_done = safe_concat_lists(self.local_done, other.local_done)
self.max_reached = safe_concat_lists(self.max_reached, other.max_reached)
self.agents = safe_concat_lists(self.agents, other.agents)
self.previous_vector_actions = safe_concat_np_ndarray(
self.previous_vector_actions, other.previous_vector_actions
)
self.previous_text_actions = safe_concat_lists(
self.previous_text_actions, other.previous_text_actions
)
self.action_masks = safe_concat_np_ndarray(
self.action_masks, other.action_masks
)
self.custom_observations = safe_concat_lists(
self.custom_observations, other.custom_observations
)
@staticmethod
def merge_memories(m1, m2, agents1, agents2):
if len(m1) == 0 and len(m2) != 0:

for x in agent_info_list
]
vis_obs += [obs]
if len(agent_info_list) == 0:
memory_size = 0
else:
memory_size = max(len(x.memories) for x in agent_info_list)
if memory_size == 0:
memory = np.zeros((0, 0))
else:
[
x.memories.extend([0] * (memory_size - len(x.memories)))
for x in agent_info_list
]
memory = np.array([list(x.memories) for x in agent_info_list])
total_num_actions = sum(brain_params.vector_action_space_size)
mask_actions = np.ones((len(agent_info_list), total_num_actions))
for agent_index, agent_info in enumerate(agent_info_list):

visual_observation=vis_obs,
vector_observation=vector_obs,
text_observations=[x.text_observation for x in agent_info_list],
memory=memory,
reward=[x.reward if not np.isnan(x.reward) else 0 for x in agent_info_list],
agents=agents,
local_done=[x.done for x in agent_info_list],

15
ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.py


name='mlagents/envs/communicator_objects/agent_action.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n5mlagents/envs/communicator_objects/agent_action.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents/envs/communicator_objects/custom_action.proto\"\xa1\x01\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x12\r\n\x05value\x18\x04 \x01(\x02\x12>\n\rcustom_action\x18\x05 \x01(\x0b\x32\'.communicator_objects.CustomActionProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n5mlagents/envs/communicator_objects/agent_action.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents/envs/communicator_objects/custom_action.proto\"\x95\x01\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\r\n\x05value\x18\x04 \x01(\x02\x12>\n\rcustom_action\x18\x05 \x01(\x0b\x32\'.communicator_objects.CustomActionProtoJ\x04\x08\x03\x10\x04\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2.DESCRIPTOR,])

is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='memories', full_name='communicator_objects.AgentActionProto.memories', index=2,
number=3, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='value', full_name='communicator_objects.AgentActionProto.value', index=3,
name='value', full_name='communicator_objects.AgentActionProto.value', index=2,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,

name='custom_action', full_name='communicator_objects.AgentActionProto.custom_action', index=4,
name='custom_action', full_name='communicator_objects.AgentActionProto.custom_action', index=3,
number=5, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,

oneofs=[
],
serialized_start=136,
serialized_end=297,
serialized_end=285,
)
_AGENTACTIONPROTO.fields_by_name['custom_action'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2._CUSTOMACTIONPROTO

6
ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.pyi


DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
vector_actions = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___float]
text_actions = ... # type: typing___Text
memories = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___float]
value = ... # type: builtin___float
@property

*,
vector_actions : typing___Optional[typing___Iterable[builtin___float]] = None,
text_actions : typing___Optional[typing___Text] = None,
memories : typing___Optional[typing___Iterable[builtin___float]] = None,
value : typing___Optional[builtin___float] = None,
custom_action : typing___Optional[mlagents___envs___communicator_objects___custom_action_pb2___CustomActionProto] = None,
) -> None: ...

def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"custom_action"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"custom_action",u"memories",u"text_actions",u"value",u"vector_actions"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"custom_action",u"text_actions",u"value",u"vector_actions"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"custom_action",b"custom_action",u"memories",b"memories",u"text_actions",b"text_actions",u"value",b"value",u"vector_actions",b"vector_actions"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"custom_action",b"custom_action",u"text_actions",b"text_actions",u"value",b"value",u"vector_actions",b"vector_actions"]) -> None: ...

25
ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.py


name='mlagents/envs/communicator_objects/agent_info.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n3mlagents/envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a?mlagents/envs/communicator_objects/compressed_observation.proto\x1a;mlagents/envs/communicator_objects/custom_observation.proto\"\x98\x03\n\x0e\x41gentInfoProto\x12\"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x18\n\x10text_observation\x18\x03 \x01(\t\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x1b\n\x13stored_text_actions\x18\x05 \x01(\t\x12\x10\n\x08memories\x18\x06 \x03(\x02\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12H\n\x12\x63ustom_observation\x18\x0c \x01(\x0b\x32,.communicator_objects.CustomObservationProto\x12Q\n\x17\x63ompressed_observations\x18\r \x03(\x0b\x32\x30.communicator_objects.CompressedObservationProtoJ\x04\x08\x02\x10\x03\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n3mlagents/envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a?mlagents/envs/communicator_objects/compressed_observation.proto\x1a;mlagents/envs/communicator_objects/custom_observation.proto\"\x8c\x03\n\x0e\x41gentInfoProto\x12\"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x18\n\x10text_observation\x18\x03 \x01(\t\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x1b\n\x13stored_text_actions\x18\x05 \x01(\t\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12H\n\x12\x63ustom_observation\x18\x0c \x01(\x0b\x32,.communicator_objects.CustomObservationProto\x12Q\n\x17\x63ompressed_observations\x18\r \x03(\x0b\x32\x30.communicator_objects.CompressedObservationProtoJ\x04\x08\x02\x10\x03J\x04\x08\x06\x10\x07\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_compressed__observation__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2.DESCRIPTOR,])

is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='memories', full_name='communicator_objects.AgentInfoProto.memories', index=4,
number=6, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=5,
name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=4,
number=7, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,

name='done', full_name='communicator_objects.AgentInfoProto.done', index=6,
name='done', full_name='communicator_objects.AgentInfoProto.done', index=5,
number=8, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,

name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=7,
name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=6,
number=9, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,

name='id', full_name='communicator_objects.AgentInfoProto.id', index=8,
name='id', full_name='communicator_objects.AgentInfoProto.id', index=7,
number=10, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,

name='action_mask', full_name='communicator_objects.AgentInfoProto.action_mask', index=9,
name='action_mask', full_name='communicator_objects.AgentInfoProto.action_mask', index=8,
number=11, type=8, cpp_type=7, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,

name='custom_observation', full_name='communicator_objects.AgentInfoProto.custom_observation', index=10,
name='custom_observation', full_name='communicator_objects.AgentInfoProto.custom_observation', index=9,
number=12, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,

name='compressed_observations', full_name='communicator_objects.AgentInfoProto.compressed_observations', index=11,
name='compressed_observations', full_name='communicator_objects.AgentInfoProto.compressed_observations', index=10,
number=13, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,

oneofs=[
],
serialized_start=204,
serialized_end=612,
serialized_end=600,
)
_AGENTINFOPROTO.fields_by_name['custom_observation'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2._CUSTOMOBSERVATIONPROTO

6
ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.pyi


text_observation = ... # type: typing___Text
stored_vector_actions = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___float]
stored_text_actions = ... # type: typing___Text
memories = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___float]
reward = ... # type: builtin___float
done = ... # type: builtin___bool
max_step_reached = ... # type: builtin___bool

text_observation : typing___Optional[typing___Text] = None,
stored_vector_actions : typing___Optional[typing___Iterable[builtin___float]] = None,
stored_text_actions : typing___Optional[typing___Text] = None,
memories : typing___Optional[typing___Iterable[builtin___float]] = None,
reward : typing___Optional[builtin___float] = None,
done : typing___Optional[builtin___bool] = None,
max_step_reached : typing___Optional[builtin___bool] = None,

def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"custom_observation"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"compressed_observations",u"custom_observation",u"done",u"id",u"max_step_reached",u"memories",u"reward",u"stacked_vector_observation",u"stored_text_actions",u"stored_vector_actions",u"text_observation"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"compressed_observations",u"custom_observation",u"done",u"id",u"max_step_reached",u"reward",u"stacked_vector_observation",u"stored_text_actions",u"stored_vector_actions",u"text_observation"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"compressed_observations",b"compressed_observations",u"custom_observation",b"custom_observation",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"memories",b"memories",u"reward",b"reward",u"stacked_vector_observation",b"stacked_vector_observation",u"stored_text_actions",b"stored_text_actions",u"stored_vector_actions",b"stored_vector_actions",u"text_observation",b"text_observation"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"compressed_observations",b"compressed_observations",u"custom_observation",b"custom_observation",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"reward",b"reward",u"stacked_vector_observation",b"stacked_vector_observation",u"stored_text_actions",b"stored_text_actions",u"stored_vector_actions",b"stored_vector_actions",u"text_observation",b"text_observation"]) -> None: ...

34
ml-agents-envs/mlagents/envs/environment.py


def step(
self,
vector_action: Dict[str, np.ndarray] = None,
memory: Optional[Dict[str, np.ndarray]] = None,
text_action: Optional[Dict[str, List[str]]] = None,
value: Optional[Dict[str, np.ndarray]] = None,
custom_action: Dict[str, Any] = None,

if self._is_first_message:
return self.reset()
vector_action = {} if vector_action is None else vector_action
memory = {} if memory is None else memory
text_action = {} if text_action is None else text_action
value = {} if value is None else value
custom_action = {} if custom_action is None else custom_action

"step cannot take a vector_action input"
)
if isinstance(memory, self.SINGLE_BRAIN_ACTION_TYPES):
if self._num_external_brains == 1:
memory = {self._external_brain_names[0]: memory}
elif self._num_external_brains > 1:
raise UnityActionException(
"You have {0} brains, you need to feed a dictionary of brain names as keys "
"and memories as values".format(self._num_external_brains)
)
else:
raise UnityActionException(
"There are no external brains in the environment, "
"step cannot take a memory input"
)
if isinstance(text_action, self.SINGLE_BRAIN_TEXT_TYPES):
if self._num_external_brains == 1:
text_action = {self._external_brain_names[0]: text_action}

"step cannot take a custom_action input"
)
for brain_name in (
list(vector_action.keys())
+ list(memory.keys())
+ list(text_action.keys())
):
for brain_name in list(vector_action.keys()) + list(text_action.keys()):
if brain_name not in self._external_brain_names:
raise UnityActionException(
"The name {0} does not correspond to an external brain "

)
else:
vector_action[brain_name] = self._flatten(vector_action[brain_name])
if brain_name not in memory:
memory[brain_name] = []
else:
if memory[brain_name] is None:
memory[brain_name] = []
else:
memory[brain_name] = self._flatten(memory[brain_name])
if brain_name not in text_action:
text_action[brain_name] = [""] * n_agent
else:

)
step_input = self._generate_step_input(
vector_action, memory, text_action, value, custom_action
vector_action, text_action, value, custom_action
)
with