Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

597 行
23 KiB

# if UNITY_EDITOR || UNITY_STANDALONE_WIN || UNITY_STANDALONE_OSX || UNITY_STANDALONE_LINUX
using Grpc.Core;
#endif
#if UNITY_EDITOR
using UnityEditor;
#endif
using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using Unity.MLAgents.Analytics;
using Unity.MLAgents.CommunicatorObjects;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Policies;
using Unity.MLAgents.SideChannels;
using Google.Protobuf;
namespace Unity.MLAgents
{
/// Responsible for communication with External using gRPC.
internal class RpcCommunicator : ICommunicator
{
// The python package version must be >= s_MinSupportedPythonPackageVersion
// and <= s_MaxSupportedPythonPackageVersion.
static Version s_MinSupportedPythonPackageVersion = new Version("0.16.1");
static Version s_MaxSupportedPythonPackageVersion = new Version("0.20.0");
public event QuitCommandHandler QuitCommandReceived;
public event ResetCommandHandler ResetCommandReceived;
/// If true, the communication is active.
bool m_IsOpen;
List<string> m_BehaviorNames = new List<string>();
bool m_NeedCommunicateThisStep;
ObservationWriter m_ObservationWriter = new ObservationWriter();
Dictionary<string, SensorShapeValidator> m_SensorShapeValidators = new Dictionary<string, SensorShapeValidator>();
Dictionary<string, List<int>> m_OrderedAgentsRequestingDecisions = new Dictionary<string, List<int>>();
/// The current UnityRLOutput to be sent when all the brains queried the communicator
UnityRLOutputProto m_CurrentUnityRlOutput =
new UnityRLOutputProto();
Dictionary<string, Dictionary<int, float[]>> m_LastActionsReceived =
new Dictionary<string, Dictionary<int, float[]>>();
// Brains that we have sent over the communicator with agents.
HashSet<string> m_SentBrainKeys = new HashSet<string>();
Dictionary<string, BrainParameters> m_UnsentBrainKeys = new Dictionary<string, BrainParameters>();
# if UNITY_EDITOR || UNITY_STANDALONE_WIN || UNITY_STANDALONE_OSX || UNITY_STANDALONE_LINUX
/// The Unity to External client.
UnityToExternalProto.UnityToExternalProtoClient m_Client;
#endif
/// The communicator parameters sent at construction
CommunicatorInitParameters m_CommunicatorInitParameters;
/// <summary>
/// Initializes a new instance of the RPCCommunicator class.
/// </summary>
/// <param name="communicatorInitParameters">Communicator parameters.</param>
public RpcCommunicator(CommunicatorInitParameters communicatorInitParameters)
{
m_CommunicatorInitParameters = communicatorInitParameters;
}
#region Initialization
internal static bool CheckCommunicationVersionsAreCompatible(
string unityCommunicationVersion,
string pythonApiVersion,
string pythonLibraryVersion)
{
var unityVersion = new Version(unityCommunicationVersion);
var pythonVersion = new Version(pythonApiVersion);
if (unityVersion.Major == 0)
{
if (unityVersion.Major != pythonVersion.Major || unityVersion.Minor != pythonVersion.Minor)
{
return false;
}
}
else if (unityVersion.Major != pythonVersion.Major)
{
return false;
}
else if (unityVersion.Minor != pythonVersion.Minor)
{
// Even if we initialize, we still want to check to make sure that we inform users of minor version
// changes. This will surface any features that may not work due to minor version incompatibilities.
Debug.LogWarningFormat(
"WARNING: The communication API versions between Unity and python differ at the minor version level. " +
"Python API: {0}, Unity API: {1} Python Library Version: {2} .\n" +
"This means that some features may not work unless you upgrade the package with the lower version." +
"Please find the versions that work best together from our release page.\n" +
"https://github.com/Unity-Technologies/ml-agents/releases",
pythonApiVersion, unityCommunicationVersion, pythonLibraryVersion
);
}
return true;
}
internal static bool CheckPythonPackageVersionIsCompatible(string pythonLibraryVersion)
{
Version pythonVersion;
try
{
pythonVersion = new Version(pythonLibraryVersion);
}
catch
{
// Unparseable - this also catches things like "0.20.0-dev0" which we don't want to support
return false;
}
if (pythonVersion < s_MinSupportedPythonPackageVersion ||
pythonVersion > s_MaxSupportedPythonPackageVersion)
{
return false;
}
return true;
}
/// <summary>
/// Sends the initialization parameters through the Communicator.
/// Is used by the academy to send initialization parameters to the communicator.
/// </summary>
/// <returns>The External Initialization Parameters received.</returns>
/// <param name="initParameters">The Unity Initialization Parameters to be sent.</param>
public UnityRLInitParameters Initialize(CommunicatorInitParameters initParameters)
{
var academyParameters = new UnityRLInitializationOutputProto
{
Name = initParameters.name,
PackageVersion = initParameters.unityPackageVersion,
CommunicationVersion = initParameters.unityCommunicationVersion,
Capabilities = initParameters.CSharpCapabilities.ToProto()
};
UnityInputProto input;
UnityInputProto initializationInput;
try
{
initializationInput = Initialize(
new UnityOutputProto
{
RlInitializationOutput = academyParameters
},
out input);
var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion;
var pythonPackageVersion = initializationInput.RlInitializationInput.PackageVersion;
var unityCommunicationVersion = initParameters.unityCommunicationVersion;
TrainingAnalytics.SetTrainerInformation(pythonPackageVersion, pythonCommunicationVersion);
var communicationIsCompatible = CheckCommunicationVersionsAreCompatible(unityCommunicationVersion,
pythonCommunicationVersion,
pythonPackageVersion);
// Initialization succeeded part-way. The most likely cause is a mismatch between the communicator
// API strings, so log an explicit warning if that's the case.
if (initializationInput != null && input == null)
{
if (!communicationIsCompatible)
{
Debug.LogWarningFormat(
"Communication protocol between python ({0}) and Unity ({1}) have different " +
"versions which make them incompatible. Python library version: {2}.",
pythonCommunicationVersion, initParameters.unityCommunicationVersion,
pythonPackageVersion
);
}
else
{
Debug.LogWarningFormat(
"Unknown communication error between Python. Python communication protocol: {0}, " +
"Python library version: {1}.",
pythonCommunicationVersion,
pythonPackageVersion
);
}
throw new UnityAgentsException("ICommunicator.Initialize() failed.");
}
var packageVersionSupported = CheckPythonPackageVersionIsCompatible(pythonPackageVersion);
if (!packageVersionSupported)
{
Debug.LogWarningFormat(
"Python package version ({0}) is out of the supported range or not from an official release. " +
"It is strongly recommended that you use a Python package between {1} and {2}. " +
"Training will proceed, but the output format may be different.",
pythonPackageVersion,
s_MinSupportedPythonPackageVersion,
s_MaxSupportedPythonPackageVersion
);
}
}
catch
{
var exceptionMessage = "The Communicator was unable to connect. Please make sure the External " +
"process is ready to accept communication with Unity.";
// Check for common error condition and add details to the exception message.
var httpProxy = Environment.GetEnvironmentVariable("HTTP_PROXY");
var httpsProxy = Environment.GetEnvironmentVariable("HTTPS_PROXY");
if (httpProxy != null || httpsProxy != null)
{
exceptionMessage += " Try removing HTTP_PROXY and HTTPS_PROXY from the" +
"environment variables and try again.";
}
throw new UnityAgentsException(exceptionMessage);
}
UpdateEnvironmentWithInput(input.RlInput);
return initializationInput.RlInitializationInput.ToUnityRLInitParameters();
}
/// <summary>
/// Adds the brain to the list of brains which will be sending information to External.
/// </summary>
/// <param name="brainKey">Brain key.</param>
/// <param name="brainParameters">Brain parameters needed to send to the trainer.</param>
public void SubscribeBrain(string brainKey, BrainParameters brainParameters)
{
if (m_BehaviorNames.Contains(brainKey))
{
return;
}
m_BehaviorNames.Add(brainKey);
m_CurrentUnityRlOutput.AgentInfos.Add(
brainKey,
new UnityRLOutputProto.Types.ListAgentInfoProto()
);
CacheBrainParameters(brainKey, brainParameters);
}
void UpdateEnvironmentWithInput(UnityRLInputProto rlInput)
{
SideChannelsManager.ProcessSideChannelData(rlInput.SideChannel.ToArray());
SendCommandEvent(rlInput.Command);
}
UnityInputProto Initialize(UnityOutputProto unityOutput,
out UnityInputProto unityInput)
{
# if UNITY_EDITOR || UNITY_STANDALONE_WIN || UNITY_STANDALONE_OSX || UNITY_STANDALONE_LINUX
m_IsOpen = true;
var channel = new Channel(
"localhost:" + m_CommunicatorInitParameters.port,
ChannelCredentials.Insecure);
m_Client = new UnityToExternalProto.UnityToExternalProtoClient(channel);
var result = m_Client.Exchange(WrapMessage(unityOutput, 200));
unityInput = m_Client.Exchange(WrapMessage(null, 200)).UnityInput;
#if UNITY_EDITOR
EditorApplication.playModeStateChanged += HandleOnPlayModeChanged;
#endif
return result.UnityInput;
#else
throw new UnityAgentsException(
"You cannot perform training on this platform.");
#endif
}
#endregion
#region Destruction
/// <summary>
/// Close the communicator gracefully on both sides of the communication.
/// </summary>
public void Dispose()
{
# if UNITY_EDITOR || UNITY_STANDALONE_WIN || UNITY_STANDALONE_OSX || UNITY_STANDALONE_LINUX
if (!m_IsOpen)
{
return;
}
try
{
m_Client.Exchange(WrapMessage(null, 400));
m_IsOpen = false;
}
catch
{
// ignored
}
#else
throw new UnityAgentsException(
"You cannot perform training on this platform.");
#endif
}
#endregion
#region Sending Events
void SendCommandEvent(CommandProto command)
{
switch (command)
{
case CommandProto.Quit:
{
QuitCommandReceived?.Invoke();
return;
}
case CommandProto.Reset:
{
foreach (var brainName in m_OrderedAgentsRequestingDecisions.Keys)
{
m_OrderedAgentsRequestingDecisions[brainName].Clear();
}
ResetCommandReceived?.Invoke();
return;
}
default:
{
return;
}
}
}
#endregion
#region Sending and retreiving data
public void DecideBatch()
{
if (!m_NeedCommunicateThisStep)
{
return;
}
m_NeedCommunicateThisStep = false;
SendBatchedMessageHelper();
}
/// <summary>
/// Sends the observations of one Agent.
/// </summary>
/// <param name="behaviorName">Batch Key.</param>
/// <param name="info">Agent info.</param>
/// <param name="sensors">Sensors that will produce the observations</param>
public void PutObservations(string behaviorName, AgentInfo info, List<ISensor> sensors)
{
# if DEBUG
if (!m_SensorShapeValidators.ContainsKey(behaviorName))
{
m_SensorShapeValidators[behaviorName] = new SensorShapeValidator();
}
m_SensorShapeValidators[behaviorName].ValidateSensors(sensors);
#endif
using (TimerStack.Instance.Scoped("AgentInfo.ToProto"))
{
var agentInfoProto = info.ToAgentInfoProto();
using (TimerStack.Instance.Scoped("GenerateSensorData"))
{
foreach (var sensor in sensors)
{
var obsProto = sensor.GetObservationProto(m_ObservationWriter);
agentInfoProto.Observations.Add(obsProto);
}
}
m_CurrentUnityRlOutput.AgentInfos[behaviorName].Value.Add(agentInfoProto);
}
m_NeedCommunicateThisStep = true;
if (!m_OrderedAgentsRequestingDecisions.ContainsKey(behaviorName))
{
m_OrderedAgentsRequestingDecisions[behaviorName] = new List<int>();
}
if (!info.done)
{
m_OrderedAgentsRequestingDecisions[behaviorName].Add(info.episodeId);
}
if (!m_LastActionsReceived.ContainsKey(behaviorName))
{
m_LastActionsReceived[behaviorName] = new Dictionary<int, float[]>();
}
m_LastActionsReceived[behaviorName][info.episodeId] = null;
if (info.done)
{
m_LastActionsReceived[behaviorName].Remove(info.episodeId);
}
}
/// <summary>
/// Helper method that sends the current UnityRLOutput, receives the next UnityInput and
/// Applies the appropriate AgentAction to the agents.
/// </summary>
void SendBatchedMessageHelper()
{
var message = new UnityOutputProto
{
RlOutput = m_CurrentUnityRlOutput,
};
var tempUnityRlInitializationOutput = GetTempUnityRlInitializationOutput();
if (tempUnityRlInitializationOutput != null)
{
message.RlInitializationOutput = tempUnityRlInitializationOutput;
}
byte[] messageAggregated = SideChannelsManager.GetSideChannelMessage();
message.RlOutput.SideChannel = ByteString.CopyFrom(messageAggregated);
var input = Exchange(message);
UpdateSentBrainParameters(tempUnityRlInitializationOutput);
foreach (var k in m_CurrentUnityRlOutput.AgentInfos.Keys)
{
m_CurrentUnityRlOutput.AgentInfos[k].Value.Clear();
}
var rlInput = input?.RlInput;
if (rlInput?.AgentActions == null)
{
return;
}
UpdateEnvironmentWithInput(rlInput);
foreach (var brainName in rlInput.AgentActions.Keys)
{
if (!m_OrderedAgentsRequestingDecisions[brainName].Any())
{
continue;
}
if (!rlInput.AgentActions[brainName].Value.Any())
{
continue;
}
var agentActions = rlInput.AgentActions[brainName].ToAgentActionList();
var numAgents = m_OrderedAgentsRequestingDecisions[brainName].Count;
for (var i = 0; i < numAgents; i++)
{
var agentAction = agentActions[i];
var agentId = m_OrderedAgentsRequestingDecisions[brainName][i];
if (m_LastActionsReceived[brainName].ContainsKey(agentId))
{
m_LastActionsReceived[brainName][agentId] = agentAction.vectorActions;
}
}
}
foreach (var brainName in m_OrderedAgentsRequestingDecisions.Keys)
{
m_OrderedAgentsRequestingDecisions[brainName].Clear();
}
}
public float[] GetActions(string behaviorName, int agentId)
{
if (m_LastActionsReceived.ContainsKey(behaviorName))
{
if (m_LastActionsReceived[behaviorName].ContainsKey(agentId))
{
return m_LastActionsReceived[behaviorName][agentId];
}
}
return null;
}
/// <summary>
/// Send a UnityOutput and receives a UnityInput.
/// </summary>
/// <returns>The next UnityInput.</returns>
/// <param name="unityOutput">The UnityOutput to be sent.</param>
UnityInputProto Exchange(UnityOutputProto unityOutput)
{
# if UNITY_EDITOR || UNITY_STANDALONE_WIN || UNITY_STANDALONE_OSX || UNITY_STANDALONE_LINUX
if (!m_IsOpen)
{
return null;
}
try
{
var message = m_Client.Exchange(WrapMessage(unityOutput, 200));
if (message.Header.Status == 200)
{
return message.UnityInput;
}
m_IsOpen = false;
// Not sure if the quit command is actually sent when a
// non 200 message is received. Notify that we are indeed
// quitting.
QuitCommandReceived?.Invoke();
return message.UnityInput;
}
catch
{
m_IsOpen = false;
QuitCommandReceived?.Invoke();
return null;
}
#else
throw new UnityAgentsException(
"You cannot perform training on this platform.");
#endif
}
/// <summary>
/// Wraps the UnityOutput into a message with the appropriate status.
/// </summary>
/// <returns>The UnityMessage corresponding.</returns>
/// <param name="content">The UnityOutput to be wrapped.</param>
/// <param name="status">The status of the message.</param>
static UnityMessageProto WrapMessage(UnityOutputProto content, int status)
{
return new UnityMessageProto
{
Header = new HeaderProto { Status = status },
UnityOutput = content
};
}
void CacheBrainParameters(string behaviorName, BrainParameters brainParameters)
{
if (m_SentBrainKeys.Contains(behaviorName))
{
return;
}
// TODO We should check that if m_unsentBrainKeys has brainKey, it equals brainParameters
m_UnsentBrainKeys[behaviorName] = brainParameters;
}
UnityRLInitializationOutputProto GetTempUnityRlInitializationOutput()
{
UnityRLInitializationOutputProto output = null;
foreach (var behaviorName in m_UnsentBrainKeys.Keys)
{
if (m_CurrentUnityRlOutput.AgentInfos.ContainsKey(behaviorName))
{
if (m_CurrentUnityRlOutput.AgentInfos[behaviorName].CalculateSize() > 0)
{
// Only send the BrainParameters if there is a non empty list of
// AgentInfos ready to be sent.
// This is to ensure that The Python side will always have a first
// observation when receiving the BrainParameters
if (output == null)
{
output = new UnityRLInitializationOutputProto();
}
var brainParameters = m_UnsentBrainKeys[behaviorName];
output.BrainParameters.Add(brainParameters.ToProto(behaviorName, true));
}
}
}
return output;
}
void UpdateSentBrainParameters(UnityRLInitializationOutputProto output)
{
if (output == null)
{
return;
}
foreach (var brainProto in output.BrainParameters)
{
m_SentBrainKeys.Add(brainProto.BrainName);
m_UnsentBrainKeys.Remove(brainProto.BrainName);
}
}
#endregion
#if UNITY_EDITOR
/// <summary>
/// When the editor exits, the communicator must be closed
/// </summary>
/// <param name="state">State.</param>
void HandleOnPlayModeChanged(PlayModeStateChange state)
{
// This method is run whenever the playmode state is changed.
if (state == PlayModeStateChange.ExitingPlayMode)
{
Dispose();
}
}
#endif
}
}