浏览代码

Remove {text,custom} {action,observations} (#2839)

* delete text actions and obs

* delete custom actions and obs

* regenerate protos

* cleanup C#

* format

* fix tests

* fix base env signature

* doc cleanup
/develop-newnormalization
GitHub 5 年前
当前提交
ccb7eab4
共有 58 个文件被更改,包括 81 次插入1171 次删除
  1. 2
      UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
  2. 2
      UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
  3. 2
      UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
  4. 2
      UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs
  5. 2
      UnitySDK/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs
  6. 2
      UnitySDK/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs
  7. 2
      UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
  8. 2
      UnitySDK/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
  9. 2
      UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  10. 2
      UnitySDK/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs
  11. 2
      UnitySDK/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs
  12. 2
      UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs
  13. 2
      UnitySDK/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
  14. 2
      UnitySDK/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs
  15. 2
      UnitySDK/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs
  16. 2
      UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
  17. 2
      UnitySDK/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
  18. 2
      UnitySDK/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs
  19. 75
      UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
  20. 77
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentAction.cs
  21. 114
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfo.cs
  22. 5
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
  23. 2
      docs/Glossary.md
  24. 2
      docs/Learning-Environment-Create-New.md
  25. 8
      docs/Python-API.md
  26. 2
      docs/Readme.md
  27. 4
      gym-unity/gym_unity/envs/__init__.py
  28. 1
      gym-unity/gym_unity/tests/test_gym.py
  29. 1
      ml-agents-envs/mlagents/envs/action_info.py
  30. 6
      ml-agents-envs/mlagents/envs/base_unity_environment.py
  31. 9
      ml-agents-envs/mlagents/envs/brain.py
  32. 27
      ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.py
  33. 17
      ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.pyi
  34. 45
      ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.py
  35. 19
      ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.pyi
  36. 78
      ml-agents-envs/mlagents/envs/environment.py
  37. 2
      ml-agents-envs/mlagents/envs/mock_communicator.py
  38. 4
      ml-agents-envs/mlagents/envs/simple_env_manager.py
  39. 6
      ml-agents-envs/mlagents/envs/subprocess_env_manager.py
  40. 8
      ml-agents/mlagents/trainers/rl_trainer.py
  41. 2
      ml-agents/mlagents/trainers/tests/mock_brain.py
  42. 6
      ml-agents/mlagents/trainers/tests/test_policy.py
  43. 1
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  44. 7
      ml-agents/mlagents/trainers/tf_policy.py
  45. 2
      protobuf-definitions/README.md
  46. 6
      protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_action.proto
  47. 7
      protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_info.proto
  48. 11
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomAction.cs.meta
  49. 11
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomObservation.cs.meta
  50. 146
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomAction.cs
  51. 146
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomObservation.cs
  52. 171
      docs/Creating-Custom-Protobuf-Messages.md
  53. 64
      ml-agents-envs/mlagents/envs/communicator_objects/custom_action_pb2.py
  54. 23
      ml-agents-envs/mlagents/envs/communicator_objects/custom_action_pb2.pyi
  55. 64
      ml-agents-envs/mlagents/envs/communicator_objects/custom_observation_pb2.py
  56. 23
      ml-agents-envs/mlagents/envs/communicator_objects/custom_observation_pb2.pyi
  57. 7
      protobuf-definitions/proto/mlagents/envs/communicator_objects/custom_action.proto
  58. 7
      protobuf-definitions/proto/mlagents/envs/communicator_objects/custom_observation.proto

2
UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs


id = 5,
maxStepReached = true,
floatObservations = new List<float>() { 1f, 1f, 1f },
storedTextActions = "TestAction",
textObservation = "TestAction",
};
demoStore.Record(agentInfo);

2
UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs


AddVectorObs(0f);
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
agentActionCalls += 1;
AddReward(0.1f);

2
UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs


AddVectorObs(m_BallRb.velocity);
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);

2
UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs


AddVectorObs((ball.transform.position - gameObject.transform.position));
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);

2
UnitySDK/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs


AddVectorObs(m_Position, 20);
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
var movement = (int)vectorAction[0];

2
UnitySDK/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs


AddVectorObs(target.transform.localPosition);
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
for (var i = 0; i < vectorAction.Length; i++)
{

2
UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs


target.position = newTargetPos + ground.position;
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
if (detectTargets)
{

2
UnitySDK/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs


gameObject.GetComponentInChildren<Renderer>().material = normalMaterial;
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
MoveAgent(vectorAction);
}

2
UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


}
// to be implemented by the developer
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
AddReward(-0.01f);
var action = Mathf.FloorToInt(vectorAction[0]);

2
UnitySDK/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs


m_AgentRb.AddForce(dirToGo * m_Academy.agentRunSpeed, ForceMode.VelocityChange);
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
AddReward(-1f / agentParameters.maxStep);
MoveAgent(vectorAction);

2
UnitySDK/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs


/// <summary>
/// Called every step of the engine. Here the agent takes an action.
/// </summary>
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
// Move the agent using the action.
MoveAgent(vectorAction);

