Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

354 行
12 KiB

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Newtonsoft.Json;
using System.Linq;
using System.Net.Sockets;
using System.Text;
/// Responsible for communication with Python API.
public class ExternalCommunicator : Communicator
{
Academy academy;
Dictionary<string, List<int>> current_agents;
List<Brain> brains;
Dictionary<string, bool> hasSentState;
Dictionary<string, Dictionary<int, float[]>> storedActions;
Dictionary<string, Dictionary<int, float[]>> storedMemories;
Dictionary<string, Dictionary<int, float>> storedValues;
private int comPort;
Socket sender;
byte[] messageHolder;
const int messageLength = 12000;
const string api = "API-2";
private class StepMessage
{
public string brain_name { get; set; }
public List<int> agents { get; set; }
public List<float> states { get; set; }
public List<float> rewards { get; set; }
public List<float> actions { get; set; }
public List<float> memories { get; set; }
public List<bool> dones { get; set; }
}
private class AgentMessage
{
public Dictionary<string, List<float>> action { get; set; }
public Dictionary<string, List<float>> memory { get; set; }
public Dictionary<string, List<float>> value { get; set; }
}
private class ResetParametersMessage
{
public Dictionary<string, float> parameters { get; set; }
public bool train_model { get; set; }
}
/// Consrtuctor for the External Communicator
public ExternalCommunicator(Academy aca)
{
academy = aca;
brains = new List<Brain>();
current_agents = new Dictionary<string, List<int>>();
hasSentState = new Dictionary<string, bool>();
storedActions = new Dictionary<string, Dictionary<int, float[]>>();
storedMemories = new Dictionary<string, Dictionary<int, float[]>>();
storedValues = new Dictionary<string, Dictionary<int, float>>();
}
/// Adds the brain to the list of brains which have already decided their
/// actions.
public void SubscribeBrain(Brain brain)
{
brains.Add(brain);
hasSentState[brain.gameObject.name] = false;
}
public bool CommunicatorHandShake(){
try
{
ReadArgs();
}
catch
{
return false;
}
return true;
}
/// Contains the logic for the initializtation of the socket.
public void InitializeCommunicator()
{
messageHolder = new byte[messageLength];
// Create a TCP/IP socket.
sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sender.Connect("localhost", comPort);
AcademyParameters accParamerters = new AcademyParameters();
accParamerters.brainParameters = new List<BrainParameters>();
accParamerters.brainNames = new List<string>();
accParamerters.externalBrainNames = new List<string>();
accParamerters.apiNumber = api;
foreach (Brain b in brains)
{
accParamerters.brainParameters.Add(b.brainParameters);
accParamerters.brainNames.Add(b.gameObject.name);
if (b.brainType == BrainType.External)
{
accParamerters.externalBrainNames.Add(b.gameObject.name);
}
}
accParamerters.AcademyName = academy.gameObject.name;
accParamerters.resetParameters = academy.resetParameters;
SendParameters(accParamerters);
}
/// Listens to the socket for a command and returns the corresponding
/// External Command.
public ExternalCommand GetCommand()
{
int location = sender.Receive(messageHolder);
string message = Encoding.ASCII.GetString(messageHolder, 0, location);
switch (message)
{
case "STEP":
return ExternalCommand.STEP;
case "RESET":
return ExternalCommand.RESET;
case "QUIT":
return ExternalCommand.QUIT;
default:
return ExternalCommand.QUIT;
}
}
/// Listens to the socket for the new resetParameters
public Dictionary<string, float> GetResetParameters()
{
sender.Send(Encoding.ASCII.GetBytes("CONFIG_REQUEST"));
ResetParametersMessage resetParams = JsonConvert.DeserializeObject<ResetParametersMessage>(Receive());
academy.isInference = !resetParams.train_model;
return resetParams.parameters;
}
/// Used to read Python-provided environment parameters
private void ReadArgs()
{
string[] args = System.Environment.GetCommandLineArgs();
string inputPort = "";
for (int i = 0; i < args.Length; i++)
{
if (args[i] == "--port")
{
inputPort = args[i + 1];
}
}
comPort = int.Parse(inputPort);
}
/// Sends Academy parameters to external agent
private void SendParameters(AcademyParameters envParams)
{
string envMessage = JsonConvert.SerializeObject(envParams, Formatting.Indented);
sender.Send(Encoding.ASCII.GetBytes(envMessage));
}
/// Receives messages from external agent
private string Receive()
{
int location = sender.Receive(messageHolder);
string message = Encoding.ASCII.GetString(messageHolder, 0, location);
return message;
}
/// Ends connection and closes environment
private void OnApplicationQuit()
{
sender.Close();
sender.Shutdown(SocketShutdown.Both);
}
/// Contains logic for coverting texture into bytearray to send to
/// external agent.
private byte[] TexToByteArray(Texture2D tex)
{
byte[] bytes = tex.EncodeToPNG();
Object.DestroyImmediate(tex);
Resources.UnloadUnusedAssets();
return bytes;
}
private byte[] AppendLength(byte[] input){
byte[] newArray = new byte[input.Length + 4];
input.CopyTo(newArray, 4);
System.BitConverter.GetBytes(input.Length).CopyTo(newArray, 0);
return newArray;
}
/// Collects the information from the brains and sends it accross the socket
public void giveBrainInfo(Brain brain)
{
string brainName = brain.gameObject.name;
current_agents[brainName] = new List<int>(brain.agents.Keys);
List<float> concatenatedStates = new List<float>();
List<float> concatenatedRewards = new List<float>();
List<float> concatenatedMemories = new List<float>();
List<bool> concatenatedDones = new List<bool>();
List<float> concatenatedActions = new List<float>();
Dictionary<int, List<Camera>> collectedObservations = brain.CollectObservations();
Dictionary<int, List<float>> collectedStates = brain.CollectStates();
Dictionary<int, float> collectedRewards = brain.CollectRewards();
Dictionary<int, float[]> collectedMemories = brain.CollectMemories();
Dictionary<int, bool> collectedDones = brain.CollectDones();
Dictionary<int, float[]> collectedActions = brain.CollectActions();
foreach (int id in current_agents[brainName])
{
concatenatedStates = concatenatedStates.Concat(collectedStates[id]).ToList();
concatenatedRewards.Add(collectedRewards[id]);
concatenatedMemories = concatenatedMemories.Concat(collectedMemories[id].ToList()).ToList();
concatenatedDones.Add(collectedDones[id]);
concatenatedActions = concatenatedActions.Concat(collectedActions[id].ToList()).ToList();
}
StepMessage message = new StepMessage()
{
brain_name = brainName,
agents = current_agents[brainName],
states = concatenatedStates,
rewards = concatenatedRewards,
actions = concatenatedActions,
memories = concatenatedMemories,
dones = concatenatedDones
};
string envMessage = JsonConvert.SerializeObject(message, Formatting.Indented);
sender.Send(AppendLength(Encoding.ASCII.GetBytes(envMessage)));
Receive();
int i = 0;
foreach (resolution res in brain.brainParameters.cameraResolutions)
{
foreach (int id in current_agents[brainName])
{
sender.Send(AppendLength(TexToByteArray(brain.ObservationToTex(collectedObservations[id][i], res.width, res.height))));
Receive();
}
i++;
}
hasSentState[brainName] = true;
if (hasSentState.Values.All(x => x))
{
// if all the brains listed have sent their state
sender.Send(Encoding.ASCII.GetBytes((academy.done ? "True" : "False")));
List<string> brainNames = hasSentState.Keys.ToList();
foreach (string k in brainNames)
{
hasSentState[k] = false;
}
}
}
/// Listens for actions, memories, and values and sends them
/// to the corrensponding brains.
public void UpdateActions()
{
// TO MODIFY --------------------------------------------
sender.Send(Encoding.ASCII.GetBytes("STEPPING"));
string a = Receive();
AgentMessage agentMessage = JsonConvert.DeserializeObject<AgentMessage>(a);
foreach (Brain brain in brains)
{
if (brain.brainType == BrainType.External)
{
string brainName = brain.gameObject.name;
Dictionary<int, float[]> actionDict = new Dictionary<int, float[]>();
for (int i = 0; i < current_agents[brainName].Count; i++)
{
if (brain.brainParameters.actionSpaceType == StateType.continuous)
{
actionDict.Add(current_agents[brainName][i],
agentMessage.action[brainName].GetRange(i * brain.brainParameters.actionSize, brain.brainParameters.actionSize).ToArray());
}
else
{
actionDict.Add(current_agents[brainName][i],
agentMessage.action[brainName].GetRange(i, 1).ToArray());
}
}
storedActions[brainName] = actionDict;
Dictionary<int, float[]> memoryDict = new Dictionary<int, float[]>();
for (int i = 0; i < current_agents[brainName].Count; i++)
{
memoryDict.Add(current_agents[brainName][i],
agentMessage.memory[brainName].GetRange(i * brain.brainParameters.memorySize, brain.brainParameters.memorySize).ToArray());
}
storedMemories[brainName] = memoryDict;
Dictionary<int, float> valueDict = new Dictionary<int, float>();
for (int i = 0; i < current_agents[brainName].Count; i++)
{
valueDict.Add(current_agents[brainName][i],
agentMessage.value[brainName][i]);
}
storedValues[brainName] = valueDict;
}
}
}
/// Returns the actions corrensponding to the brain called brainName that
/// were received throught the socket.
public Dictionary<int, float[]> GetDecidedAction(string brainName)
{
return storedActions[brainName];
}
/// Returns the memories corrensponding to the brain called brainName that
/// were received throught the socket.
public Dictionary<int, float[]> GetMemories(string brainName)
{
return storedMemories[brainName];
}
/// Returns the values corrensponding to the brain called brainName that
/// were received throught the socket.
public Dictionary<int, float> GetValues(string brainName)
{
return storedValues[brainName];
}
}