浏览代码

Minor Optimizations (#836)

/develop-generalizationTraining-TrainerController
Arthur Juliani 6 年前
当前提交
5d402be9
共有 9 个文件被更改,包括 170 次插入182 次删除
  1. 1
      docs/Training-ML-Agents.md
  2. 7
      python/learn.py
  3. 24
      python/unityagents/environment.py
  4. 19
      python/unitytrainers/bc/trainer.py
  5. 14
      python/unitytrainers/buffer.py
  6. 15
      python/unitytrainers/models.py
  7. 234
      python/unitytrainers/ppo/trainer.py
  8. 23
      python/unitytrainers/trainer.py
  9. 15
      python/unitytrainers/trainer_controller.py

1
docs/Training-ML-Agents.md


* `--train` – Specifies whether to train model or only run in inference mode. When training, **always** use the `--train` option.
* `--worker-id=<n>` – When you are running more than one training environment at the same time, assign each a unique worker-id number. The worker-id is added to the communication port opened between the current instance of learn.py and the ExternalCommunicator object in the Unity environment. Defaults to 0.
* `--docker-target-name=<dt>` – The Docker Volume on which to store curriculum, executable and model files. See [Using Docker](Using-Docker.md).
* `--no-graphics` - Specify this option to run the Unity executable in `-batchmode` and doesn't initialize the graphics driver. Use this only if your training doesn't involve visual observations (reading from Pixels). See [here](https://docs.unity3d.com/Manual/CommandLineArguments.html) for more details.
### Training config file

7
python/learn.py


--slow Whether to run the game at training speed [default: False].
--train Whether to train model, or only run inference [default: False].
--worker-id=<n> Number to add to communication port (5005). Used for multi-environment [default: 0].
--docker-target-name=<dt> Docker Volume to store curriculum, executable and model files [default: Empty].
--docker-target-name=<dt> Docker Volume to store curriculum, executable and model files [default: Empty].
--no-graphics Whether to run the Unity simulator in no-graphics mode [default: False].
'''
options = docopt(_USAGE)

curriculum_file = None
lesson = int(options['--lesson'])
fast_simulation = not bool(options['--slow'])
no_graphics = options['--no-graphics']
# Constants
# Assumption that this yaml is present in same dir as this file

tc = TrainerController(env_path, run_id, save_freq, curriculum_file, fast_simulation, load_model, train_model,
worker_id, keep_checkpoints, lesson, seed, docker_target_name, TRAINER_CONFIG_PATH)
worker_id, keep_checkpoints, lesson, seed, docker_target_name, TRAINER_CONFIG_PATH,
no_graphics)
tc.start_learning()

24
python/unityagents/environment.py


class UnityEnvironment(object):
def __init__(self, file_name=None, worker_id=0,
base_port=5005, curriculum=None,
seed=0, docker_training=False):
seed=0, docker_training=False, no_graphics=False):
"""
Starts a new unity environment and establishes a connection with the environment.
Notice: Currently communication between Unity and Python takes place over an open socket without authentication.

:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
:int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
:param docker_training: Informs this class whether the process is being run within a container.
:param no_graphics: Whether to run the Unity simulator in no-graphics mode
"""
atexit.register(self._close)

self._loaded = False # If true, this means the environment was successfully loaded
self.proc1 = None # The process that is started. If None, no process was started
self.executable_launcher(file_name, docker_training)
self.executable_launcher(file_name, docker_training, no_graphics)
else:
logger.info("Ready to connect with the Editor.")
self._loaded = True

def external_brain_names(self):
return self._external_brain_names
def executable_launcher(self, file_name, docker_training):
def executable_launcher(self, file_name, docker_training, no_graphics):
cwd = os.getcwd()
file_name = (file_name.strip()
.replace('.app', '').replace('.exe', '').replace('.x86_64', '').replace('.x86', ''))

logger.debug("This is the launch string {}".format(launch_string))
# Launch Unity environment
if not docker_training:
self.proc1 = subprocess.Popen(
[launch_string,
'--port', str(self.port)])
if no_graphics:
self.proc1 = subprocess.Popen(
[launch_string,'-nographics', '-batchmode',
'--port', str(self.port)])
else:
self.proc1 = subprocess.Popen(
[launch_string, '--port', str(self.port)])
else:
"""
Comments for future maintenance:

agent_info_list = output.agentInfos[b].value
vis_obs = []
for i in range(self.brains[b].number_visual_observations):
obs = [
self._process_pixels(x.visual_observations[i], self.brains[b].camera_resolutions[i]['blackAndWhite'])
obs = [self._process_pixels(x.visual_observations[i],
self.brains[b].camera_resolutions[i]['blackAndWhite'])
for x in agent_info_list]
vis_obs += [np.array(obs)]
if len(agent_info_list) == 0:

if memory_size == 0:
memory = np.zeros((0,0))
memory = np.zeros((0, 0))
else:
[x.memories.extend([0] * (memory_size - len(x.memories))) for x in agent_info_list]
memory = np.array([x.memories for x in agent_info_list])

19
python/unitytrainers/bc/trainer.py


normalize=False,
use_recurrent=trainer_parameters['use_recurrent'],
brain=self.brain)
self.inference_run_list = [self.model.sample_action]
if self.use_recurrent:
self.inference_run_list += [self.model.memory_out]
def __str__(self):

else:
return 0
def increment_step(self):
def increment_step_and_update_last_reward(self):
Increment the step count of the trainer
Increment the step count of the trainer and Updates the last reward
def update_last_reward(self):
"""
Updates the last reward
"""
return
def take_action(self, all_brain_info: AllBrainInfo):

agent_brain = all_brain_info[self.brain_name]
feed_dict = {self.model.dropout_rate: 1.0, self.model.sequence_length: 1}
run_list = [self.model.sample_action]
if self.use_observations:
for i, _ in enumerate(agent_brain.visual_observations):
feed_dict[self.model.visual_in[i]] = agent_brain.visual_observations[i]

if agent_brain.memories.shape[1] == 0:
agent_brain.memories = np.zeros((len(agent_brain.agents), self.m_size))
feed_dict[self.model.memory_in] = agent_brain.memories
run_list += [self.model.memory_out]
agent_action, memories = self.sess.run(run_list, feed_dict)
agent_action, memories = self.sess.run(self.inference_run_list, feed_dict)
agent_action = self.sess.run(run_list, feed_dict)
agent_action = self.sess.run(self.inference_run_list, feed_dict)
return agent_action, None, None, None
def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take_action_outputs):

14
python/unitytrainers/buffer.py


class AgentBufferField(list):
"""
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to his
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to his
AgentBufferField with the append method.
"""

self[:] = []
self[:] = list(np.array(data))
def get_batch(self, batch_size=None, training_length=None, sequential=True):
def get_batch(self, batch_size=None, training_length=1, sequential=True):
:param batch_size: The number of elements to retrieve. If None:
:param batch_size: The number of elements to retrieve. If None:
:param sequential: If true and training_length is not None: the elements
:param sequential: If true and training_length is not None: the elements
if training_length is None:
# When the training length is None, the method returns a list of elements,
if training_length == 1:
# When the training length is 1, the method returns a list of elements,
# not a list of sequences of elements.
if batch_size is None:
# If batch_size is None : All the elements of the AgentBufferField are returned.

def check_length(self, key_list):
"""
Some methods will require that some fields have the same length.
check_length will return true if the fields in key_list
check_length will return true if the fields in key_list
have the same length.
:param key_list: The fields which length will be compared
"""

15
python/unitytrainers/models.py


initializer=tf.zeros_initializer())
self.running_variance = tf.get_variable("running_variance", [self.o_size], trainable=False,
dtype=tf.float32, initializer=tf.ones_initializer())
self.new_mean = tf.placeholder(shape=[self.o_size], dtype=tf.float32, name='new_mean')
self.new_variance = tf.placeholder(shape=[self.o_size], dtype=tf.float32, name='new_variance')
self.update_mean = tf.assign(self.running_mean, self.new_mean)
self.update_variance = tf.assign(self.running_variance, self.new_variance)
self.update_mean, self.update_variance = self.create_normalizer_update(self.vector_in)
self.normalized_state = tf.clip_by_value((self.vector_in - self.running_mean) / tf.sqrt(
self.running_variance / (tf.cast(self.global_step, tf.float32) + 1)), -5, 5,

else:
self.vector_in = tf.placeholder(shape=[None, 1], dtype=tf.int32, name='vector_observation')
return self.vector_in
def create_normalizer_update(self, vector_input):
mean_current_observation = tf.reduce_mean(vector_input, axis=0)
new_mean = self.running_mean + (mean_current_observation - self.running_mean) / \
tf.cast(self.global_step + 1, tf.float32)
new_variance = self.running_variance + (mean_current_observation - new_mean) * \
(mean_current_observation - self.running_mean)
update_mean = tf.assign(self.running_mean, new_mean)
update_variance = tf.assign(self.running_variance, new_variance)
return update_mean, update_variance
@staticmethod
def create_continuous_observation_encoder(observation_input, h_size, activation, num_layers, scope, reuse):

234
python/unitytrainers/ppo/trainer.py


self.use_recurrent = trainer_parameters["use_recurrent"]
self.use_curiosity = bool(trainer_parameters['use_curiosity'])
self.sequence_length = 1
self.step = 0
if self.use_recurrent:
if self.m_size == 0:
raise UnityTrainerException("The memory size for brain {0} is 0 even though the trainer uses recurrent."
.format(brain_name))

self.summary_writer = tf.summary.FileWriter(self.summary_path)
self.inference_run_list = [self.model.output, self.model.all_probs, self.model.value,
self.model.entropy, self.model.learning_rate]
if self.is_continuous_action:
self.inference_run_list.append(self.model.output_pre)
if self.use_recurrent:
self.inference_run_list.extend([self.model.memory_out])
if (self.is_training and self.is_continuous_observation and
self.use_vector_obs and self.trainer_parameters['normalize']):
self.inference_run_list.extend([self.model.update_mean, self.model.update_variance])
return '''Hypermarameters for the PPO Trainer of brain {0}: \n{1}'''.format(
return '''Hyperparameters for the PPO Trainer of brain {0}: \n{1}'''.format(
self.brain_name, '\n'.join(['\t{0}:\t{1}'.format(x, self.trainer_parameters[x]) for x in self.param_keys]))
@property

Returns the number of steps the trainer has performed
:return: the step count of the trainer
"""
return self.sess.run(self.model.global_step)
return self.step
@property
def get_last_reward(self):

"""
return self.sess.run(self.model.last_reward)
def increment_step(self):
def increment_step_and_update_last_reward(self):
Increment the step count of the trainer
"""
self.sess.run(self.model.increment_step)
def update_last_reward(self):
"""
Updates the last reward
Increment the step count of the trainer and Updates the last reward
self.sess.run(self.model.update_reward, feed_dict={self.model.new_reward: mean_reward})
def running_average(self, data, steps, running_mean, running_variance):
"""
Computes new running mean and variances.
:param data: New piece of data.
:param steps: Total number of data so far.
:param running_mean: TF op corresponding to stored running mean.
:param running_variance: TF op corresponding to stored running variance.
:return: New mean and variance values.
"""
mean, var = self.sess.run([running_mean, running_variance])
current_x = np.mean(data, axis=0)
new_mean = mean + (current_x - mean) / (steps + 1)
new_variance = var + (current_x - new_mean) * (current_x - mean)
return new_mean, new_variance
self.sess.run([self.model.update_reward,
self.model.increment_step],
feed_dict={self.model.new_reward: mean_reward})
else:
self.sess.run(self.model.increment_step)
self.step = self.sess.run(self.model.global_step)
def take_action(self, all_brain_info: AllBrainInfo):
"""

to be passed to add experiences
"""
steps = self.get_step
feed_dict = {self.model.batch_size: len(curr_brain_info.vector_observations), self.model.sequence_length: 1}
run_list = [self.model.output, self.model.all_probs, self.model.value, self.model.entropy,
self.model.learning_rate]
if self.is_continuous_action:
run_list.append(self.model.output_pre)
feed_dict = {self.model.batch_size: len(curr_brain_info.vector_observations),
self.model.sequence_length: 1}
feed_dict[self.model.prev_action] = np.reshape(curr_brain_info.previous_vector_actions, [-1])
feed_dict[self.model.prev_action] = curr_brain_info.previous_vector_actions.flatten()
run_list += [self.model.memory_out]
if (self.is_training and self.is_continuous_observation and
self.use_vector_obs and self.trainer_parameters['normalize']):
new_mean, new_variance = self.running_average(
curr_brain_info.vector_observations, steps, self.model.running_mean, self.model.running_variance)
feed_dict[self.model.new_mean] = new_mean
feed_dict[self.model.new_variance] = new_variance
run_list = run_list + [self.model.update_mean, self.model.update_variance]
values = self.sess.run(self.inference_run_list, feed_dict=feed_dict)
run_out = dict(zip(self.inference_run_list, values))
values = self.sess.run(run_list, feed_dict=feed_dict)
run_out = dict(zip(run_list, values))
return (run_out[self.model.output],
run_out[self.model.memory_out],
[str(v) for v in run_out[self.model.value]],
run_out)
return run_out[self.model.output], run_out[self.model.memory_out], None, run_out
return (run_out[self.model.output],
None,
[str(v) for v in run_out[self.model.value]],
run_out)
return run_out[self.model.output], None, None, run_out
def add_experiences(self, curr_all_info: AllBrainInfo, next_all_info: AllBrainInfo, take_action_outputs):
def generate_intrinsic_rewards(self, curr_info, next_info):
Adds experiences to each agent's experience history.
:param curr_all_info: Dictionary of all current brains and corresponding BrainInfo.
:param next_all_info: Dictionary of all current brains and corresponding BrainInfo.
:param take_action_outputs: The outputs of the take action method.
Generates intrinsic reward used for Curiosity-based training.
:param curr_info: Current BrainInfo.
:param next_info: Next BrainInfo.
:return: Intrinsic rewards for all agents.
curr_info = curr_all_info[self.brain_name]
next_info = next_all_info[self.brain_name]
intrinsic_rewards = np.array([])
run_list = [self.model.intrinsic_reward]
run_list.append(self.model.output)
feed_dict[self.model.output] = next_info.previous_vector_actions.flatten()
feed_dict[self.model.action_holder] = np.reshape(take_action_outputs[self.model.output], [-1])
feed_dict[self.model.action_holder] = next_info.previous_vector_actions.flatten()
for i, _ in enumerate(curr_info.visual_observations):
for i in range(len(curr_info.visual_observations)):
if self.use_recurrent:
feed_dict[self.model.prev_action] = np.reshape(curr_info.previous_vector_actions, [-1])
if curr_info.memories.shape[1] == 0:
curr_info.memories = np.zeros((len(curr_info.agents), self.m_size))
feed_dict[self.model.memory_in] = curr_info.memories
run_list += [self.model.memory_out]
intrinsic_rewards = self.sess.run(self.model.intrinsic_reward, feed_dict=feed_dict) * \
float(self.has_updated)
intrinsic_rewards = self.sess.run(self.model.intrinsic_reward,
feed_dict=feed_dict) * float(self.has_updated)
return intrinsic_rewards
else:
return None
def generate_value_estimate(self, brain_info, idx):
"""
Generates value estimates for bootstrapping.
:param brain_info: BrainInfo to be used for bootstrapping.
:param idx: Index in BrainInfo of agent.
:return: Value estimate.
"""
feed_dict = {self.model.batch_size: 1, self.model.sequence_length: 1}
if self.use_visual_obs:
for i in range(len(brain_info.visual_observations)):
feed_dict[self.model.visual_in[i]] = brain_info.visual_observations[i][idx]
if self.use_vector_obs:
feed_dict[self.model.vector_in] = [brain_info.vector_observations[idx]]
if self.use_recurrent:
if brain_info.memories.shape[1] == 0:
brain_info.memories = np.zeros(
(len(brain_info.vector_observations), self.m_size))
feed_dict[self.model.memory_in] = [brain_info.memories[idx]]
if not self.is_continuous_action and self.use_recurrent:
feed_dict[self.model.prev_action] = brain_info.previous_vector_actions[idx].flatten()
value_estimate = self.sess.run(self.model.value, feed_dict)
return value_estimate
def add_experiences(self, curr_all_info: AllBrainInfo, next_all_info: AllBrainInfo, take_action_outputs):
"""
Adds experiences to each agent's experience history.
:param curr_all_info: Dictionary of all current brains and corresponding BrainInfo.
:param next_all_info: Dictionary of all current brains and corresponding BrainInfo.
:param take_action_outputs: The outputs of the take action method.
"""
curr_info = curr_all_info[self.brain_name]
next_info = next_all_info[self.brain_name]
intrinsic_rewards = self.generate_intrinsic_rewards(curr_info, next_info)
for agent_id in curr_info.agents:
self.training_buffer[agent_id].last_brain_info = curr_info

stored_info = self.training_buffer[agent_id].last_brain_info
stored_take_action_outputs = self.training_buffer[agent_id].last_take_action_outputs
if stored_info is None:
continue
else:
if stored_info is not None:
idx = stored_info.agents.index(agent_id)
next_idx = next_info.agents.index(agent_id)
if not stored_info.local_done[idx]:

else:
bootstrapping_info = info
idx = l
feed_dict = {self.model.batch_size: len(bootstrapping_info.vector_observations),
self.model.sequence_length: 1}
if self.use_visual_obs:
for i in range(len(bootstrapping_info.visual_observations)):
feed_dict[self.model.visual_in[i]] = bootstrapping_info.visual_observations[i]
if self.use_vector_obs:
feed_dict[self.model.vector_in] = bootstrapping_info.vector_observations
if self.use_recurrent:
if bootstrapping_info.memories.shape[1] == 0:
bootstrapping_info.memories = np.zeros(
(len(bootstrapping_info.vector_observations), self.m_size))
feed_dict[self.model.memory_in] = bootstrapping_info.memories
if not self.is_continuous_action and self.use_recurrent:
feed_dict[self.model.prev_action] = np.reshape(bootstrapping_info.previous_vector_actions, [-1])
value_next = self.sess.run(self.model.value, feed_dict)[idx]
value_next = self.generate_value_estimate(bootstrapping_info, idx)
self.training_buffer[agent_id]['advantages'].set(
get_gae(

gamma=self.trainer_parameters['gamma'],
lambd=self.trainer_parameters['lambd'])
)
lambd=self.trainer_parameters['lambd']))
self.training_buffer.append_update_buffer(agent_id,
batch_size=None, training_length=self.sequence_length)
self.training_buffer.append_update_buffer(agent_id, batch_size=None,
training_length=self.sequence_length)
self.training_buffer[agent_id].reset_agent()
if info.local_done[l]:

Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to whether or not update_model() can be run
"""
return len(self.training_buffer.update_buffer['actions']) > \
max(int(self.trainer_parameters['buffer_size'] / self.sequence_length), 1)
size_of_buffer = len(self.training_buffer.update_buffer['actions'])
return size_of_buffer > max(int(self.trainer_parameters['buffer_size'] / self.sequence_length), 1)
num_epoch = self.trainer_parameters['num_epoch']
num_epoch = self.trainer_parameters['num_epoch']
buffer = self.training_buffer.update_buffer
_buffer = self.training_buffer.update_buffer
self.model.mask_input: np.array(_buffer['masks'][start:end]).reshape(
[-1]),
self.model.returns_holder: np.array(_buffer['discounted_returns'][start:end]).reshape(
[-1]),
self.model.old_value: np.array(_buffer['value_estimates'][start:end]).reshape([-1]),
self.model.advantage: np.array(_buffer['advantages'][start:end]).reshape([-1, 1]),
self.model.all_old_probs: np.array(
_buffer['action_probs'][start:end]).reshape([-1, self.brain.vector_action_space_size])}
self.model.mask_input: np.array(buffer['masks'][start:end]).flatten(),
self.model.returns_holder: np.array(buffer['discounted_returns'][start:end]).flatten(),
self.model.old_value: np.array(buffer['value_estimates'][start:end]).flatten(),
self.model.advantage: np.array(buffer['advantages'][start:end]).reshape([-1, 1]),
self.model.all_old_probs: np.array(buffer['action_probs'][start:end]).reshape(
[-1, self.brain.vector_action_space_size])}
feed_dict[self.model.output_pre] = np.array(
_buffer['actions_pre'][start:end]).reshape([-1, self.brain.vector_action_space_size])
feed_dict[self.model.output_pre] = np.array(buffer['actions_pre'][start:end]).reshape(
[-1, self.brain.vector_action_space_size])
feed_dict[self.model.action_holder] = np.array(
_buffer['actions'][start:end]).reshape([-1])
feed_dict[self.model.action_holder] = np.array(buffer['actions'][start:end]).flatten()
feed_dict[self.model.prev_action] = np.array(
_buffer['prev_action'][start:end]).reshape([-1])
feed_dict[self.model.prev_action] = np.array(buffer['prev_action'][start:end]).flatten()
feed_dict[self.model.vector_in] = np.array(
_buffer['vector_obs'][start:end]).reshape(
[-1, self.brain.vector_observation_space_size * self.brain.num_stacked_vector_observations])
total_observation_length = self.brain.vector_observation_space_size * \
self.brain.num_stacked_vector_observations
feed_dict[self.model.vector_in] = np.array(buffer['vector_obs'][start:end]).reshape(
[-1, total_observation_length])
feed_dict[self.model.next_vector_obs] = np.array(
_buffer['next_vector_obs'][start:end]).reshape(
[-1,
self.brain.vector_observation_space_size * self.brain.num_stacked_vector_observations])
feed_dict[self.model.next_vector_obs] = np.array(buffer['next_vector_obs'][start:end])\
.reshape([-1, total_observation_length])
feed_dict[self.model.vector_in] = np.array(
_buffer['vector_obs'][start:end]).reshape([-1, self.brain.num_stacked_vector_observations])
feed_dict[self.model.vector_in] = np.array(buffer['vector_obs'][start:end]).reshape(
[-1, self.brain.num_stacked_vector_observations])
_obs = np.array(_buffer['visual_obs%d' % i][start:end])
_obs = np.array(buffer['visual_obs%d' % i][start:end])
_obs = np.array(_buffer['next_visual_obs%d' % i][start:end])
_obs = np.array(buffer['next_visual_obs%d' % i][start:end])
mem_in = np.array(_buffer['memory'][start:end])[:, 0, :]
mem_in = np.array(buffer['memory'][start:end])[:, 0, :]
feed_dict[self.model.memory_in] = mem_in
run_list = [self.model.value_loss, self.model.policy_loss, self.model.update_batch]

self.stats['forward_loss'].append(np.mean(forward_total))
self.stats['inverse_loss'].append(np.mean(inverse_total))
self.training_buffer.reset_update_buffer()
def discount_rewards(r, gamma=0.99, value_next=0.0):
"""

23
python/unitytrainers/trainer.py


"""
raise UnityTrainerException("The get_last_reward property was not implemented.")
def increment_step(self):
"""
Increment the step count of the trainer
"""
raise UnityTrainerException("The increment_step method was not implemented.")
def update_last_reward(self):
def increment_step_and_update_last_reward(self):
Updates the last reward
Increment the step count of the trainer and updates the last reward
raise UnityTrainerException("The update_last_reward method was not implemented.")
raise UnityTrainerException("The increment_step_and_update_last_reward method was not implemented.")
def take_action(self, all_brain_info: AllBrainInfo):
"""

"""
if (self.get_step % self.trainer_parameters['summary_freq'] == 0 and self.get_step != 0 and
self.is_training and self.get_step <= self.get_max_steps):
steps = self.get_step
.format(self.brain_name, steps,
.format(self.brain_name, self.get_step,
.format(self.brain_name, steps))
.format(self.brain_name, self.get_step))
summary = tf.Summary()
for key in self.stats:
if len(self.stats[key]) > 0:

summary.value.add(tag='Info/Lesson', simple_value=lesson_number)
self.summary_writer.add_summary(summary, steps)
self.summary_writer.add_summary(summary, self.get_step)
self.summary_writer.flush()
def write_tensorboard_text(self, key, input_dict):

:param input_dict: A dictionary that will be displayed in a table on Tensorboard.
"""
try:
s_op = tf.summary.text(key,
tf.convert_to_tensor(([[str(x), str(input_dict[x])] for x in input_dict]))
)
s_op = tf.summary.text(key, tf.convert_to_tensor(([[str(x), str(input_dict[x])] for x in input_dict])))
s = self.sess.run(s_op)
self.summary_writer.add_summary(s, self.get_step)
except:

15
python/unitytrainers/trainer_controller.py


class TrainerController(object):
def __init__(self, env_path, run_id, save_freq, curriculum_file, fast_simulation, load, train,
worker_id, keep_checkpoints, lesson, seed, docker_target_name, trainer_config_path):
worker_id, keep_checkpoints, lesson, seed, docker_target_name, trainer_config_path,
no_graphics):
:param env_path: Location to the environment executable to be loaded.
:param run_id: The sub-directory name for model and summary statistics
:param save_freq: Frequency at which to save model

:param seed: Random seed used for training.
:param docker_target_name: Name of docker volume that will contain all data.
:param trainer_config_path: Fully qualified path to location of trainer configuration file
:param no_graphics: Whether to run the Unity simulator in no-graphics mode
"""
self.trainer_config_path = trainer_config_path
if env_path is not None:

self.model_path = '/{docker_target_name}/models/{run_id}'.format(
docker_target_name=docker_target_name,
run_id=run_id)
if env_path is not None :
if env_path is not None:
env_path = '/{docker_target_name}/{env_name}'.format(docker_target_name=docker_target_name,
env_name=env_path)
if curriculum_file is None:

tf.set_random_seed(self.seed)
self.env = UnityEnvironment(file_name=env_path, worker_id=self.worker_id,
curriculum=self.curriculum_file, seed=self.seed,
docker_training=self.docker_training)
docker_training=self.docker_training,
no_graphics=no_graphics)
if env_path is None:
self.env_name = 'editor_'+self.env.academy_name
else:

# Write training statistics to Tensorboard.
trainer.write_summary(self.env.curriculum.lesson_number)
if self.train_model and trainer.get_step <= trainer.get_max_steps:
trainer.increment_step()
trainer.update_last_reward()
if self.train_model and trainer.get_step <= trainer.get_max_steps:
trainer.increment_step_and_update_last_reward()
if self.train_model:
global_step += 1
if global_step % self.save_freq == 0 and global_step != 0 and self.train_model:
# Save Tensorflow model

正在加载...
取消
保存