2
UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs


m_AgentRb.AddForce(dirToGo * 2f, ForceMode.VelocityChange);
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
AddReward(-1f / agentParameters.maxStep);
MoveAgent(vectorAction);

2
UnitySDK/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs


/// <summary>
/// The agent's four actions correspond to torques on each of the two joints.
/// </summary>
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
m_GoalDegree += m_GoalSpeed;
UpdateGoalPosition();

2
UnitySDK/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs


ForceMode.VelocityChange);
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
// Existential penalty for strikers.
if (agentRole == AgentRole.Striker)

2
UnitySDK/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs


{
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
}

2
UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs


AddVectorObs(m_BallRb.velocity.y);
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
var moveX = Mathf.Clamp(vectorAction[0], -1f, 1f) * m_InvertMult;
var moveY = Mathf.Clamp(vectorAction[1], -1f, 1f);

2
UnitySDK/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs


}
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
m_DirToTarget = target.position - m_JdController.bodyPartsDict[hips].rb.position;

2
UnitySDK/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs


jumpingTime -= Time.fixedDeltaTime;
}
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
MoveAgent(vectorAction);
if ((!Physics.Raycast(m_AgentRb.position, Vector3.down, 20))

75
UnitySDK/Assets/ML-Agents/Scripts/Agent.cs


public List<float> floatObservations;
/// <summary>
/// Most recent text observation.
/// </summary>
public string textObservation;
/// <summary>
/// <summary>
/// Keeps track of the last text action taken by the Brain.
/// </summary>
public string storedTextActions;
/// <summary>
/// For discrete control, specifies the actions that the agent cannot take. Is true if

/// to separate between different agents in the environment.
/// </summary>
public int id;
/// <summary>
/// User-customizable object for sending structured output from Unity to Python in response
/// to an action in addition to a scalar reward.
/// TODO(cgoy): All references to protobuf objects should be removed.
/// </summary>
public CommunicatorObjects.CustomObservationProto customObservation;
}
/// <summary>

public struct AgentAction
{
public float[] vectorActions;
public string textActions;
/// TODO(cgoy): All references to protobuf objects should be removed.
public CommunicatorObjects.CustomActionProto customAction;
}
/// <summary>

}
}
if (m_Info.textObservation == null)
m_Info.textObservation = "";
m_Action.textActions = "";
m_Info.customObservation = null;
}
/// <summary>

}
m_Info.storedVectorActions = m_Action.vectorActions;
m_Info.storedTextActions = m_Action.textActions;
m_Info.compressedObservations.Clear();
m_ActionMasker.ResetMask();
using (TimerStack.Instance.Scoped("CollectObservations"))

m_Recorder.WriteExperience(m_Info);
}
m_Info.textObservation = "";
}
/// <summary>

}
/// <summary>
/// Collects the (vector, visual, text) observations of the agent.
/// Collects the (vector, visual) observations of the agent.
/// The agent observation describes the current environment from the
/// perspective of the agent.
/// </summary>

/// observation could include distances to friends or enemies, or the
/// current level of ammunition at its disposal.
/// Recall that an Agent may attach vector, visual or textual observations.
/// Recall that an Agent may attach vector or visual observations.
/// Vector observations are added by calling the provided helper methods:
/// - <see cref="AddVectorObs(int)"/>
/// - <see cref="AddVectorObs(float)"/>

/// needs to match the vectorObservationSize attribute of the linked Brain.
/// Visual observations are implicitly added from the cameras attached to
/// the Agent.
/// Lastly, textual observations are added using
/// <see cref="SetTextObs(string)"/>.
/// </remarks>
public virtual void CollectObservations()
{

}
/// <summary>
/// Sets the text observation.
/// </summary>
/// <param name="textObservation">The text observation.</param>
public void SetTextObs(string textObservation)
{
m_Info.textObservation = textObservation;
}
/// <summary>
/// Specifies the agent behavior at every step based on the provided
/// action.
/// </summary>

/// </param>
/// <param name="textAction">Text action.</param>
public virtual void AgentAction(float[] vectorAction, string textAction)
{
}
/// <summary>
/// Specifies the agent behavior at every step based on the provided
/// action.
/// </summary>
/// <param name="vectorAction">
/// Vector action. Note that for discrete actions, the provided array
/// will be of length 1.
/// </param>
/// <param name="textAction">Text action.</param>
/// <param name="customAction">
/// A custom action, defined by the user as custom protobuf message. Useful if the action is hard to encode
/// as either a flat vector or a single string.
/// </param>
public virtual void AgentAction(float[] vectorAction, string textAction, CommunicatorObjects.CustomActionProto customAction)
public virtual void AgentAction(float[] vectorAction)
// We fall back to not using the custom action if the subclassed Agent doesn't override this method.
AgentAction(vectorAction, textAction);
}
/// <summary>

