浏览代码

modified the socket to receive states and images of any size

/develop-generalizationTraining-TrainerController
vincentpierre 7 年前
当前提交
c16e0ac3
共有 2 个文件被更改,包括 22 次插入5 次删除
  1. 16
      python/unityagents/environment.py
  2. 11
      unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs

16
python/unityagents/environment.py


import os
import socket
import subprocess
import struct
from .brain import BrainInfo, BrainParameters
from .exception import UnityEnvironmentException, UnityActionException

atexit.register(self.close)
self.port = base_port + worker_id
self._buffer_size = 120000
self._buffer_size = 12000
self._loaded = False
self._open_socket = False

for k in self._resetParameters])) + '\n' + \
'\n'.join([str(self._brains[b]) for b in self._brains])
def _recv_bytes(self):
s = self._conn.recv(self._buffer_size)
message_length = struct.unpack("I", bytearray(s[:4]))[0]
s = s[4:]
while len(s) != message_length:
s += self._conn.recv(self._buffer_size)
return s
def _get_state_image(self, bw):
"""
Receives observation from socket, and confirms.

s = self._conn.recv(self._buffer_size)
s = self._recv_bytes()
s = self._process_pixels(image_bytes=s, bw=bw)
self._conn.send(b"RECEIVED")
return s

Receives dictionary of state information from socket, and confirms.
:return:
"""
state = self._conn.recv(self._buffer_size).decode('utf-8')
state = self._recv_bytes().decode('utf-8')
self._conn.send(b"RECEIVED")
state_dict = json.loads(state)
return state_dict

11
unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs


return bytes;
}
private byte[] AppendLength(byte[] input){
byte[] newArray = new byte[input.Length + 4];
input.CopyTo(newArray, 4);
System.BitConverter.GetBytes(input.Length).CopyTo(newArray, 0);
return newArray;
}
/// Collects the information from the brains and sends it accross the socket
public void giveBrainInfo(Brain brain)
{

dones = concatenatedDones
};
string envMessage = JsonConvert.SerializeObject(message, Formatting.Indented);
sender.Send(Encoding.ASCII.GetBytes(envMessage));
sender.Send(AppendLength(Encoding.ASCII.GetBytes(envMessage)));
Receive();
int i = 0;
foreach (resolution res in brain.brainParameters.cameraResolutions)

sender.Send(TexToByteArray(brain.ObservationToTex(collectedObservations[id][i], res.width, res.height)));
sender.Send(AppendLength(TexToByteArray(brain.ObservationToTex(collectedObservations[id][i], res.width, res.height))));
Receive();
}
i++;

正在加载...
取消
保存