浏览代码

[cleanup] Add a new type hint to call a dictionary of BrainInfo objects as an AllBrainInfo. Propagate this hint to all methods. Some pep8 cleanups.

/develop-generalizationTraining-TrainerController
eshvk 7 年前
当前提交
030ac5c5
共有 6 个文件被更改,包括 76 次插入68 次删除
  1. 6
      python/unityagents/brain.py
  2. 40
      python/unityagents/environment.py
  3. 21
      python/unitytrainers/bc/trainer.py
  4. 47
      python/unitytrainers/ppo/trainer.py
  5. 18
      python/unitytrainers/trainer.py
  6. 12
      python/unitytrainers/trainer_controller.py

6
python/unityagents/brain.py


from typing import Dict
class BrainInfo:
def __init__(self, observation, state, memory=None, reward=None, agents=None, local_done=None,
action=None, max_reached=None):

self.max_reached = max_reached
self.agents = agents
self.previous_actions = action
AllBrainInfo = Dict[str, BrainInfo]
class BrainParameters:

40
python/unityagents/environment.py


import subprocess
import struct
from .brain import BrainInfo, BrainParameters
from .brain import BrainInfo, BrainParameters, AllBrainInfo
from .exception import UnityEnvironmentException, UnityActionException, UnityTimeOutException
from .curriculum import Curriculum

self._global_done = None
self._academy_name = p["AcademyName"]
self._log_path = p["logPath"]
self._brains = {}
self._brains = AllBrainInfo()
self._brain_names = p["brainNames"]
self._external_brain_names = p["externalBrainNames"]
self._external_brain_names = [] if self._external_brain_names is None else self._external_brain_names

self._brains[self._brain_names[i]] = BrainParameters(self._brain_names[i], p["brainParameters"][i])
self._loaded = True
logger.info("\n'{}' started successfully!".format(self._academy_name))
if (self._num_external_brains == 0):
if self._num_external_brains == 0:
logger.warning(" No External Brains found in the Unity Environment. "
"You will not be able to pass actions to your agent(s).")
except UnityEnvironmentException:

@property
def curriculum(self):
return self._curriculum

@staticmethod
def _process_pixels(image_bytes=None, bw=False):
"""
Converts bytearray observation image into numpy array, resizes it, and optionally converts it to greyscale
:param image_bytes: input bytearray corresponding to image
Converts byte array observation image into numpy array, re-sizes it, and optionally converts it to grey scale
:param image_bytes: input byte array corresponding to image
:return: processed numpy array of observation from environment
"""
s = bytearray(image_bytes)

state_dict = json.loads(state)
return state_dict
def reset(self, train_mode=True, config=None, lesson=None):
def reset(self, train_mode=True, config=None, lesson=None) -> AllBrainInfo:
:return: A Data structure corresponding to the initial reset state of the environment.
:return: AllBrainInfo : A Data structure corresponding to the initial reset state of the environment.
elif config != {}:
logger.info("\nAcademy Reset with parameters : \t{0}"
.format(', '.join([str(x) + ' -> ' + str(config[x]) for x in config])))

else:
raise UnityEnvironmentException("No Unity environment is loaded.")
def _get_state(self):
def _get_state(self) -> AllBrainInfo:
:return: a dictionary BrainInfo objects.
:return: a dictionary of BrainInfo objects.
self._data = {}
self._data = AllBrainInfo()
for index in range(self._num_brains):
state_dict = self._get_state_dict()
b = state_dict["brain_name"]

observations.append(np.array(obs_n))
self._data[b] = BrainInfo(observations, states, memories, rewards, agents, dones, actions, max_reached=maxes)
self._data[b] = BrainInfo(observations, states, memories, rewards, agents,
dones, actions, max_reached=maxes)
try:
self._global_done = self._conn.recv(self._buffer_size).decode('utf-8') == 'True'

arr = [float(x) for x in arr]
return arr
def step(self, action=None, memory=None, value=None):
def step(self, action=None, memory=None, value=None) -> AllBrainInfo:
"""
Provides the environment with an action, moves the environment dynamics forward accordingly, and returns
observation, state, and reward information to the agent.

:return: A Data structure corresponding to the new state of the environment.
:return: AllBrainInfo : A Data structure corresponding to the new state of the environment.
"""
action = {} if action is None else action
memory = {} if memory is None else memory