if ((m_RequestAction) && (m_Brain != null))
{
m_RequestAction = false;
AgentAction(m_Action.vectorActions, m_Action.textActions, m_Action.customAction);
AgentAction(m_Action.vectorActions);
}
if ((m_StepCount >= agentParameters.maxStep)

void DecideAction()
{
m_Brain?.DecideAction();
}
/// <summary>
/// Sets the custom observation for the agent for this episode.
/// </summary>
/// <param name="customObservation">New value of the agent's custom observation.</param>
public void SetCustomObservation(CommunicatorObjects.CustomObservationProto customObservation)
{
m_Info.customObservation = customObservation;
}
}
}

77
UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentAction.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2Fj",
"dGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaNm1sYWdlbnRzL2Vu",
"dnMvY29tbXVuaWNhdG9yX29iamVjdHMvY3VzdG9tX2FjdGlvbi5wcm90byKV",
"AQoQQWdlbnRBY3Rpb25Qcm90bxIWCg52ZWN0b3JfYWN0aW9ucxgBIAMoAhIU",
"Cgx0ZXh0X2FjdGlvbnMYAiABKAkSDQoFdmFsdWUYBCABKAISPgoNY3VzdG9t",
"X2FjdGlvbhgFIAEoCzInLmNvbW11bmljYXRvcl9vYmplY3RzLkN1c3RvbUFj",
"dGlvblByb3RvSgQIAxAEQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmpl",
"Y3RzYgZwcm90bzM="));
"dGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiSwoQQWdlbnRBY3Rp",
"b25Qcm90bxIWCg52ZWN0b3JfYWN0aW9ucxgBIAMoAhINCgV2YWx1ZRgEIAEo",
"AkoECAIQA0oECAMQBEoECAUQBkIfqgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9y",
"T2JqZWN0c2IGcHJvdG8z"));
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.CustomActionReflection.Descriptor, },
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "TextActions", "Value", "CustomAction" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "Value" }, null, null, null)
}));
}
#endregion

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public AgentActionProto(AgentActionProto other) : this() {
vectorActions_ = other.vectorActions_.Clone();
textActions_ = other.textActions_;
CustomAction = other.customAction_ != null ? other.CustomAction.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

get { return vectorActions_; }
}
/// <summary>Field number for the "text_actions" field.</summary>
public const int TextActionsFieldNumber = 2;
private string textActions_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string TextActions {
get { return textActions_; }
set {
textActions_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}
/// <summary>Field number for the "value" field.</summary>
public const int ValueFieldNumber = 4;
private float value_;

set {
value_ = value;
}
}
/// <summary>Field number for the "custom_action" field.</summary>
public const int CustomActionFieldNumber = 5;
private global::MLAgents.CommunicatorObjects.CustomActionProto customAction_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::MLAgents.CommunicatorObjects.CustomActionProto CustomAction {
get { return customAction_; }
set {
customAction_ = value;
}
}

return true;
}
if(!vectorActions_.Equals(other.vectorActions_)) return false;
if (TextActions != other.TextActions) return false;
if (!object.Equals(CustomAction, other.CustomAction)) return false;
return Equals(_unknownFields, other._unknownFields);
}

hash ^= vectorActions_.GetHashCode();
if (TextActions.Length != 0) hash ^= TextActions.GetHashCode();
if (customAction_ != null) hash ^= CustomAction.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
vectorActions_.WriteTo(output, _repeated_vectorActions_codec);
if (TextActions.Length != 0) {
output.WriteRawTag(18);
output.WriteString(TextActions);
}
if (customAction_ != null) {
output.WriteRawTag(42);
output.WriteMessage(CustomAction);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

public int CalculateSize() {
int size = 0;
size += vectorActions_.CalculateSize(_repeated_vectorActions_codec);
if (TextActions.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(TextActions);
}
if (customAction_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(CustomAction);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

return;
}
vectorActions_.Add(other.vectorActions_);
if (other.TextActions.Length != 0) {
TextActions = other.TextActions;
}
}
if (other.customAction_ != null) {
if (customAction_ == null) {
customAction_ = new global::MLAgents.CommunicatorObjects.CustomActionProto();
}
CustomAction.MergeFrom(other.CustomAction);
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

vectorActions_.AddEntriesFrom(input, _repeated_vectorActions_codec);
break;
}
case 18: {
TextActions = input.ReadString();
break;
}
break;
}
case 42: {
if (customAction_ == null) {
customAction_ = new global::MLAgents.CommunicatorObjects.CustomActionProto();
}
input.ReadMessage(customAction_);
break;
}
}

114
UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfo.cs


"CjNtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGj9tbGFnZW50cy9lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL2NvbXByZXNzZWRfb2JzZXJ2YXRpb24u",
"cHJvdG8aO21sYWdlbnRzL2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvY3Vz",
"dG9tX29ic2VydmF0aW9uLnByb3RvIowDCg5BZ2VudEluZm9Qcm90bxIiChpz",
"dGFja2VkX3ZlY3Rvcl9vYnNlcnZhdGlvbhgBIAMoAhIYChB0ZXh0X29ic2Vy",
"dmF0aW9uGAMgASgJEh0KFXN0b3JlZF92ZWN0b3JfYWN0aW9ucxgEIAMoAhIb",
"ChNzdG9yZWRfdGV4dF9hY3Rpb25zGAUgASgJEg4KBnJld2FyZBgHIAEoAhIM",
"CgRkb25lGAggASgIEhgKEG1heF9zdGVwX3JlYWNoZWQYCSABKAgSCgoCaWQY",
"CiABKAUSEwoLYWN0aW9uX21hc2sYCyADKAgSSAoSY3VzdG9tX29ic2VydmF0",
"aW9uGAwgASgLMiwuY29tbXVuaWNhdG9yX29iamVjdHMuQ3VzdG9tT2JzZXJ2",
"YXRpb25Qcm90bxJRChdjb21wcmVzc2VkX29ic2VydmF0aW9ucxgNIAMoCzIw",
"LmNvbW11bmljYXRvcl9vYmplY3RzLkNvbXByZXNzZWRPYnNlcnZhdGlvblBy",
"b3RvSgQIAhADSgQIBhAHQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmpl",
"Y3RzYgZwcm90bzM="));
"cHJvdG8inQIKDkFnZW50SW5mb1Byb3RvEiIKGnN0YWNrZWRfdmVjdG9yX29i",
"c2VydmF0aW9uGAEgAygCEh0KFXN0b3JlZF92ZWN0b3JfYWN0aW9ucxgEIAMo",
"AhIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIYChBtYXhfc3RlcF9y",
"ZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlvbl9tYXNrGAsgAygI",
"ElEKF2NvbXByZXNzZWRfb2JzZXJ2YXRpb25zGA0gAygLMjAuY29tbXVuaWNh",
"dG9yX29iamVjdHMuQ29tcHJlc3NlZE9ic2VydmF0aW9uUHJvdG9KBAgCEANK",
"BAgDEARKBAgFEAZKBAgGEAdKBAgMEA1CH6oCHE1MQWdlbnRzLkNvbW11bmlj",
"YXRvck9iamVjdHNiBnByb3RvMw=="));
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.CompressedObservationReflection.Descriptor, global::MLAgents.CommunicatorObjects.CustomObservationReflection.Descriptor, },
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.CompressedObservationReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "StackedVectorObservation", "TextObservation", "StoredVectorActions", "StoredTextActions", "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "CustomObservation", "CompressedObservations" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "StackedVectorObservation", "StoredVectorActions", "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "CompressedObservations" }, null, null, null)
}));
}
#endregion

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public AgentInfoProto(AgentInfoProto other) : this() {
stackedVectorObservation_ = other.stackedVectorObservation_.Clone();
textObservation_ = other.textObservation_;
storedTextActions_ = other.storedTextActions_;
CustomObservation = other.customObservation_ != null ? other.CustomObservation.Clone() : null;
compressedObservations_ = other.compressedObservations_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

get { return stackedVectorObservation_; }
}
/// <summary>Field number for the "text_observation" field.</summary>
public const int TextObservationFieldNumber = 3;
private string textObservation_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string TextObservation {
get { return textObservation_; }
set {
textObservation_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}
/// <summary>Field number for the "stored_vector_actions" field.</summary>
public const int StoredVectorActionsFieldNumber = 4;
private static readonly pb::FieldCodec<float> _repeated_storedVectorActions_codec

public pbc::RepeatedField<float> StoredVectorActions {
get { return storedVectorActions_; }
}
/// <summary>Field number for the "stored_text_actions" field.</summary>
public const int StoredTextActionsFieldNumber = 5;
private string storedTextActions_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string StoredTextActions {
get { return storedTextActions_; }
set {
storedTextActions_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}
/// <summary>Field number for the "reward" field.</summary>

get { return actionMask_; }
}
/// <summary>Field number for the "custom_observation" field.</summary>
public const int CustomObservationFieldNumber = 12;
private global::MLAgents.CommunicatorObjects.CustomObservationProto customObservation_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::MLAgents.CommunicatorObjects.CustomObservationProto CustomObservation {
get { return customObservation_; }
set {
customObservation_ = value;
}
}
/// <summary>Field number for the "compressed_observations" field.</summary>
public const int CompressedObservationsFieldNumber = 13;
private static readonly pb::FieldCodec<global::MLAgents.CommunicatorObjects.CompressedObservationProto> _repeated_compressedObservations_codec

return true;
}
if(!stackedVectorObservation_.Equals(other.stackedVectorObservation_)) return false;
if (TextObservation != other.TextObservation) return false;
if (StoredTextActions != other.StoredTextActions) return false;
if (!object.Equals(CustomObservation, other.CustomObservation)) return false;
if(!compressedObservations_.Equals(other.compressedObservations_)) return false;
return Equals(_unknownFields, other._unknownFields);
}

