浏览代码

Add file check & reuse protobuf conversion functions (#1316)

/develop-generalizationTraining-TrainerController
GitHub 6 年前
当前提交
c4fa3893
共有 6 个文件被更改,包括 228 次插入227 次删除
  1. 9
      UnitySDK/Assets/ML-Agents/Examples/Pyramids/Brains/VisualPyramidsLearning.asset
  2. 9
      UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scenes/VisualPyramids.unity
  3. 113
      ml-agents/mlagents/envs/brain.py
  4. 232
      ml-agents/mlagents/envs/environment.py
  5. 19
      ml-agents/mlagents/envs/utilities.py
  6. 73
      ml-agents/mlagents/trainers/demo_loader.py

9
UnitySDK/Assets/ML-Agents/Examples/Pyramids/Brains/VisualPyramidsLearning.asset


m_Name: VisualPyramidsLearning
m_EditorClassIdentifier:
brainParameters:
vectorObservationSize: 1
vectorObservationSize: 0
vectorActionSize: 01000000
cameraResolutions: []
vectorActionSize: 05000000
cameraResolutions:
- width: 84
height: 84
blackAndWhite: 0
vectorActionDescriptions:
-
vectorActionSpaceType: 0

9
UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scenes/VisualPyramids.unity


broadcastHub:
broadcastingBrains:
- {fileID: 11400000, guid: 60f0ffcd08c3b43a6bdc746cfc0c4059, type: 2}
_brainsToControl: []
_brainsToControl:
- {fileID: 11400000, guid: 60f0ffcd08c3b43a6bdc746cfc0c4059, type: 2}
maxSteps: 0
trainingConfiguration:
width: 80

propertyPath: m_RootOrder
value: 4
objectReference: {fileID: 0}
- target: {fileID: 114021675794892514, guid: 85206c30964c245ee92be7d0ed13b3b8,
type: 2}
propertyPath: brain
value:
objectReference: {fileID: 11400000, guid: 60f0ffcd08c3b43a6bdc746cfc0c4059,
type: 2}
m_RemovedComponents: []
m_ParentPrefab: {fileID: 100100000, guid: 85206c30964c245ee92be7d0ed13b3b8, type: 2}
m_IsPrefabParent: 0

113
ml-agents/mlagents/envs/brain.py


import logging
import numpy as np
import io
from PIL import Image
logger = logging.getLogger("mlagents.envs")
class BrainInfo:

self.previous_text_actions = text_action
self.action_masks = action_mask
@staticmethod
def process_pixels(image_bytes, gray_scale):
"""
Converts byte array observation image into numpy array, re-sizes it,
and optionally converts it to grey scale
:param gray_scale: Whether to convert the image to grayscale.
:param image_bytes: input byte array corresponding to image
:return: processed numpy array of observation from environment
"""
s = bytearray(image_bytes)
image = Image.open(io.BytesIO(s))
s = np.array(image) / 255.0
if gray_scale:
s = np.mean(s, axis=2)
s = np.reshape(s, [s.shape[0], s.shape[1], 1])
return s
@staticmethod
def from_agent_proto(agent_info_list, brain_params):
"""
Converts list of agent infos to BrainInfo.
"""
vis_obs = []
for i in range(brain_params.number_visual_observations):
obs = [BrainInfo.process_pixels(x.visual_observations[i],
brain_params.camera_resolutions[i]['blackAndWhite'])
for x in agent_info_list]
vis_obs += [np.array(obs)]
if len(agent_info_list) == 0:
memory_size = 0
else:
memory_size = max([len(x.memories) for x in agent_info_list])
if memory_size == 0:
memory = np.zeros((0, 0))
else:
[x.memories.extend([0] * (memory_size - len(x.memories))) for x in agent_info_list]
memory = np.array([x.memories for x in agent_info_list])
total_num_actions = sum(brain_params.vector_action_space_size)
mask_actions = np.ones((len(agent_info_list), total_num_actions))
for agent_index, agent_info in enumerate(agent_info_list):
if agent_info.action_mask is not None:
if len(agent_info.action_mask) == total_num_actions:
mask_actions[agent_index, :] = [
0 if agent_info.action_mask[k] else 1 for k in range(total_num_actions)]
if any([np.isnan(x.reward) for x in agent_info_list]):
logger.warning("An agent had a NaN reward for brain " + brain_params.brain_name)
if any([np.isnan(x.stacked_vector_observation).any() for x in agent_info_list]):
logger.warning("An agent had a NaN observation for brain " + brain_params.brain_name)
brain_info = BrainInfo(
visual_observation=vis_obs,
vector_observation=np.nan_to_num(
np.array([x.stacked_vector_observation for x in agent_info_list])),
text_observations=[x.text_observation for x in agent_info_list],
memory=memory,
reward=[x.reward if not np.isnan(x.reward) else 0 for x in agent_info_list],
agents=[x.id for x in agent_info_list],
local_done=[x.done for x in agent_info_list],
vector_action=np.array([x.stored_vector_actions for x in agent_info_list]),
text_action=[x.stored_text_actions for x in agent_info_list],
max_reached=[x.max_step_reached for x in agent_info_list],
action_mask=mask_actions
)
return brain_info
# Renaming of dictionary of brain name to BrainInfo for clarity
def __init__(self, brain_name, brain_param):
def __init__(self, brain_name, vector_observation_space_size, num_stacked_vector_observations,
camera_resolutions, vector_action_space_size,
vector_action_descriptions, vector_action_space_type):
:param brain_name: Name of brain.
:param brain_param: Dictionary of brain parameters.
self.vector_observation_space_size = brain_param["vectorObservationSize"]
self.num_stacked_vector_observations = brain_param["numStackedVectorObservations"]
self.number_visual_observations = len(brain_param["cameraResolutions"])
self.camera_resolutions = brain_param["cameraResolutions"]
self.vector_action_space_size = brain_param["vectorActionSize"]
self.vector_action_descriptions = brain_param["vectorActionDescriptions"]
self.vector_action_space_type = ["discrete", "continuous"][brain_param["vectorActionSpaceType"]]
self.vector_observation_space_size = vector_observation_space_size
self.num_stacked_vector_observations = num_stacked_vector_observations
self.number_visual_observations = len(camera_resolutions)
self.camera_resolutions = camera_resolutions
self.vector_action_space_size = vector_action_space_size
self.vector_action_descriptions = vector_action_descriptions
self.vector_action_space_type = ["discrete", "continuous"][vector_action_space_type]
def __str__(self):
return '''Unity brain name: {}

self.vector_action_space_type,
str(self.vector_action_space_size),
', '.join(self.vector_action_descriptions))
@staticmethod
def from_proto(brain_param_proto):
"""
Converts brain parameter proto to BrainParameter object.
:param brain_param_proto: protobuf object.
:return: BrainParameter object.
"""
resolution = [{
"height": x.height,
"width": x.width,
"blackAndWhite": x.gray_scale
} for x in brain_param_proto.camera_resolutions]
brain_params = BrainParameters(brain_param_proto.brain_name,
brain_param_proto.vector_observation_size,
brain_param_proto.num_stacked_vector_observations,
resolution,
brain_param_proto.vector_action_size,
brain_param_proto.vector_action_descriptions,
brain_param_proto.vector_action_space_type)
return brain_params

232
ml-agents/mlagents/envs/environment.py


import atexit
import glob
import io
from .brain import BrainInfo, BrainParameters, AllBrainInfo
from .utilities import process_pixels
from .brain import AllBrainInfo, BrainInfo, BrainParameters
from .communicator_objects import UnityRLInput, UnityRLOutput, AgentActionProto,\
EnvironmentParametersProto, UnityRLInitializationInput, UnityRLInitializationOutput,\
from .communicator_objects import UnityRLInput, UnityRLOutput, AgentActionProto, \
EnvironmentParametersProto, UnityRLInitializationInput, UnityRLInitializationOutput, \
from .socket_communicator import SocketCommunicator
from PIL import Image
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("mlagents.envs")

self.port = base_port + worker_id
self._buffer_size = 12000
self._version_ = "API-5"
self._loaded = False # If true, this means the environment was successfully loaded
self.proc1 = None # The process that is started. If None, no process was started
self._loaded = False # If true, this means the environment was successfully loaded
self.proc1 = None # The process that is started. If None, no process was started
if file_name is None and worker_id!=0:
if file_name is None and worker_id != 0:
"If the environment name is None, the worker-id must be 0 in order to connect with the Editor.")
"If the environment name is None, "
"the worker-id must be 0 in order to connect with the Editor.")
if file_name is not None:
self.executable_launcher(file_name, docker_training, no_graphics)
else:

self._external_brain_names = []
for brain_param in aca_params.brain_parameters:
self._brain_names += [brain_param.brain_name]
resolution = [{
"height": x.height,
"width": x.width,
"blackAndWhite": x.gray_scale
} for x in brain_param.camera_resolutions]
self._brains[brain_param.brain_name] = \
BrainParameters(brain_param.brain_name, {
"vectorObservationSize": brain_param.vector_observation_size,
"numStackedVectorObservations": brain_param.num_stacked_vector_observations,
"cameraResolutions": resolution,
"vectorActionSize": brain_param.vector_action_size,
"vectorActionDescriptions": brain_param.vector_action_descriptions,
"vectorActionSpaceType": brain_param.vector_action_space_type
})
self._brains[brain_param.brain_name] = BrainParameters.from_proto(brain_param)
self._resetParameters = dict(aca_params.environment_parameters.float_parameters) # TODO
self._resetParameters = dict(aca_params.environment_parameters.float_parameters)
logger.info("\n'{0}' started successfully!\n{1}".format(self._academy_name, str(self)))
if self._num_external_brains == 0:
logger.warning(" No Learning Brains set to train found in the Unity Environment. "

def executable_launcher(self, file_name, docker_training, no_graphics):
cwd = os.getcwd()
file_name = (file_name.strip()
.replace('.app', '').replace('.exe', '').replace('.x86_64', '').replace('.x86', ''))
.replace('.app', '').replace('.exe', '').replace('.x86_64', '').replace('.x86',
''))
true_filename = os.path.basename(os.path.normpath(file_name))
logger.debug('The true file name is {}'.format(true_filename))
launch_string = None

launch_string = candidates[0]
elif platform == 'darwin':
candidates = glob.glob(os.path.join(cwd, file_name + '.app', 'Contents', 'MacOS', true_filename))
candidates = glob.glob(
os.path.join(cwd, file_name + '.app', 'Contents', 'MacOS', true_filename))
candidates = glob.glob(os.path.join(file_name + '.app', 'Contents', 'MacOS', true_filename))
candidates = glob.glob(
os.path.join(file_name + '.app', 'Contents', 'MacOS', true_filename))
candidates = glob.glob(os.path.join(cwd, file_name + '.app', 'Contents', 'MacOS', '*'))
candidates = glob.glob(
os.path.join(cwd, file_name + '.app', 'Contents', 'MacOS', '*'))
if len(candidates) == 0:
candidates = glob.glob(os.path.join(file_name + '.app', 'Contents', 'MacOS', '*'))
if len(candidates) > 0:

if not docker_training:
if no_graphics:
self.proc1 = subprocess.Popen(
[launch_string,'-nographics', '-batchmode',
[launch_string, '-nographics', '-batchmode',
'--port', str(self.port)])
else:
self.proc1 = subprocess.Popen(

Number of Brains: {1}
Number of Training Brains : {2}
Reset Parameters :\n\t\t{3}'''.format(self._academy_name, str(self._num_brains),
str(self._num_external_brains),
"\n\t\t".join([str(k) + " -> " + str(self._resetParameters[k])
for k in self._resetParameters])) + '\n' + \
str(self._num_external_brains),
"\n\t\t".join(
[str(k) + " -> " + str(self._resetParameters[k])
for k in self._resetParameters])) + '\n' + \
:return: AllBrainInfo : A Data structure corresponding to the initial reset state of the environment.
:return: AllBrainInfo : A data structure corresponding to the initial reset state of the environment.
"""
if config is None:
config = self._resetParameters

raise UnityEnvironmentException(
"The value for parameter '{0}'' must be an Integer or a Float.".format(k))
else:
raise UnityEnvironmentException("The parameter '{0}' is not a valid parameter.".format(k))
raise UnityEnvironmentException(
"The parameter '{0}' is not a valid parameter.".format(k))
if self._loaded:
outputs = self.communicator.exchange(

else:
raise UnityEnvironmentException("No Unity environment is loaded.")
def step(self, vector_action=None, memory=None, text_action=None, value=None) -> AllBrainInfo:
def step(self, vector_action=None, memory=None, text_action=None, value=None) -> AllBrainInfo:
Provides the environment with an action, moves the environment dynamics forward accordingly, and returns
observation, state, and reward information to the agent.
:param vector_action: Agent's vector action to send to environment. Can be a scalar or vector of int/floats.
:param memory: Vector corresponding to memory used for RNNs, frame-stacking, or other auto-regressive process.
Provides the environment with an action, moves the environment dynamics forward accordingly,
and returns observation, state, and reward information to the agent.
:param value: Value estimates provided by agents.
:param vector_action: Agent's vector action. Can be a scalar or vector of int/floats.
:param memory: Vector corresponding to memory used for recurrent policies.
:param text_action: Text action to send to environment for.
:return: AllBrainInfo : A Data structure corresponding to the new state of the environment.
"""

