浏览代码
Refactor ICommunicator API (#2675)
Refactor ICommunicator API (#2675)
- Push (almost) all references to protobuf objects into the RpcCommunicator. - Simplify the passing around of Agents and Agent Infos. - Delete all references to the Batcher. - Simplify the Environment Step by removing all of the reset and message counting logic. - Finishes MLA-27 and MLA-28/develop-gpu-test
GitHub
5 年前
当前提交
2d92a49b
共有 28 个文件被更改,包括 755 次插入 和 865 次删除
-
2UnitySDK/Assets/ML-Agents/Editor/DemonstrationImporter.cs
-
15UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs
-
26UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs
-
33UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
-
1UnitySDK/Assets/ML-Agents/Editor/Tests/TimerTest.cs
-
2UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
-
202UnitySDK/Assets/ML-Agents/Scripts/Academy.cs
-
46UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
-
24UnitySDK/Assets/ML-Agents/Scripts/Brain.cs
-
41UnitySDK/Assets/ML-Agents/Scripts/BrainParameters.cs
-
1UnitySDK/Assets/ML-Agents/Scripts/Demonstration.cs
-
100UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
-
305UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs
-
27UnitySDK/Assets/ML-Agents/Scripts/HeuristicBrain.cs
-
123UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs
-
23UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs
-
49UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
-
14UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs
-
10UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs
-
26UnitySDK/Assets/ML-Agents/Scripts/LearningBrain.cs
-
5UnitySDK/Assets/ML-Agents/Scripts/PlayerBrain.cs
-
25UnitySDK/Assets/ML-Agents/Scripts/ResetParameters.cs
-
21UnitySDK/Assets/ML-Agents/Scripts/Timer.cs
-
1UnitySDK/UnitySDK.sln.DotSettings
-
13UnitySDK/Assets/ML-Agents/Scripts/Batcher.cs.meta
-
181UnitySDK/Assets/ML-Agents/Scripts/SocketCommunicator.cs
-
13UnitySDK/Assets/ML-Agents/Scripts/SocketCommunicator.cs.meta
-
291UnitySDK/Assets/ML-Agents/Scripts/Batcher.cs
|
|||
fileFormatVersion: 2 |
|||
guid: 4243d5dc0ad5746cba578575182f8c17 |
|||
timeCreated: 1523045876 |
|||
licenseType: Free |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using Google.Protobuf; |
|||
using System.Net.Sockets; |
|||
using UnityEngine; |
|||
using MLAgents.CommunicatorObjects; |
|||
using System.Threading.Tasks; |
|||
#if UNITY_EDITOR
|
|||
using UnityEditor; |
|||
#endif
|
|||
|
|||
namespace MLAgents |
|||
{ |
|||
public class SocketCommunicator : ICommunicator |
|||
{ |
|||
private const float k_TimeOut = 10f; |
|||
private const int k_MessageLength = 12000; |
|||
byte[] m_MessageHolder = new byte[k_MessageLength]; |
|||
int m_ComPort; |
|||
Socket m_Sender; |
|||
byte[] m_LengthHolder = new byte[4]; |
|||
CommunicatorParameters m_CommunicatorParameters; |
|||
|
|||
|
|||
public SocketCommunicator(CommunicatorParameters communicatorParameters) |
|||
{ |
|||
m_CommunicatorParameters = communicatorParameters; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Initialize the communicator by sending the first UnityOutput and receiving the
|
|||
/// first UnityInput. The second UnityInput is stored in the unityInput argument.
|
|||
/// </summary>
|
|||
/// <returns>The first Unity Input.</returns>
|
|||
/// <param name="unityOutput">The first Unity Output.</param>
|
|||
/// <param name="unityInput">The second Unity input.</param>
|
|||
public UnityInputProto Initialize(UnityOutputProto unityOutput, |
|||
out UnityInputProto unityInput) |
|||
{ |
|||
m_Sender = new Socket( |
|||
AddressFamily.InterNetwork, |
|||
SocketType.Stream, |
|||
ProtocolType.Tcp); |
|||
m_Sender.Connect("localhost", m_CommunicatorParameters.port); |
|||
|
|||
var initializationInput = |
|||
UnityMessageProto.Parser.ParseFrom(Receive()); |
|||
|
|||
Send(WrapMessage(unityOutput, 200).ToByteArray()); |
|||
|
|||
unityInput = UnityMessageProto.Parser.ParseFrom(Receive()).UnityInput; |
|||
#if UNITY_EDITOR
|
|||
#if UNITY_2017_2_OR_NEWER
|
|||
EditorApplication.playModeStateChanged += HandleOnPlayModeChanged; |
|||
#else
|
|||
EditorApplication.playmodeStateChanged += HandleOnPlayModeChanged; |
|||
#endif
|
|||
#endif
|
|||
return initializationInput.UnityInput; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Uses the socke to receive a byte[] from External. Reassemble a message that was split
|
|||
/// by External if it was too long.
|
|||
/// </summary>
|
|||
/// <returns>The byte[] sent by External.</returns>
|
|||
byte[] Receive() |
|||
{ |
|||
m_Sender.Receive(m_LengthHolder); |
|||
var totalLength = System.BitConverter.ToInt32(m_LengthHolder, 0); |
|||
var location = 0; |
|||
var result = new byte[totalLength]; |
|||
while (location != totalLength) |
|||
{ |
|||
var fragment = m_Sender.Receive(m_MessageHolder); |
|||
System.Buffer.BlockCopy( |
|||
m_MessageHolder, 0, result, location, fragment); |
|||
location += fragment; |
|||
} |
|||
return result; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Send the specified input via socket to External. Split the message into smaller
|
|||
/// parts if it is too long.
|
|||
/// </summary>
|
|||
/// <param name="input">The byte[] to be sent.</param>
|
|||
void Send(byte[] input) |
|||
{ |
|||
var newArray = new byte[input.Length + 4]; |
|||
input.CopyTo(newArray, 4); |
|||
System.BitConverter.GetBytes(input.Length).CopyTo(newArray, 0); |
|||
m_Sender.Send(newArray); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Close the communicator gracefully on both sides of the communication.
|
|||
/// </summary>
|
|||
public void Close() |
|||
{ |
|||
Send(WrapMessage(null, 400).ToByteArray()); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Send a UnityOutput and receives a UnityInput.
|
|||
/// </summary>
|
|||
/// <returns>The next UnityInput.</returns>
|
|||
/// <param name="unityOutput">The UnityOutput to be sent.</param>
|
|||
public UnityInputProto Exchange(UnityOutputProto unityOutput) |
|||
{ |
|||
Send(WrapMessage(unityOutput, 200).ToByteArray()); |
|||
byte[] received = null; |
|||
var task = Task.Run(() => received = Receive()); |
|||
if (!task.Wait(System.TimeSpan.FromSeconds(k_TimeOut))) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"The communicator took too long to respond."); |
|||
} |
|||
|
|||
var message = UnityMessageProto.Parser.ParseFrom(received); |
|||
|
|||
if (message.Header.Status != 200) |
|||
{ |
|||
return null; |
|||
} |
|||
return message.UnityInput; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Wraps the UnityOuptut into a message with the appropriate status.
|
|||
/// </summary>
|
|||
/// <returns>The UnityMessage corresponding.</returns>
|
|||
/// <param name="content">The UnityOutput to be wrapped.</param>
|
|||
/// <param name="status">The status of the message.</param>
|
|||
private static UnityMessageProto WrapMessage(UnityOutputProto content, int status) |
|||
{ |
|||
return new UnityMessageProto |
|||
{ |
|||
Header = new HeaderProto { Status = status }, |
|||
UnityOutput = content |
|||
}; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// When the Unity application quits, the communicator must be closed
|
|||
/// </summary>
|
|||
private void OnApplicationQuit() |
|||
{ |
|||
Close(); |
|||
} |
|||
|
|||
#if UNITY_EDITOR
|
|||
#if UNITY_2017_2_OR_NEWER
|
|||
/// <summary>
|
|||
/// When the editor exits, the communicator must be closed
|
|||
/// </summary>
|
|||
/// <param name="state">State.</param>
|
|||
private void HandleOnPlayModeChanged(PlayModeStateChange state) |
|||
{ |
|||
// This method is run whenever the playmode state is changed.
|
|||
if (state == PlayModeStateChange.ExitingPlayMode) |
|||
{ |
|||
Close(); |
|||
} |
|||
} |
|||
|
|||
#else
|
|||
/// <summary>
|
|||
/// When the editor exits, the communicator must be closed
|
|||
/// </summary>
|
|||
private void HandleOnPlayModeChanged() |
|||
{ |
|||
// This method is run whenever the playmode state is changed.
|
|||
if (!EditorApplication.isPlayingOrWillChangePlaymode) |
|||
{ |
|||
Close(); |
|||
} |
|||
} |
|||
|
|||
#endif
|
|||
#endif
|
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: f0901c57c84a54f25aa5955165072493 |
|||
timeCreated: 1523046536 |
|||
licenseType: Free |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System.Collections.Generic; |
|||
using System.Linq; |
|||
using System; |
|||
using UnityEngine; |
|||
|
|||
namespace MLAgents |
|||
{ |
|||
/// <summary>
|
|||
/// The batcher is an RL specific class that makes sure that the information each object in
|
|||
/// Unity (Academy and Brains) wants to send to External is appropriately batched together
|
|||
/// and sent only when necessary.
|
|||
///
|
|||
/// The Batcher will only send a Message to the Communicator when either :
|
|||
/// 1 - The academy is done
|
|||
/// 2 - At least one brain has data to send
|
|||
///
|
|||
/// At each step, the batcher will keep track of the brains that queried the batcher for that
|
|||
/// step. The batcher can only send the batched data when all the Brains have queried the
|
|||
/// Batcher.
|
|||
/// </summary>
|
|||
public class Batcher |
|||
{ |
|||
/// The default number of agents in the scene
|
|||
private const int k_NumAgents = 32; |
|||
|
|||
/// Keeps track of which brains have data to send on the current step
|
|||
Dictionary<string, bool> m_HasData = |
|||
new Dictionary<string, bool>(); |
|||
|
|||
/// Keeps track of which brains queried the batcher on the current step
|
|||
Dictionary<string, bool> m_HasQueried = |
|||
new Dictionary<string, bool>(); |
|||
|
|||
/// Keeps track of the agents of each brain on the current step
|
|||
Dictionary<string, List<Agent>> m_CurrentAgents = |
|||
new Dictionary<string, List<Agent>>(); |
|||
|
|||
/// The Communicator of the batcher, sends a message at most once per step
|
|||
ICommunicator m_Communicator; |
|||
|
|||
/// The current UnityRLOutput to be sent when all the brains queried the batcher
|
|||
CommunicatorObjects.UnityRLOutputProto m_CurrentUnityRlOutput = |
|||
new CommunicatorObjects.UnityRLOutputProto(); |
|||
|
|||
/// Keeps track of last CommandProto sent by External
|
|||
CommunicatorObjects.CommandProto m_Command; |
|||
|
|||
/// Keeps track of last EnvironmentParametersProto sent by External
|
|||
CommunicatorObjects.EnvironmentParametersProto m_EnvironmentParameters; |
|||
|
|||
/// Keeps track of last training mode sent by External
|
|||
bool m_IsTraining; |
|||
|
|||
/// Keeps track of the number of messages received
|
|||
private ulong m_MessagesReceived; |
|||
|
|||
/// <summary>
|
|||
/// Initializes a new instance of the Batcher class.
|
|||
/// </summary>
|
|||
/// <param name="communicator">The communicator to be used by the batcher.</param>
|
|||
public Batcher(ICommunicator communicator) |
|||
{ |
|||
m_Communicator = communicator; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Sends the academy parameters through the Communicator.
|
|||
/// Is used by the academy to send the AcademyParameters to the communicator.
|
|||
/// </summary>
|
|||
/// <returns>The External Initialization Parameters received.</returns>
|
|||
/// <param name="academyParameters">The Unity Initialization Parameters to be sent.</param>
|
|||
public CommunicatorObjects.UnityRLInitializationInputProto SendAcademyParameters( |
|||
CommunicatorObjects.UnityRLInitializationOutputProto academyParameters) |
|||
{ |
|||
CommunicatorObjects.UnityInputProto input; |
|||
var initializationInput = new CommunicatorObjects.UnityInputProto(); |
|||
try |
|||
{ |
|||
initializationInput = m_Communicator.Initialize( |
|||
new CommunicatorObjects.UnityOutputProto |
|||
{ |
|||
RlInitializationOutput = academyParameters |
|||
}, |
|||
out input); |
|||
} |
|||
catch |
|||
{ |
|||
var exceptionMessage = "The Communicator was unable to connect. Please make sure the External " + |
|||
"process is ready to accept communication with Unity."; |
|||
|
|||
// Check for common error condition and add details to the exception message.
|
|||
var httpProxy = Environment.GetEnvironmentVariable("HTTP_PROXY"); |
|||
var httpsProxy = Environment.GetEnvironmentVariable("HTTPS_PROXY"); |
|||
if (httpProxy != null || httpsProxy != null) |
|||
{ |
|||
exceptionMessage += " Try removing HTTP_PROXY and HTTPS_PROXY from the" + |
|||
"environment variables and try again."; |
|||
} |
|||
throw new UnityAgentsException(exceptionMessage); |
|||
} |
|||
|
|||
var firstRlInput = input.RlInput; |
|||
m_Command = firstRlInput.Command; |
|||
m_EnvironmentParameters = firstRlInput.EnvironmentParameters; |
|||
m_IsTraining = firstRlInput.IsTraining; |
|||
return initializationInput.RlInitializationInput; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Gets the command. Is used by the academy to get reset or quit signals.
|
|||
/// </summary>
|
|||
/// <returns>The current command.</returns>
|
|||
public CommunicatorObjects.CommandProto GetCommand() |
|||
{ |
|||
return m_Command; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Gets the number of messages received so far. Can be used to check for new messages.
|
|||
/// </summary>
|
|||
/// <returns>The number of messages received since start of the simulation</returns>
|
|||
public ulong GetNumberMessageReceived() |
|||
{ |
|||
return m_MessagesReceived; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Gets the environment parameters. Is used by the academy to update
|
|||
/// the environment parameters.
|
|||
/// </summary>
|
|||
/// <returns>The environment parameters.</returns>
|
|||
public CommunicatorObjects.EnvironmentParametersProto GetEnvironmentParameters() |
|||
{ |
|||
return m_EnvironmentParameters; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Gets the last training_mode flag External sent
|
|||
/// </summary>
|
|||
/// <returns><c>true</c>, if training mode is requested, <c>false</c> otherwise.</returns>
|
|||
public bool GetIsTraining() |
|||
{ |
|||
return m_IsTraining; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Adds the brain to the list of brains which will be sending information to External.
|
|||
/// </summary>
|
|||
/// <param name="brainKey">Brain key.</param>
|
|||
public void SubscribeBrain(string brainKey) |
|||
{ |
|||
m_HasQueried[brainKey] = false; |
|||
m_HasData[brainKey] = false; |
|||
m_CurrentAgents[brainKey] = new List<Agent>(k_NumAgents); |
|||
m_CurrentUnityRlOutput.AgentInfos.Add( |
|||
brainKey, |
|||
new CommunicatorObjects.UnityRLOutputProto.Types.ListAgentInfoProto()); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Sends the brain info. If at least one brain has an agent in need of
|
|||
/// a decision or if the academy is done, the data is sent via
|
|||
/// Communicator. Else, a new step is realized. The data can only be
|
|||
/// sent once all the brains that subscribed to the batcher have tried
|
|||
/// to send information.
|
|||
/// </summary>
|
|||
/// <param name="brainKey">Brain key.</param>
|
|||
/// <param name="agentInfo">Agent info.</param>
|
|||
public void SendBrainInfo( |
|||
string brainKey, Dictionary<Agent, AgentInfo> agentInfo) |
|||
{ |
|||
// If no communicator is initialized, the Batcher will not transmit
|
|||
// BrainInfo
|
|||
if (m_Communicator == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
// The brain tried called GiveBrainInfo, update m_hasQueried
|
|||
m_HasQueried[brainKey] = true; |
|||
// Populate the currentAgents dictionary
|
|||
m_CurrentAgents[brainKey].Clear(); |
|||
foreach (var agent in agentInfo.Keys) |
|||
{ |
|||
m_CurrentAgents[brainKey].Add(agent); |
|||
} |
|||
|
|||
// If at least one agent has data to send, then append data to
|
|||
// the message and update hasSentState
|
|||
if (m_CurrentAgents[brainKey].Count > 0) |
|||
{ |
|||
foreach (var agent in m_CurrentAgents[brainKey]) |
|||
{ |
|||
var agentInfoProto = agentInfo[agent].ToProto(); |
|||
m_CurrentUnityRlOutput.AgentInfos[brainKey].Value.Add(agentInfoProto); |
|||
// Avoid visual obs memory leak. This should be called AFTER we are done with the visual obs.
|
|||
// e.g. after recording them to demo and using them for inference.
|
|||
agentInfo[agent].ClearVisualObs(); |
|||
} |
|||
|
|||
m_HasData[brainKey] = true; |
|||
} |
|||
|
|||
// If any agent needs to send data, then the whole message
|
|||
// must be sent
|
|||
if (m_HasQueried.Values.All(x => x)) |
|||
{ |
|||
if (m_HasData.Values.Any(x => x)) |
|||
{ |
|||
SendBatchedMessageHelper(); |
|||
} |
|||
|
|||
// The message was just sent so we must reset hasSentState and
|
|||
// triedSendState
|
|||
foreach (var k in m_CurrentAgents.Keys) |
|||
{ |
|||
m_HasData[k] = false; |
|||
m_HasQueried[k] = false; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Helper method that sends the current UnityRLOutput, receives the next UnityInput and
|
|||
/// Applies the appropriate AgentAction to the agents.
|
|||
/// </summary>
|
|||
void SendBatchedMessageHelper() |
|||
{ |
|||
var input = m_Communicator.Exchange( |
|||
new CommunicatorObjects.UnityOutputProto |
|||
{ |
|||
RlOutput = m_CurrentUnityRlOutput |
|||
}); |
|||
m_MessagesReceived += 1; |
|||
|
|||
|
|||
|
|||
foreach (var k in m_CurrentUnityRlOutput.AgentInfos.Keys) |
|||
{ |
|||
m_CurrentUnityRlOutput.AgentInfos[k].Value.Clear(); |
|||
} |
|||
|
|||
if (input == null) |
|||
{ |
|||
m_Command = CommunicatorObjects.CommandProto.Quit; |
|||
return; |
|||
} |
|||
|
|||
var rlInput = input.RlInput; |
|||
|
|||
if (rlInput == null) |
|||
{ |
|||
m_Command = CommunicatorObjects.CommandProto.Quit; |
|||
return; |
|||
} |
|||
|
|||
m_Command = rlInput.Command; |
|||
m_EnvironmentParameters = rlInput.EnvironmentParameters; |
|||
m_IsTraining = rlInput.IsTraining; |
|||
|
|||
if (rlInput.AgentActions == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
foreach (var brainName in rlInput.AgentActions.Keys) |
|||
{ |
|||
if (!m_CurrentAgents[brainName].Any()) |
|||
{ |
|||
continue; |
|||
} |
|||
|
|||
if (!rlInput.AgentActions[brainName].Value.Any()) |
|||
{ |
|||
continue; |
|||
} |
|||
|
|||
for (var i = 0; i < m_CurrentAgents[brainName].Count; i++) |
|||
{ |
|||
var agent = m_CurrentAgents[brainName][i]; |
|||
var action = rlInput.AgentActions[brainName].Value[i]; |
|||
agent.UpdateVectorAction(action.VectorActions.ToArray()); |
|||
agent.UpdateMemoriesAction(action.Memories.ToList()); |
|||
agent.UpdateTextAction(action.TextActions); |
|||
agent.UpdateValueAction(action.Value); |
|||
agent.UpdateCustomAction(action.CustomAction); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
撰写
预览
正在加载...
取消
保存
Reference in new issue