int hash = 1;
hash ^= stackedVectorObservation_.GetHashCode();
if (TextObservation.Length != 0) hash ^= TextObservation.GetHashCode();
if (StoredTextActions.Length != 0) hash ^= StoredTextActions.GetHashCode();
if (customObservation_ != null) hash ^= CustomObservation.GetHashCode();
hash ^= compressedObservations_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
stackedVectorObservation_.WriteTo(output, _repeated_stackedVectorObservation_codec);
if (TextObservation.Length != 0) {
output.WriteRawTag(26);
output.WriteString(TextObservation);
}
if (StoredTextActions.Length != 0) {
output.WriteRawTag(42);
output.WriteString(StoredTextActions);
}
if (Reward != 0F) {
output.WriteRawTag(61);
output.WriteFloat(Reward);

output.WriteInt32(Id);
}
actionMask_.WriteTo(output, _repeated_actionMask_codec);
if (customObservation_ != null) {
output.WriteRawTag(98);
output.WriteMessage(CustomObservation);
}
compressedObservations_.WriteTo(output, _repeated_compressedObservations_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);

public int CalculateSize() {
int size = 0;
size += stackedVectorObservation_.CalculateSize(_repeated_stackedVectorObservation_codec);
if (TextObservation.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(TextObservation);
}
if (StoredTextActions.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(StoredTextActions);
}
if (Reward != 0F) {
size += 1 + 4;
}

size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id);
}
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
if (customObservation_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(CustomObservation);
}
size += compressedObservations_.CalculateSize(_repeated_compressedObservations_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();

return;
}
stackedVectorObservation_.Add(other.stackedVectorObservation_);
if (other.TextObservation.Length != 0) {
TextObservation = other.TextObservation;
}
if (other.StoredTextActions.Length != 0) {
StoredTextActions = other.StoredTextActions;
}
if (other.Reward != 0F) {
Reward = other.Reward;
}

