浏览代码

fixes on imitation trainer, now works with demo (#274)

/develop-generalizationTraining-TrainerController
Arthur Juliani 7 年前
当前提交
3b8755d2
共有 1 个文件被更改,包括 7 次插入19 次删除
  1. 26
      python/trainers/imitation_trainer.py

26
python/trainers/imitation_trainer.py


# # Unity ML Agents
# ## ML-Agent Learning (PPO)
# Contains an implementation of PPO as described [here](https://arxiv.org/abs/1707.06347).
# ## ML-Agent Learning (Imitation)
# Contains an implementation of Imitation Learning
import logging
import os

:param training: Whether the trainer is set for training.
"""
self.param_keys = [ 'is_imitation', 'brain_to_imitate', 'batch_size', 'time_horizon', 'graph_scope',
'summary_freq', 'max_steps']
'summary_freq', 'max_steps', 'batches_per_epoch']
for k in self.param_keys:
if k not in trainer_parameters:

self.variable_scope = trainer_parameters['graph_scope']
self.brain_to_imitate = trainer_parameters['brain_to_imitate']
self.batch_size = trainer_parameters['batch_size']
self.batches_per_epoch = trainer_parameters['batches_per_epoch']
self.step = 0
self.cumulative_rewards = {}
self.episode_steps = {}

Returns the maximum number of steps. Is used to know when the trainer should be stopped.
:return: The maximum number of steps of the trainer
"""
return self.trainer_parameters['max_steps']
return float(self.trainer_parameters['max_steps'])
@property
def get_step(self):

Returns wether or not the trainer has enough elements to run update model
:return: A boolean corresponding to wether or not update_model() can be run
"""
return len(self.training_buffer.update_buffer['actions']) > 1
return len(self.training_buffer.update_buffer['actions']) > self.batch_size
# num_epoch = self.trainer_parameters['num_epoch']
# strange from there
for j in range(len(self.training_buffer.update_buffer['actions']) // self.batch_size):
for j in range(min(len(self.training_buffer.update_buffer['actions']) // self.batch_size, self.batches_per_epoch)):
# batch_states = shuffle_states[j * batch_size:(j + 1) * batch_size]
# batch_actions = shuffle_actions[j * batch_size:(j + 1) * batch_size]
batch_actions = np.array(_buffer['actions'][j * batch_size:(j + 1) * batch_size])
if not self.is_continuous:
feed_dict = {

else:
self.stats['losses'].append(0)
self.training_buffer.reset_all()
# Do we clear it at some point ?
# self.training_buffer.reset_update_buffer()
def write_summary(self, lesson_number):
"""

正在加载...
取消
保存