raise UnityActionException(
"There was a mismatch between the provided memory and environment's expectation: "
"The brain {0} expected {1} memories but was given {2}"
.format(b, self._brains[b].memory_space_size * n_agent, len(memory[b])))
.format(b, self._brains[b].memory_space_size * n_agent, len(memory[b])))
(self._brains[b].action_space_type == "continuous" and len(
(self._brains[b].action_space_type == "continuous" and len(
.format(b, n_agent if self._brains[b].action_space_type == "discrete" else
str(self._brains[b].action_space_size * n_agent), self._brains[b].action_space_type,
str(action[b])))
.format(b, n_agent if self._brains[b].action_space_type == "discrete" else
str(self._brains[b].action_space_size * n_agent), self._brains[b].action_space_type,
str(action[b])))
self._conn.send(b"STEP")
self._send_action(action, memory, value)
return self._get_state()

21
python/unitytrainers/bc/trainer.py


import numpy as np
import tensorflow as tf
from unityagents import AllBrainInfo
from unitytrainers.bc.models import BehavioralCloningModel
from unitytrainers.buffer import Buffer
from unitytrainers.trainer import UnityTrainerException, Trainer

"""
return
def take_action(self, all_brain_info):
def take_action(self, all_brain_info: AllBrainInfo):
:param info: Current BrainInfo from environment.
:return: a tupple containing action, memories, values and an object
:param all_brain_info: AllBrainInfo from environment.
:return: a tuple containing action, memories, values and an object
to be passed to add experiences
"""
agent_brain = all_brain_info[self.brain_name]

agent_action = self.sess.run(run_list, feed_dict)
return agent_action, None, None, None
def add_experiences(self, info, next_info, take_action_outputs):
def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take_action_outputs):
:param info: Current BrainInfo.
:param next_info: Next BrainInfo.
:param curr_info: Current AllBrainInfo (Dictionary of all current brains and corresponding BrainInfo).
:param next_info: Next AllBrainInfo (Dictionary of all current brains and corresponding BrainInfo).
info_expert = info[self.brain_to_imitate]
info_expert = curr_info[self.brain_to_imitate]
next_info_expert = next_info[self.brain_to_imitate]
for agent_id in info_expert.agents:
if agent_id in next_info_expert.agents:

if self.use_observations:
for i, _ in enumerate(info.observations):
for i, _ in enumerate(curr_info.observations):
self.training_buffer[agent_id]['observations%d' % i].append(info_expert.observations[i][idx])
if self.use_states:
self.training_buffer[agent_id]['states'].append(info_expert.states[idx])

self.episode_steps[agent_id] = 0
self.episode_steps[agent_id] += 1
def process_experiences(self, info):
def process_experiences(self, info: AllBrainInfo):
:param info: Current BrainInfo
:param info: Current AllBrainInfo
"""
info_expert = info[self.brain_to_imitate]
for l in range(len(info_expert.agents)):

47
python/unitytrainers/ppo/trainer.py


import numpy as np
import tensorflow as tf
from unityagents import AllBrainInfo
from unitytrainers.buffer import Buffer
from unitytrainers.ppo.models import PPOModel
from unitytrainers.trainer import UnityTrainerException, Trainer

new_variance = var + (current_x - new_mean) * (current_x - mean)
return new_mean, new_variance
def take_action(self, info):
def take_action(self, all_brain_info: AllBrainInfo):
:param info: Current BrainInfo from environment.
:return: a tupple containing action, memories, values and an object
:param all_brain_info: A dictionary of brain names and BrainInfo from environment.
:return: a tuple containing action, memories, values and an object
info = info[self.brain_name]
feed_dict = {self.model.batch_size: len(info.states), self.model.sequence_length: 1}
curr_brain_info = all_brain_info[self.brain_name]
feed_dict = {self.model.batch_size: len(curr_brain_info.states), self.model.sequence_length: 1}
for i, _ in enumerate(info.observations):
feed_dict[self.model.observation_in[i]] = info.observations[i]
for i, _ in enumerate(curr_brain_info.observations):
feed_dict[self.model.observation_in[i]] = curr_brain_info.observations[i]
feed_dict[self.model.state_in] = info.states
feed_dict[self.model.state_in] = curr_brain_info.states
feed_dict[self.model.memory_in] = info.memories
feed_dict[self.model.memory_in] = curr_brain_info.memories
new_mean, new_variance = self.running_average(info.states, steps, self.model.running_mean,
new_mean, new_variance = self.running_average(curr_brain_info.states, steps, self.model.running_mean,
self.model.running_variance)
feed_dict[self.model.new_mean] = new_mean
feed_dict[self.model.new_variance] = new_variance

else:
return run_out[self.model.output], None, run_out[self.model.value], run_out
def add_experiences(self, info, next_info, take_action_outputs):
def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take_action_outputs):
:param info: Current BrainInfo.
:param next_info: Next BrainInfo.
:param curr_info: Dictionary of all current brains and corresponding BrainInfo.
:param next_info: Dictionary of all current brains and corresponding BrainInfo.
info = info[self.brain_name]
curr_info = curr_info[self.brain_name]
next_info = next_info[self.brain_name]
actions = take_action_outputs[self.model.output]
epsi = 0

value = take_action_outputs[self.model.value]
for agent_id in info.agents:
for agent_id in curr_info.agents:
idx = info.agents.index(agent_id)
idx = curr_info.agents.index(agent_id)
if not info.local_done[idx]:
if not curr_info.local_done[idx]:
for i, _ in enumerate(info.observations):
self.training_buffer[agent_id]['observations%d' % i].append(info.observations[i][idx])
for i, _ in enumerate(curr_info.observations):
self.training_buffer[agent_id]['observations%d' % i].append(curr_info.observations[i][idx])
self.training_buffer[agent_id]['states'].append(info.states[idx])
self.training_buffer[agent_id]['states'].append(curr_info.states[idx])
self.training_buffer[agent_id]['memory'].append(info.memories[idx])
self.training_buffer[agent_id]['memory'].append(curr_info.memories[idx])
if self.is_continuous:
self.training_buffer[agent_id]['epsilons'].append(epsi[idx])
self.training_buffer[agent_id]['actions'].append(actions[idx])

self.episode_steps[agent_id] = 0
self.episode_steps[agent_id] += 1
def process_experiences(self, info):
def process_experiences(self, info: AllBrainInfo):
:param info: Current BrainInfo
:param info: Dictionary of all current brains and corresponding BrainInfo.
"""
info = info[self.brain_name]

18
python/unitytrainers/trainer.py


import tensorflow as tf
from unityagents import UnityException
from unityagents import UnityException, AllBrainInfo
logger = logging.getLogger("unityagents")

"""
raise UnityTrainerException("The update_last_reward method was not implemented.")
def take_action(self, info):
def take_action(self, all_brain_info: AllBrainInfo):
:param info: Current BrainInfo from environment.
:return: a tupple containing action, memories, values and an object
:param all_brain_info: A dictionary of brain names and BrainInfo from environment.
:return: a tuple containing action, memories, values and an object
def add_experiences(self, info, next_info, take_action_outputs):
def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take_action_outputs):
:param info: Current BrainInfo.
:param next_info: Next BrainInfo.
:param curr_info: Current AllBrainInfo.
:param next_info: Next AllBrainInfo.
def process_experiences(self, info):
def process_experiences(self, info: AllBrainInfo):
:param info: Current BrainInfo
:param info: Dictionary of all current brains and corresponding BrainInfo.
"""
raise UnityTrainerException("The process_experiences method was not implemented.")

12
python/unitytrainers/trainer_controller.py


sess.run(init)
global_step = 0 # This is only for saving the model
self.env.curriculum.increment_lesson(self._get_progress())
info = self.env.reset(train_mode=self.fast_simulation)
curr_info = self.env.reset(train_mode=self.fast_simulation)
if self.train_model:
for brain_name, trainer in self.trainers.items():
trainer.write_tensorboard_text('Hyperparameters', trainer.parameters)

self.env.curriculum.increment_lesson(self._get_progress())
info = self.env.reset(train_mode=self.fast_simulation)
curr_info = self.env.reset(train_mode=self.fast_simulation)
for brain_name, trainer in self.trainers.items():
trainer.end_episode()
# Decide and take an action

take_action_memories[brain_name],
take_action_values[brain_name],
take_action_outputs[brain_name]) = trainer.take_action(info)
take_action_outputs[brain_name]) = trainer.take_action(curr_info)
trainer.add_experiences(info, new_info, take_action_outputs[brain_name])
info = new_info
trainer.add_experiences(curr_info, new_info, take_action_outputs[brain_name])
curr_info = new_info
trainer.process_experiences(info)
trainer.process_experiences(curr_info)
if trainer.is_ready_update() and self.train_model and trainer.get_step <= trainer.get_max_steps:
# Perform gradient descent with experience buffer
trainer.update_model()

正在加载...
取消
保存