Id = other.Id;
}
actionMask_.Add(other.actionMask_);
if (other.customObservation_ != null) {
if (customObservation_ == null) {
customObservation_ = new global::MLAgents.CommunicatorObjects.CustomObservationProto();
}
CustomObservation.MergeFrom(other.CustomObservation);
}
compressedObservations_.Add(other.compressedObservations_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

stackedVectorObservation_.AddEntriesFrom(input, _repeated_stackedVectorObservation_codec);
break;
}
case 26: {
TextObservation = input.ReadString();
break;
}
break;
}
case 42: {
StoredTextActions = input.ReadString();
break;
}
case 61: {

case 90:
case 88: {
actionMask_.AddEntriesFrom(input, _repeated_actionMask_codec);
break;
}
case 98: {
if (customObservation_ == null) {
customObservation_ = new global::MLAgents.CommunicatorObjects.CustomObservationProto();
}
input.ReadMessage(customObservation_);
break;
}
case 106: {

5
UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs


{
StackedVectorObservation = { ai.floatObservations },
StoredVectorActions = { ai.storedVectorActions },
StoredTextActions = ai.storedTextActions,
TextObservation = ai.textObservation,
CustomObservation = ai.customObservation
};
if (ai.actionMasks != null)

return new AgentAction
{
vectorActions = aap.VectorActions.ToArray(),
textActions = aap.TextActions,
customAction = aap.CustomAction
};
}

2
docs/Glossary.md


* **Frame** - An instance of rendering the main camera for the display.
Corresponds to each `Update` call of the game engine.
* **Observation** - Partial information describing the state of the environment
available to a given agent. (e.g. Vector, Visual, Text)
available to a given agent. (e.g. Vector, Visual)
* **Policy** - Function for producing decisions from observations.
* **Reward** - Signal provided at every step used to indicate desirability of an
agent’s action within the current state of the environment.

2
docs/Learning-Environment-Create-New.md


```csharp
public float speed = 10;
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
// Actions, size = 2
Vector3 controlSignal = Vector3.zero;

8
docs/Python-API.md


the list corresponds to the n<sup>th</sup> observation of the Brain.
- **`vector_observations`** : A two dimensional numpy array of dimension `(batch
size, vector observation size)`.
- **`text_observations`** : A list of string corresponding to the Agents text
observations.
- **`memories`** : A two dimensional numpy array of dimension `(batch size,
memory size)` which corresponds to the memories sent at the previous step.
- **`rewards`** : A list as long as the number of Agents using the Brain

`resetParameters` and the values are their corresponding float values.
Define the reset parameters on the Academy Inspector window in the Unity
Editor.
- **Step : `env.step(action, memory=None, text_action=None)`**
- **Step : `env.step(action)`**
- `memory` is an optional input that can be used to send a list of floats per
Agents to be retrieved at the next step.
- `text_action` is an optional input that be used to send a single string per
Agent.
Returns a dictionary mapping Brain names to BrainInfo objects.

2
docs/Readme.md


* [Using the Monitor](Feature-Monitor.md)
* [Using the Video Recorder](https://github.com/Unity-Technologies/video-recorder)
* [Using an Executable Environment](Learning-Environment-Executable.md)
* [Creating Custom Protobuf Messages](Creating-Custom-Protobuf-Messages.md)
## Training

* [API Reference](API-Reference.md)
* [How to use the Python API](Python-API.md)
* [Wrapping Learning Environment as a Gym (+Baselines/Dopamine Integration)](../gym-unity/README.md)
* [Creating custom protobuf messages](Creating-Custom-Protobuf-Messages.md)

4
gym-unity/gym_unity/envs/__init__.py


default_observation,
info.rewards[0],
info.local_done[0],
{"text_observation": info.text_observations[0], "brain_info": info},
{"text_observation": None, "brain_info": info},
)
def _preprocess_single(self, single_visual_obs):

list(default_observation),
info.rewards,
info.local_done,
{"text_observation": info.text_observations, "brain_info": info},
{"text_observation": None, "brain_info": info},
)
def _preprocess_multi(self, multiple_visual_obs):

1
gym-unity/gym_unity/tests/test_gym.py


mock_braininfo.return_value.visual_observations = [[np.zeros(shape=(8, 8, 3))]]
mock_braininfo.return_value.rewards = num_agents * [1.0]
mock_braininfo.return_value.local_done = num_agents * [False]
mock_braininfo.return_value.text_observations = num_agents * [""]
mock_braininfo.return_value.agents = range(0, num_agents)
return mock_braininfo()

1
ml-agents-envs/mlagents/envs/action_info.py


class ActionInfo(NamedTuple):
action: Any
text: Any
value: Any
outputs: ActionInfoOutputs

6
ml-agents-envs/mlagents/envs/base_unity_environment.py


class BaseUnityEnvironment(ABC):
@abstractmethod
def step(
self,
vector_action: Optional[Dict] = None,
text_action: Optional[Dict] = None,
value: Optional[Dict] = None,
custom_action: Dict[str, Any] = None,
self, vector_action: Optional[Dict] = None, value: Optional[Dict] = None
) -> AllBrainInfo:
pass

9
ml-agents-envs/mlagents/envs/brain.py


self,
visual_observation,
vector_observation,
text_observations,
text_action=None,
custom_observations=None,
):
"""
Describes experience at current step of all agents linked to a brain.

self.text_observations = text_observations
self.previous_text_actions = text_action
self.custom_observations = custom_observations
@staticmethod
def merge_memories(m1, m2, agents1, agents2):

brain_info = BrainInfo(
visual_observation=vis_obs,
vector_observation=vector_obs,
text_observations=[x.text_observation for x in agent_info_list],
text_action=[list(x.stored_text_actions) for x in agent_info_list],
custom_observations=[x.custom_observation for x in agent_info_list],
action_mask=mask_actions,
)
return brain_info

27
ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.py


_sym_db = _symbol_database.Default()
from mlagents.envs.communicator_objects import custom_action_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2
DESCRIPTOR = _descriptor.FileDescriptor(

serialized_pb=_b('\n5mlagents/envs/communicator_objects/agent_action.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents/envs/communicator_objects/custom_action.proto\"\x95\x01\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\r\n\x05value\x18\x04 \x01(\x02\x12>\n\rcustom_action\x18\x05 \x01(\x0b\x32\'.communicator_objects.CustomActionProtoJ\x04\x08\x03\x10\x04\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2.DESCRIPTOR,])
serialized_pb=_b('\n5mlagents/envs/communicator_objects/agent_action.proto\x12\x14\x63ommunicator_objects\"K\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\r\n\x05value\x18\x04 \x01(\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x05\x10\x06\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
)

is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='text_actions', full_name='communicator_objects.AgentActionProto.text_actions', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='value', full_name='communicator_objects.AgentActionProto.value', index=2,
name='value', full_name='communicator_objects.AgentActionProto.value', index=1,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='custom_action', full_name='communicator_objects.AgentActionProto.custom_action', index=3,
number=5, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),

extension_ranges=[],
oneofs=[
],
serialized_start=136,
serialized_end=285,
serialized_start=79,
serialized_end=154,
_AGENTACTIONPROTO.fields_by_name['custom_action'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2._CUSTOMACTIONPROTO
DESCRIPTOR.message_types_by_name['AgentActionProto'] = _AGENTACTIONPROTO
_sym_db.RegisterFileDescriptor(DESCRIPTOR)

17
ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.pyi


Message as google___protobuf___message___Message,
)
from mlagents.envs.communicator_objects.custom_action_pb2 import (
CustomActionProto as mlagents___envs___communicator_objects___custom_action_pb2___CustomActionProto,
)
Text as typing___Text,
)
from typing_extensions import (

class AgentActionProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
vector_actions = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___float]
text_actions = ... # type: typing___Text
@property
def custom_action(self) -> mlagents___envs___communicator_objects___custom_action_pb2___CustomActionProto: ...
text_actions : typing___Optional[typing___Text] = None,
custom_action : typing___Optional[mlagents___envs___communicator_objects___custom_action_pb2___CustomActionProto] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> AgentActionProto: ...

def HasField(self, field_name: typing_extensions___Literal[u"custom_action"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"custom_action",u"text_actions",u"value",u"vector_actions"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"value",u"vector_actions"]) -> None: ...
def HasField(self, field_name: typing_extensions___Literal[u"custom_action",b"custom_action"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"custom_action",b"custom_action",u"text_actions",b"text_actions",u"value",b"value",u"vector_actions",b"vector_actions"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"value",b"value",u"vector_actions",b"vector_actions"]) -> None: ...

45
ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.py


from mlagents.envs.communicator_objects import compressed_observation_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_compressed__observation__pb2
from mlagents.envs.communicator_objects import custom_observation_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2
DESCRIPTOR = _descriptor.FileDescriptor(

serialized_pb=_b('\n3mlagents/envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a?mlagents/envs/communicator_objects/compressed_observation.proto\x1a;mlagents/envs/communicator_objects/custom_observation.proto\"\x8c\x03\n\x0e\x41gentInfoProto\x12\"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x18\n\x10text_observation\x18\x03 \x01(\t\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x1b\n\x13stored_text_actions\x18\x05 \x01(\t\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12H\n\x12\x63ustom_observation\x18\x0c \x01(\x0b\x32,.communicator_objects.CustomObservationProto\x12Q\n\x17\x63ompressed_observations\x18\r \x03(\x0b\x32\x30.communicator_objects.CompressedObservationProtoJ\x04\x08\x02\x10\x03J\x04\x08\x06\x10\x07\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n3mlagents/envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a?mlagents/envs/communicator_objects/compressed_observation.proto\"\x9d\x02\n\x0e\x41gentInfoProto\x12\"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12Q\n\x17\x63ompressed_observations\x18\r \x03(\x0b\x32\x30.communicator_objects.CompressedObservationProtoJ\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_compressed__observation__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2.DESCRIPTOR,])
dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_compressed__observation__pb2.DESCRIPTOR,])

is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='text_observation', full_name='communicator_objects.AgentInfoProto.text_observation', index=1,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='stored_vector_actions', full_name='communicator_objects.AgentInfoProto.stored_vector_actions', index=2,
name='stored_vector_actions', full_name='communicator_objects.AgentInfoProto.stored_vector_actions', index=1,
number=4, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,

name='stored_text_actions', full_name='communicator_objects.AgentInfoProto.stored_text_actions', index=3,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=4,
name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=2,
number=7, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,

name='done', full_name='communicator_objects.AgentInfoProto.done', index=5,
name='done', full_name='communicator_objects.AgentInfoProto.done', index=3,
number=8, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,

name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=6,
name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=4,
number=9, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,

name='id', full_name='communicator_objects.AgentInfoProto.id', index=7,
name='id', full_name='communicator_objects.AgentInfoProto.id', index=5,
number=10, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,

name='action_mask', full_name='communicator_objects.AgentInfoProto.action_mask', index=8,
name='action_mask', full_name='communicator_objects.AgentInfoProto.action_mask', index=6,
number=11, type=8, cpp_type=7, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,

name='custom_observation', full_name='communicator_objects.AgentInfoProto.custom_observation', index=9,
number=12, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='compressed_observations', full_name='communicator_objects.AgentInfoProto.compressed_observations', index=10,
name='compressed_observations', full_name='communicator_objects.AgentInfoProto.compressed_observations', index=7,
number=13, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,

extension_ranges=[],
oneofs=[
],
serialized_start=204,
serialized_end=600,
serialized_start=143,
serialized_end=428,
_AGENTINFOPROTO.fields_by_name['custom_observation'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2._CUSTOMOBSERVATIONPROTO
_AGENTINFOPROTO.fields_by_name['compressed_observations'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_compressed__observation__pb2._COMPRESSEDOBSERVATIONPROTO
DESCRIPTOR.message_types_by_name['AgentInfoProto'] = _AGENTINFOPROTO
_sym_db.RegisterFileDescriptor(DESCRIPTOR)

19
ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.pyi


CompressedObservationProto as mlagents___envs___communicator_objects___compressed_observation_pb2___CompressedObservationProto,
)
from mlagents.envs.communicator_objects.custom_observation_pb2 import (
CustomObservationProto as mlagents___envs___communicator_objects___custom_observation_pb2___CustomObservationProto,
)
Text as typing___Text,
)
from typing_extensions import (

class AgentInfoProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
stacked_vector_observation = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___float]
text_observation = ... # type: typing___Text
stored_text_actions = ... # type: typing___Text
@property
def custom_observation(self) -> mlagents___envs___communicator_objects___custom_observation_pb2___CustomObservationProto: ...
@property
def compressed_observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents___envs___communicator_objects___compressed_observation_pb2___CompressedObservationProto]: ...

stacked_vector_observation : typing___Optional[typing___Iterable[builtin___float]] = None,
text_observation : typing___Optional[typing___Text] = None,
stored_text_actions : typing___Optional[typing___Text] = None,
custom_observation : typing___Optional[mlagents___envs___communicator_objects___custom_observation_pb2___CustomObservationProto] = None,
compressed_observations : typing___Optional[typing___Iterable[mlagents___envs___communicator_objects___compressed_observation_pb2___CompressedObservationProto]] = None,
) -> None: ...
@classmethod

if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"custom_observation"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"compressed_observations",u"custom_observation",u"done",u"id",u"max_step_reached",u"reward",u"stacked_vector_observation",u"stored_text_actions",u"stored_vector_actions",u"text_observation"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"compressed_observations",u"done",u"id",u"max_step_reached",u"reward",u"stacked_vector_observation",u"stored_vector_actions"]) -> None: ...
def HasField(self, field_name: typing_extensions___Literal[u"custom_observation",b"custom_observation"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"compressed_observations",b"compressed_observations",u"custom_observation",b"custom_observation",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"reward",b"reward",u"stacked_vector_observation",b"stacked_vector_observation",u"stored_text_actions",b"stored_text_actions",u"stored_vector_actions",b"stored_vector_actions",u"text_observation",b"text_observation"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"compressed_observations",b"compressed_observations",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"reward",b"reward",u"stacked_vector_observation",b"stacked_vector_observation",u"stored_vector_actions",b"stored_vector_actions"]) -> None: ...

78
ml-agents-envs/mlagents/envs/environment.py


)
from mlagents.envs.communicator_objects.unity_input_pb2 import UnityInputProto
from mlagents.envs.communicator_objects.custom_action_pb2 import CustomActionProto
from .rpc_communicator import RpcCommunicator
from sys import platform

class UnityEnvironment(BaseUnityEnvironment):
SCALAR_ACTION_TYPES = (int, np.int32, np.int64, float, np.float32, np.float64)
SINGLE_BRAIN_ACTION_TYPES = SCALAR_ACTION_TYPES + (list, np.ndarray)
SINGLE_BRAIN_TEXT_TYPES = list
API_VERSION = "API-11"
def __init__(

def step(
self,
vector_action: Dict[str, np.ndarray] = None,
text_action: Optional[Dict[str, List[str]]] = None,
custom_action: Dict[str, Any] = None,
) -> AllBrainInfo:
"""
Provides the environment with an action, moves the environment dynamics forward accordingly,

:param memory: Vector corresponding to memory used for recurrent policies.
:param text_action: Text action to send to environment for.
:param custom_action: Optional instance of a CustomAction protobuf message.
text_action = {} if text_action is None else text_action
custom_action = {} if custom_action is None else custom_action
# Check that environment is loaded, and episode is currently running.
if not self._loaded:

"step cannot take a vector_action input"
)
if isinstance(text_action, self.SINGLE_BRAIN_TEXT_TYPES):
if self._num_external_brains == 1:
text_action = {self._external_brain_names[0]: text_action}
elif self._num_external_brains > 1:
raise UnityActionException(
"You have {0} brains, you need to feed a dictionary of brain names as keys "
"and text_actions as values".format(self._num_external_brains)
)
else:
raise UnityActionException(
"There are no external brains in the environment, "
"step cannot take a value input"
)
if isinstance(value, self.SINGLE_BRAIN_ACTION_TYPES):
if self._num_external_brains == 1:
value = {self._external_brain_names[0]: value}

"step cannot take a value input"
)
if isinstance(custom_action, CustomActionProto):
if self._num_external_brains == 1:
custom_action = {self._external_brain_names[0]: custom_action}
elif self._num_external_brains > 1:
raise UnityActionException(
"You have {0} brains, you need to feed a dictionary of brain names as keys "
"and CustomAction instances as values".format(
self._num_external_brains
)
)
else:
raise UnityActionException(
"There are no external brains in the environment, "
"step cannot take a custom_action input"
)
for brain_name in list(vector_action.keys()) + list(text_action.keys()):
for brain_name in list(vector_action.keys()):
if brain_name not in self._external_brain_names:
raise UnityActionException(
"The name {0} does not correspond to an external brain "

)
else:
vector_action[brain_name] = self._flatten(vector_action[brain_name])
if brain_name not in text_action:
text_action[brain_name] = [""] * n_agent
else:
if text_action[brain_name] is None:
text_action[brain_name] = [""] * n_agent
if brain_name not in custom_action:
custom_action[brain_name] = [None] * n_agent
else:
if custom_action[brain_name] is None:
custom_action[brain_name] = [None] * n_agent
if isinstance(custom_action[brain_name], CustomActionProto):
custom_action[brain_name] = [
custom_action[brain_name]
] * n_agent
number_text_actions = len(text_action[brain_name])
if not ((number_text_actions == n_agent) or number_text_actions == 0):
raise UnityActionException(
"There was a mismatch between the provided text_action and "
"the environment's expectation: "
"The brain {0} expected {1} text_action but was given {2}".format(
brain_name, n_agent, number_text_actions
)
)
discrete_check = (
self._brains[brain_name].vector_action_space_type == "discrete"

)
)
step_input = self._generate_step_input(
vector_action, text_action, value, custom_action
)
step_input = self._generate_step_input(vector_action, value)
with hierarchical_timer("communicator.exchange"):
outputs = self.communicator.exchange(step_input)
if outputs is None: