using System.Collections; using System.Collections.Generic; using UnityEngine; using Newtonsoft.Json; using System.Linq; using System.Net.Sockets; using System.Text; using System.IO; /// Responsible for communication with Python API. public class ExternalCommunicator : Communicator { Academy academy; Dictionary> current_agents; List brains; Dictionary hasSentState; Dictionary> storedActions; Dictionary> storedMemories; Dictionary> storedValues; private int comPort; Socket sender; byte[] messageHolder; const int messageLength = 12000; StreamWriter logWriter; string logPath; const string api = "API-2"; private class StepMessage { public string brain_name { get; set; } public List agents { get; set; } public List states { get; set; } public List rewards { get; set; } public List actions { get; set; } public List memories { get; set; } public List dones { get; set; } } private class AgentMessage { public Dictionary> action { get; set; } public Dictionary> memory { get; set; } public Dictionary> value { get; set; } } private class ResetParametersMessage { public Dictionary parameters { get; set; } public bool train_model { get; set; } } /// Consrtuctor for the External Communicator public ExternalCommunicator(Academy aca) { academy = aca; brains = new List(); current_agents = new Dictionary>(); hasSentState = new Dictionary(); storedActions = new Dictionary>(); storedMemories = new Dictionary>(); storedValues = new Dictionary>(); } /// Adds the brain to the list of brains which have already decided their /// actions. public void SubscribeBrain(Brain brain) { brains.Add(brain); hasSentState[brain.gameObject.name] = false; } public bool CommunicatorHandShake(){ try { ReadArgs(); } catch { return false; } return true; } /// Contains the logic for the initializtation of the socket. public void InitializeCommunicator() { Application.logMessageReceived += HandleLog; logPath = Path.GetFullPath(".") + "/unity-environment.log"; logWriter = new StreamWriter(logPath, false); logWriter.WriteLine(System.DateTime.Now.ToString()); logWriter.WriteLine(" "); logWriter.Close(); messageHolder = new byte[messageLength]; // Create a TCP/IP socket. sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); sender.Connect("localhost", comPort); AcademyParameters accParamerters = new AcademyParameters(); accParamerters.brainParameters = new List(); accParamerters.brainNames = new List(); accParamerters.externalBrainNames = new List(); accParamerters.apiNumber = api; accParamerters.logPath = logPath; foreach (Brain b in brains) { accParamerters.brainParameters.Add(b.brainParameters); accParamerters.brainNames.Add(b.gameObject.name); if (b.brainType == BrainType.External) { accParamerters.externalBrainNames.Add(b.gameObject.name); } } accParamerters.AcademyName = academy.gameObject.name; accParamerters.resetParameters = academy.resetParameters; SendParameters(accParamerters); } void HandleLog(string logString, string stackTrace, LogType type) { logWriter = new StreamWriter(logPath, true); logWriter.WriteLine(type.ToString()); logWriter.WriteLine(logString); logWriter.WriteLine(stackTrace); logWriter.Close(); } /// Listens to the socket for a command and returns the corresponding /// External Command. public ExternalCommand GetCommand() { int location = sender.Receive(messageHolder); string message = Encoding.ASCII.GetString(messageHolder, 0, location); switch (message) { case "STEP": return ExternalCommand.STEP; case "RESET": return ExternalCommand.RESET; case "QUIT": return ExternalCommand.QUIT; default: return ExternalCommand.QUIT; } } /// Listens to the socket for the new resetParameters public Dictionary GetResetParameters() { sender.Send(Encoding.ASCII.GetBytes("CONFIG_REQUEST")); ResetParametersMessage resetParams = JsonConvert.DeserializeObject(Receive()); academy.isInference = !resetParams.train_model; return resetParams.parameters; } /// Used to read Python-provided environment parameters private void ReadArgs() { string[] args = System.Environment.GetCommandLineArgs(); string inputPort = ""; for (int i = 0; i < args.Length; i++) { if (args[i] == "--port") { inputPort = args[i + 1]; } } comPort = int.Parse(inputPort); } /// Sends Academy parameters to external agent private void SendParameters(AcademyParameters envParams) { string envMessage = JsonConvert.SerializeObject(envParams, Formatting.Indented); sender.Send(Encoding.ASCII.GetBytes(envMessage)); } /// Receives messages from external agent private string Receive() { int location = sender.Receive(messageHolder); string message = Encoding.ASCII.GetString(messageHolder, 0, location); return message; } /// Ends connection and closes environment private void OnApplicationQuit() { sender.Close(); sender.Shutdown(SocketShutdown.Both); } /// Contains logic for coverting texture into bytearray to send to /// external agent. private byte[] TexToByteArray(Texture2D tex) { byte[] bytes = tex.EncodeToPNG(); Object.DestroyImmediate(tex); Resources.UnloadUnusedAssets(); return bytes; } private byte[] AppendLength(byte[] input){ byte[] newArray = new byte[input.Length + 4]; input.CopyTo(newArray, 4); System.BitConverter.GetBytes(input.Length).CopyTo(newArray, 0); return newArray; } /// Collects the information from the brains and sends it accross the socket public void giveBrainInfo(Brain brain) { string brainName = brain.gameObject.name; current_agents[brainName] = new List(brain.agents.Keys); List concatenatedStates = new List(); List concatenatedRewards = new List(); List concatenatedMemories = new List(); List concatenatedDones = new List(); List concatenatedActions = new List(); Dictionary> collectedObservations = brain.CollectObservations(); Dictionary> collectedStates = brain.CollectStates(); Dictionary collectedRewards = brain.CollectRewards(); Dictionary collectedMemories = brain.CollectMemories(); Dictionary collectedDones = brain.CollectDones(); Dictionary collectedActions = brain.CollectActions(); foreach (int id in current_agents[brainName]) { concatenatedStates = concatenatedStates.Concat(collectedStates[id]).ToList(); concatenatedRewards.Add(collectedRewards[id]); concatenatedMemories = concatenatedMemories.Concat(collectedMemories[id].ToList()).ToList(); concatenatedDones.Add(collectedDones[id]); concatenatedActions = concatenatedActions.Concat(collectedActions[id].ToList()).ToList(); } StepMessage message = new StepMessage() { brain_name = brainName, agents = current_agents[brainName], states = concatenatedStates, rewards = concatenatedRewards, actions = concatenatedActions, memories = concatenatedMemories, dones = concatenatedDones }; string envMessage = JsonConvert.SerializeObject(message, Formatting.Indented); sender.Send(AppendLength(Encoding.ASCII.GetBytes(envMessage))); Receive(); int i = 0; foreach (resolution res in brain.brainParameters.cameraResolutions) { foreach (int id in current_agents[brainName]) { sender.Send(AppendLength(TexToByteArray(brain.ObservationToTex(collectedObservations[id][i], res.width, res.height)))); Receive(); } i++; } hasSentState[brainName] = true; if (hasSentState.Values.All(x => x)) { // if all the brains listed have sent their state sender.Send(Encoding.ASCII.GetBytes((academy.done ? "True" : "False"))); List brainNames = hasSentState.Keys.ToList(); foreach (string k in brainNames) { hasSentState[k] = false; } } } /// Listens for actions, memories, and values and sends them /// to the corrensponding brains. public void UpdateActions() { // TO MODIFY -------------------------------------------- sender.Send(Encoding.ASCII.GetBytes("STEPPING")); string a = Receive(); AgentMessage agentMessage = JsonConvert.DeserializeObject(a); foreach (Brain brain in brains) { if (brain.brainType == BrainType.External) { string brainName = brain.gameObject.name; Dictionary actionDict = new Dictionary(); for (int i = 0; i < current_agents[brainName].Count; i++) { if (brain.brainParameters.actionSpaceType == StateType.continuous) { actionDict.Add(current_agents[brainName][i], agentMessage.action[brainName].GetRange(i * brain.brainParameters.actionSize, brain.brainParameters.actionSize).ToArray()); } else { actionDict.Add(current_agents[brainName][i], agentMessage.action[brainName].GetRange(i, 1).ToArray()); } } storedActions[brainName] = actionDict; Dictionary memoryDict = new Dictionary(); for (int i = 0; i < current_agents[brainName].Count; i++) { memoryDict.Add(current_agents[brainName][i], agentMessage.memory[brainName].GetRange(i * brain.brainParameters.memorySize, brain.brainParameters.memorySize).ToArray()); } storedMemories[brainName] = memoryDict; Dictionary valueDict = new Dictionary(); for (int i = 0; i < current_agents[brainName].Count; i++) { valueDict.Add(current_agents[brainName][i], agentMessage.value[brainName][i]); } storedValues[brainName] = valueDict; } } } /// Returns the actions corrensponding to the brain called brainName that /// were received throught the socket. public Dictionary GetDecidedAction(string brainName) { return storedActions[brainName]; } /// Returns the memories corrensponding to the brain called brainName that /// were received throught the socket. public Dictionary GetMemories(string brainName) { return storedMemories[brainName]; } /// Returns the values corrensponding to the brain called brainName that /// were received throught the socket. public Dictionary GetValues(string brainName) { return storedValues[brainName]; } }