您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
374 行
12 KiB
374 行
12 KiB
using System.Collections;
|
|
using System.Collections.Generic;
|
|
using UnityEngine;
|
|
|
|
using Newtonsoft.Json;
|
|
using System.Linq;
|
|
using System.Net.Sockets;
|
|
using System.Text;
|
|
using System.IO;
|
|
|
|
|
|
/// Responsible for communication with Python API.
|
|
public class ExternalCommunicator : Communicator
|
|
{
|
|
|
|
Academy academy;
|
|
|
|
Dictionary<string, List<int>> current_agents;
|
|
|
|
List<Brain> brains;
|
|
|
|
Dictionary<string, bool> hasSentState;
|
|
|
|
Dictionary<string, Dictionary<int, float[]>> storedActions;
|
|
Dictionary<string, Dictionary<int, float[]>> storedMemories;
|
|
Dictionary<string, Dictionary<int, float>> storedValues;
|
|
|
|
private int comPort;
|
|
Socket sender;
|
|
byte[] messageHolder;
|
|
|
|
const int messageLength = 12000;
|
|
|
|
StreamWriter logWriter;
|
|
string logPath;
|
|
|
|
const string api = "API-2";
|
|
|
|
private class StepMessage
|
|
{
|
|
public string brain_name { get; set; }
|
|
|
|
public List<int> agents { get; set; }
|
|
|
|
public List<float> states { get; set; }
|
|
|
|
public List<float> rewards { get; set; }
|
|
|
|
public List<float> actions { get; set; }
|
|
|
|
public List<float> memories { get; set; }
|
|
|
|
public List<bool> dones { get; set; }
|
|
}
|
|
|
|
private class AgentMessage
|
|
{
|
|
public Dictionary<string, List<float>> action { get; set; }
|
|
|
|
public Dictionary<string, List<float>> memory { get; set; }
|
|
|
|
public Dictionary<string, List<float>> value { get; set; }
|
|
|
|
}
|
|
|
|
private class ResetParametersMessage
|
|
{
|
|
public Dictionary<string, float> parameters { get; set; }
|
|
|
|
public bool train_model { get; set; }
|
|
}
|
|
|
|
/// Consrtuctor for the External Communicator
|
|
public ExternalCommunicator(Academy aca)
|
|
{
|
|
academy = aca;
|
|
brains = new List<Brain>();
|
|
current_agents = new Dictionary<string, List<int>>();
|
|
|
|
hasSentState = new Dictionary<string, bool>();
|
|
|
|
storedActions = new Dictionary<string, Dictionary<int, float[]>>();
|
|
storedMemories = new Dictionary<string, Dictionary<int, float[]>>();
|
|
storedValues = new Dictionary<string, Dictionary<int, float>>();
|
|
}
|
|
|
|
/// Adds the brain to the list of brains which have already decided their
|
|
/// actions.
|
|
public void SubscribeBrain(Brain brain)
|
|
{
|
|
brains.Add(brain);
|
|
hasSentState[brain.gameObject.name] = false;
|
|
}
|
|
|
|
|
|
public bool CommunicatorHandShake(){
|
|
try
|
|
{
|
|
ReadArgs();
|
|
}
|
|
catch
|
|
{
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/// Contains the logic for the initializtation of the socket.
|
|
public void InitializeCommunicator()
|
|
{
|
|
Application.logMessageReceived += HandleLog;
|
|
logPath = Path.GetFullPath(".") + "/unity-environment.log";
|
|
logWriter = new StreamWriter(logPath, false);
|
|
logWriter.WriteLine(System.DateTime.Now.ToString());
|
|
logWriter.WriteLine(" ");
|
|
logWriter.Close();
|
|
messageHolder = new byte[messageLength];
|
|
|
|
// Create a TCP/IP socket.
|
|
sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
|
sender.Connect("localhost", comPort);
|
|
|
|
AcademyParameters accParamerters = new AcademyParameters();
|
|
accParamerters.brainParameters = new List<BrainParameters>();
|
|
accParamerters.brainNames = new List<string>();
|
|
accParamerters.externalBrainNames = new List<string>();
|
|
accParamerters.apiNumber = api;
|
|
accParamerters.logPath = logPath;
|
|
foreach (Brain b in brains)
|
|
{
|
|
accParamerters.brainParameters.Add(b.brainParameters);
|
|
accParamerters.brainNames.Add(b.gameObject.name);
|
|
if (b.brainType == BrainType.External)
|
|
{
|
|
accParamerters.externalBrainNames.Add(b.gameObject.name);
|
|
}
|
|
}
|
|
accParamerters.AcademyName = academy.gameObject.name;
|
|
accParamerters.resetParameters = academy.resetParameters;
|
|
|
|
SendParameters(accParamerters);
|
|
}
|
|
|
|
void HandleLog(string logString, string stackTrace, LogType type)
|
|
{
|
|
logWriter = new StreamWriter(logPath, true);
|
|
logWriter.WriteLine(type.ToString());
|
|
logWriter.WriteLine(logString);
|
|
logWriter.WriteLine(stackTrace);
|
|
logWriter.Close();
|
|
}
|
|
|
|
/// Listens to the socket for a command and returns the corresponding
|
|
/// External Command.
|
|
public ExternalCommand GetCommand()
|
|
{
|
|
int location = sender.Receive(messageHolder);
|
|
string message = Encoding.ASCII.GetString(messageHolder, 0, location);
|
|
switch (message)
|
|
{
|
|
case "STEP":
|
|
return ExternalCommand.STEP;
|
|
case "RESET":
|
|
return ExternalCommand.RESET;
|
|
case "QUIT":
|
|
return ExternalCommand.QUIT;
|
|
default:
|
|
return ExternalCommand.QUIT;
|
|
}
|
|
}
|
|
|
|
/// Listens to the socket for the new resetParameters
|
|
public Dictionary<string, float> GetResetParameters()
|
|
{
|
|
sender.Send(Encoding.ASCII.GetBytes("CONFIG_REQUEST"));
|
|
ResetParametersMessage resetParams = JsonConvert.DeserializeObject<ResetParametersMessage>(Receive());
|
|
academy.isInference = !resetParams.train_model;
|
|
return resetParams.parameters;
|
|
}
|
|
|
|
|
|
/// Used to read Python-provided environment parameters
|
|
private void ReadArgs()
|
|
{
|
|
string[] args = System.Environment.GetCommandLineArgs();
|
|
string inputPort = "";
|
|
for (int i = 0; i < args.Length; i++)
|
|
{
|
|
if (args[i] == "--port")
|
|
{
|
|
inputPort = args[i + 1];
|
|
}
|
|
}
|
|
|
|
comPort = int.Parse(inputPort);
|
|
}
|
|
|
|
/// Sends Academy parameters to external agent
|
|
private void SendParameters(AcademyParameters envParams)
|
|
{
|
|
string envMessage = JsonConvert.SerializeObject(envParams, Formatting.Indented);
|
|
sender.Send(Encoding.ASCII.GetBytes(envMessage));
|
|
}
|
|
|
|
/// Receives messages from external agent
|
|
private string Receive()
|
|
{
|
|
int location = sender.Receive(messageHolder);
|
|
string message = Encoding.ASCII.GetString(messageHolder, 0, location);
|
|
return message;
|
|
}
|
|
|
|
|
|
/// Ends connection and closes environment
|
|
private void OnApplicationQuit()
|
|
{
|
|
sender.Close();
|
|
sender.Shutdown(SocketShutdown.Both);
|
|
}
|
|
|
|
/// Contains logic for coverting texture into bytearray to send to
|
|
/// external agent.
|
|
private byte[] TexToByteArray(Texture2D tex)
|
|
{
|
|
byte[] bytes = tex.EncodeToPNG();
|
|
Object.DestroyImmediate(tex);
|
|
Resources.UnloadUnusedAssets();
|
|
return bytes;
|
|
}
|
|
|
|
private byte[] AppendLength(byte[] input){
|
|
byte[] newArray = new byte[input.Length + 4];
|
|
input.CopyTo(newArray, 4);
|
|
System.BitConverter.GetBytes(input.Length).CopyTo(newArray, 0);
|
|
return newArray;
|
|
}
|
|
|
|
/// Collects the information from the brains and sends it accross the socket
|
|
public void giveBrainInfo(Brain brain)
|
|
{
|
|
string brainName = brain.gameObject.name;
|
|
current_agents[brainName] = new List<int>(brain.agents.Keys);
|
|
List<float> concatenatedStates = new List<float>();
|
|
List<float> concatenatedRewards = new List<float>();
|
|
List<float> concatenatedMemories = new List<float>();
|
|
List<bool> concatenatedDones = new List<bool>();
|
|
List<float> concatenatedActions = new List<float>();
|
|
Dictionary<int, List<Camera>> collectedObservations = brain.CollectObservations();
|
|
Dictionary<int, List<float>> collectedStates = brain.CollectStates();
|
|
Dictionary<int, float> collectedRewards = brain.CollectRewards();
|
|
Dictionary<int, float[]> collectedMemories = brain.CollectMemories();
|
|
Dictionary<int, bool> collectedDones = brain.CollectDones();
|
|
Dictionary<int, float[]> collectedActions = brain.CollectActions();
|
|
|
|
foreach (int id in current_agents[brainName])
|
|
{
|
|
concatenatedStates = concatenatedStates.Concat(collectedStates[id]).ToList();
|
|
concatenatedRewards.Add(collectedRewards[id]);
|
|
concatenatedMemories = concatenatedMemories.Concat(collectedMemories[id].ToList()).ToList();
|
|
concatenatedDones.Add(collectedDones[id]);
|
|
concatenatedActions = concatenatedActions.Concat(collectedActions[id].ToList()).ToList();
|
|
}
|
|
StepMessage message = new StepMessage()
|
|
{
|
|
brain_name = brainName,
|
|
agents = current_agents[brainName],
|
|
states = concatenatedStates,
|
|
rewards = concatenatedRewards,
|
|
actions = concatenatedActions,
|
|
memories = concatenatedMemories,
|
|
dones = concatenatedDones
|
|
};
|
|
string envMessage = JsonConvert.SerializeObject(message, Formatting.Indented);
|
|
sender.Send(AppendLength(Encoding.ASCII.GetBytes(envMessage)));
|
|
Receive();
|
|
int i = 0;
|
|
foreach (resolution res in brain.brainParameters.cameraResolutions)
|
|
{
|
|
foreach (int id in current_agents[brainName])
|
|
{
|
|
sender.Send(AppendLength(TexToByteArray(brain.ObservationToTex(collectedObservations[id][i], res.width, res.height))));
|
|
Receive();
|
|
}
|
|
i++;
|
|
}
|
|
|
|
hasSentState[brainName] = true;
|
|
|
|
if (hasSentState.Values.All(x => x))
|
|
{
|
|
// if all the brains listed have sent their state
|
|
sender.Send(Encoding.ASCII.GetBytes((academy.done ? "True" : "False")));
|
|
List<string> brainNames = hasSentState.Keys.ToList();
|
|
foreach (string k in brainNames)
|
|
{
|
|
hasSentState[k] = false;
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
/// Listens for actions, memories, and values and sends them
|
|
/// to the corrensponding brains.
|
|
public void UpdateActions()
|
|
{
|
|
// TO MODIFY --------------------------------------------
|
|
sender.Send(Encoding.ASCII.GetBytes("STEPPING"));
|
|
string a = Receive();
|
|
AgentMessage agentMessage = JsonConvert.DeserializeObject<AgentMessage>(a);
|
|
|
|
foreach (Brain brain in brains)
|
|
{
|
|
if (brain.brainType == BrainType.External)
|
|
{
|
|
string brainName = brain.gameObject.name;
|
|
|
|
Dictionary<int, float[]> actionDict = new Dictionary<int, float[]>();
|
|
for (int i = 0; i < current_agents[brainName].Count; i++)
|
|
{
|
|
if (brain.brainParameters.actionSpaceType == StateType.continuous)
|
|
{
|
|
actionDict.Add(current_agents[brainName][i],
|
|
agentMessage.action[brainName].GetRange(i * brain.brainParameters.actionSize, brain.brainParameters.actionSize).ToArray());
|
|
}
|
|
else
|
|
{
|
|
actionDict.Add(current_agents[brainName][i],
|
|
agentMessage.action[brainName].GetRange(i, 1).ToArray());
|
|
}
|
|
}
|
|
storedActions[brainName] = actionDict;
|
|
|
|
Dictionary<int, float[]> memoryDict = new Dictionary<int, float[]>();
|
|
for (int i = 0; i < current_agents[brainName].Count; i++)
|
|
{
|
|
memoryDict.Add(current_agents[brainName][i],
|
|
agentMessage.memory[brainName].GetRange(i * brain.brainParameters.memorySize, brain.brainParameters.memorySize).ToArray());
|
|
}
|
|
storedMemories[brainName] = memoryDict;
|
|
|
|
Dictionary<int, float> valueDict = new Dictionary<int, float>();
|
|
for (int i = 0; i < current_agents[brainName].Count; i++)
|
|
{
|
|
valueDict.Add(current_agents[brainName][i],
|
|
agentMessage.value[brainName][i]);
|
|
}
|
|
storedValues[brainName] = valueDict;
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
/// Returns the actions corrensponding to the brain called brainName that
|
|
/// were received throught the socket.
|
|
public Dictionary<int, float[]> GetDecidedAction(string brainName)
|
|
{
|
|
return storedActions[brainName];
|
|
}
|
|
|
|
/// Returns the memories corrensponding to the brain called brainName that
|
|
/// were received throught the socket.
|
|
public Dictionary<int, float[]> GetMemories(string brainName)
|
|
{
|
|
return storedMemories[brainName];
|
|
}
|
|
|
|
/// Returns the values corrensponding to the brain called brainName that
|
|
/// were received throught the socket.
|
|
public Dictionary<int, float> GetValues(string brainName)
|
|
{
|
|
return storedValues[brainName];
|
|
}
|
|
|
|
}
|