value = {} if value is None else value
# Check that environment is loaded, and episode is currently running.
if self._loaded and not self._global_done and self._global_done is not None:
if isinstance(vector_action, (int, np.int_, float, np.float_, list, np.ndarray)):
if self._num_external_brains == 1:

raise UnityActionException(
"There are no external brains in the environment, "
"step cannot take a memory input")
if isinstance(text_action, (str, list, np.ndarray)):
if self._num_external_brains == 1:
text_action = {self._external_brain_names[0]: text_action}

raise UnityActionException(
"There are no external brains in the environment, "
"step cannot take a value input")
if isinstance(value, (int, np.int_, float, np.float_, list, np.ndarray)):
if self._num_external_brains == 1:
value = {self._external_brain_names[0]: value}

"There are no external brains in the environment, "
"step cannot take a value input")
for brain_name in list(vector_action.keys()) + list(memory.keys()) + list(text_action.keys()):
for brain_name in list(vector_action.keys()) + list(memory.keys()) + list(
text_action.keys()):
for b in self._external_brain_names:
n_agent = self._n_agents[b]
if b not in vector_action:
# raise UnityActionException("You need to input an action for the brain {0}".format(b))
if self._brains[b].vector_action_space_type == "discrete":
vector_action[b] = [0.0] * n_agent * len(self._brains[b].vector_action_space_size)
for brain_name in self._external_brain_names:
n_agent = self._n_agents[brain_name]
if brain_name not in vector_action:
if self._brains[brain_name].vector_action_space_type == "discrete":
vector_action[brain_name] = [0.0] * n_agent * len(
self._brains[brain_name].vector_action_space_size)
vector_action[b] = [0.0] * n_agent * self._brains[b].vector_action_space_size[0]
vector_action[brain_name] = [0.0] * n_agent * \
self._brains[
brain_name].vector_action_space_size[0]
vector_action[b] = self._flatten(vector_action[b])
if b not in memory:
memory[b] = []
vector_action[brain_name] = self._flatten(vector_action[brain_name])
if brain_name not in memory:
memory[brain_name] = []
if memory[b] is None:
memory[b] = []
if memory[brain_name] is None:
memory[brain_name] = []
memory[b] = self._flatten(memory[b])
if b not in text_action:
text_action[b] = [""] * n_agent
memory[brain_name] = self._flatten(memory[brain_name])
if brain_name not in text_action:
text_action[brain_name] = [""] * n_agent
if text_action[b] is None:
text_action[b] = [""] * n_agent
if isinstance(text_action[b], str):
text_action[b] = [text_action[b]] * n_agent
if not ((len(text_action[b]) == n_agent) or len(text_action[b]) == 0):
if text_action[brain_name] is None:
text_action[brain_name] = [""] * n_agent
if isinstance(text_action[brain_name], str):
text_action[brain_name] = [text_action[brain_name]] * n_agent
number_text_actions = len(text_action[brain_name])
if not ((number_text_actions == n_agent) or number_text_actions == 0):
"There was a mismatch between the provided text_action and environment's expectation: "
"There was a mismatch between the provided text_action and "
"the environment's expectation: "
b, n_agent, len(text_action[b])))
if not ((self._brains[b].vector_action_space_type == "discrete" and len(
vector_action[b]) == n_agent * len(self._brains[b].vector_action_space_size)) or
(self._brains[b].vector_action_space_type == "continuous" and len(
vector_action[b]) == self._brains[b].vector_action_space_size[0] * n_agent)):
brain_name, n_agent, number_text_actions))
discrete_check = self._brains[brain_name].vector_action_space_type == "discrete"
expected_discrete_size = n_agent * len(
self._brains[brain_name].vector_action_space_size)
continuous_check = self._brains[brain_name].vector_action_space_type == "continuous"
expected_continuous_size = self._brains[brain_name].vector_action_space_size[
0] * n_agent
if not ((discrete_check and len(
vector_action[brain_name]) == expected_discrete_size) or
(continuous_check and len(
vector_action[brain_name]) == expected_continuous_size)):
"There was a mismatch between the provided action and environment's expectation: "
"There was a mismatch between the provided action and "
"the environment's expectation: "
.format(b, str(len(self._brains[b].vector_action_space_size) * n_agent)
if self._brains[b].vector_action_space_type == "discrete"
else str(self._brains[b].vector_action_space_size[0] * n_agent),
self._brains[b].vector_action_space_type,
str(vector_action[b])))
.format(brain_name, str(expected_discrete_size)
if discrete_check
else str(expected_continuous_size),
self._brains[brain_name].vector_action_space_type,
str(vector_action[brain_name])))
self._generate_step_input(vector_action, memory, text_action, value)
)
self._generate_step_input(vector_action, memory, text_action, value))
s = self._get_state(rl_output)
self._global_done = s[1]
state = self._get_state(rl_output)
self._global_done = state[1]
self._n_agents[_b] = len(s[0][_b].agents)
return s[0]
self._n_agents[_b] = len(state[0][_b].agents)
return state[0]
raise UnityActionException("The episode is completed. Reset the environment with 'reset()'")
raise UnityActionException(
"The episode is completed. Reset the environment with 'reset()'")
"You cannot conduct step without first calling reset. Reset the environment with 'reset()'")
"You cannot conduct step without first calling reset. "
"Reset the environment with 'reset()'")
def close(self):
"""

"""
_data = {}
global_done = output.global_done
for b in output.agentInfos:
agent_info_list = output.agentInfos[b].value
vis_obs = []
for i in range(self.brains[b].number_visual_observations):
obs = [process_pixels(x.visual_observations[i],
self.brains[b].camera_resolutions[i]['blackAndWhite'])
for x in agent_info_list]
vis_obs += [np.array(obs)]
if len(agent_info_list) == 0:
memory_size = 0
else:
memory_size = max([len(x.memories) for x in agent_info_list])
if memory_size == 0:
memory = np.zeros((0, 0))
else:
[x.memories.extend([0] * (memory_size - len(x.memories))) for x in agent_info_list]
memory = np.array([x.memories for x in agent_info_list])
total_num_actions = sum(self.brains[b].vector_action_space_size)
mask_actions = np.ones((len(agent_info_list), total_num_actions))
for agent_index, agent_info in enumerate(agent_info_list):
if agent_info.action_mask is not None:
if len(agent_info.action_mask) == total_num_actions:
mask_actions[agent_index, :] = [
0 if agent_info.action_mask[k] else 1 for k in range(total_num_actions)]
if any([np.isnan(x.reward) for x in agent_info_list]):
logger.warning("An agent had a NaN reward for brain "+b)
if any([np.isnan(x.stacked_vector_observation).any() for x in agent_info_list]):
logger.warning("An agent had a NaN observation for brain " + b)
_data[b] = BrainInfo(
visual_observation=vis_obs,
vector_observation=np.nan_to_num(np.array([x.stacked_vector_observation for x in agent_info_list])),
text_observations=[x.text_observation for x in agent_info_list],
memory=memory,
reward=[x.reward if not np.isnan(x.reward) else 0 for x in agent_info_list],
agents=[x.id for x in agent_info_list],
local_done=[x.done for x in agent_info_list],
vector_action=np.array([x.stored_vector_actions for x in agent_info_list]),
text_action=[x.stored_text_actions for x in agent_info_list],
max_reached=[x.max_step_reached for x in agent_info_list],
action_mask=mask_actions
)
for brain_name in output.agentInfos:
agent_info_list = output.agentInfos[brain_name].value
_data[brain_name] = BrainInfo.from_agent_proto(agent_info_list,
self.brains[brain_name])
return _data, global_done
def _generate_step_input(self, vector_action, memory, text_action, value) -> UnityRLInput:

_m_s = len(memory[b]) // n_agents
for i in range(n_agents):
action = AgentActionProto(
vector_actions=vector_action[b][i*_a_s: (i+1)*_a_s],
memories=memory[b][i*_m_s: (i+1)*_m_s],
vector_actions=vector_action[b][i * _a_s: (i + 1) * _a_s],
memories=memory[b][i * _m_s: (i + 1) * _m_s],
text_actions=text_action[b][i],
)
if b in value:

rl_in.command = 1
return self.wrap_unity_input(rl_in)
def send_academy_parameters(self, init_parameters: UnityRLInitializationInput) -> UnityRLInitializationOutput:
def send_academy_parameters(self,
init_parameters: UnityRLInitializationInput) -> UnityRLInitializationOutput:
inputs = UnityInput()
inputs.rl_initialization_input.CopyFrom(init_parameters)
return self.communicator.initialize(inputs).rl_initialization_output

19
ml-agents/mlagents/envs/utilities.py


from PIL import Image
import numpy as np
import io
def process_pixels(image_bytes, gray_scale):
"""
Converts byte array observation image into numpy array, re-sizes it,
and optionally converts it to grey scale
:param image_bytes: input byte array corresponding to image
:return: processed numpy array of observation from environment
"""
s = bytearray(image_bytes)
image = Image.open(io.BytesIO(s))
s = np.array(image) / 255.0
if gray_scale:
s = np.mean(s, axis=2)
s = np.reshape(s, [s.shape[0], s.shape[1], 1])
return s

73
ml-agents/mlagents/trainers/demo_loader.py


import numpy as np
import os
from mlagents.envs.utilities import process_pixels
from mlagents.envs.communicator_objects import *
from google.protobuf.internal.decoder import _DecodeVarint32

def brain_param_proto_to_obj(brain_param_proto):
resolution = [{
"height": x.height,
"width": x.width,
"blackAndWhite": x.gray_scale
} for x in brain_param_proto.camera_resolutions]
brain_params = BrainParameters(brain_param_proto.brain_name, {
"vectorObservationSize": brain_param_proto.vector_observation_size,
"numStackedVectorObservations": brain_param_proto.num_stacked_vector_observations,
"cameraResolutions": resolution,
"vectorActionSize": brain_param_proto.vector_action_size,
"vectorActionDescriptions": brain_param_proto.vector_action_descriptions,
"vectorActionSpaceType": brain_param_proto.vector_action_space_type
})
return brain_params
def agent_info_proto_to_brain_info(agent_info, brain_params):
vis_obs = []
agent_info_list = [agent_info]
for i in range(brain_params.number_visual_observations):
obs = [process_pixels(x.visual_observations[i],
brain_params.camera_resolutions[i]['blackAndWhite'])
for x in agent_info_list]
vis_obs += [np.array(obs)]
if len(agent_info_list) == 0:
memory_size = 0
else:
memory_size = max([len(x.memories) for x in agent_info_list])
if memory_size == 0:
memory = np.zeros((0, 0))
else:
[x.memories.extend([0] * (memory_size - len(x.memories))) for x in agent_info_list]
memory = np.array([x.memories for x in agent_info_list])
total_num_actions = sum(brain_params.vector_action_space_size)
mask_actions = np.ones((len(agent_info_list), total_num_actions))
for agent_index, agent_info in enumerate(agent_info_list):
if agent_info.action_mask is not None:
if len(agent_info.action_mask) == total_num_actions:
mask_actions[agent_index, :] = [
0 if agent_info.action_mask[k] else 1 for k in range(total_num_actions)]
if any([np.isnan(x.reward) for x in agent_info_list]):
logger.warning("An agent had a NaN reward.")
if any([np.isnan(x.stacked_vector_observation).any() for x in agent_info_list]):
logger.warning("An agent had a NaN observation.")
brain_info = BrainInfo(
visual_observation=vis_obs,
vector_observation=np.nan_to_num(
np.array([x.stacked_vector_observation for x in agent_info_list])),
text_observations=[x.text_observation for x in agent_info_list],
memory=memory,
reward=[x.reward if not np.isnan(x.reward) else 0 for x in agent_info_list],
agents=[x.id for x in agent_info_list],
local_done=[x.done for x in agent_info_list],
vector_action=np.array([x.stored_vector_actions for x in agent_info_list]),
text_action=[x.stored_text_actions for x in agent_info_list],
max_reached=[x.max_step_reached for x in agent_info_list],
action_mask=mask_actions
)
return brain_info
def make_demo_buffer(brain_infos, brain_params, sequence_length):
# Create and populate buffer using experiences
demo_buffer = Buffer()

:param file_path: Location of demonstration file (.demo).
:return: BrainParameter and list of BrainInfos containing demonstration data.
"""
# First 32 bytes of file dedicated to meta-data.
if not os.path.isfile(file_path):
raise FileNotFoundError("The demonstration file {} does not exist.".format(file_path))
file_extension = pathlib.Path(file_path).suffix
if file_extension != '.demo':
raise ValueError("The file is not a '.demo' file. Please provide a file with the "

if obs_decoded == 1:
brain_param_proto = BrainParametersProto()
brain_param_proto.ParseFromString(data[pos:pos + next_pos])
brain_params = brain_param_proto_to_obj(brain_param_proto)
brain_params = BrainParameters.from_proto(brain_param_proto)
brain_info = agent_info_proto_to_brain_info(agent_info, brain_params)
brain_info = BrainInfo.from_agent_proto([agent_info], brain_params)
brain_infos.append(brain_info)
if len(brain_infos) == total_expected:
break

正在加载...
取消
保存