浏览代码

Fix for visual-only imitation learning

/hotfix-v0.9.2a
Arthur Juliani 6 年前
当前提交
6b359062
共有 1 个文件被更改,包括 19 次插入23 次删除
  1. 42
      python/unitytrainers/bc/trainer.py

42
python/unitytrainers/bc/trainer.py


self.training_buffer = Buffer()
self.is_continuous_action = (env.brains[brain_name].vector_action_space_type == "continuous")
self.is_continuous_observation = (env.brains[brain_name].vector_observation_space_type == "continuous")
self.use_observations = (env.brains[brain_name].number_visual_observations > 0)
if self.use_observations:
self.use_visual_observations = (env.brains[brain_name].number_visual_observations > 0)
if self.use_visual_observations:
self.use_states = (env.brains[brain_name].vector_observation_space_size > 0)
self.use_vector_observations = (env.brains[brain_name].vector_observation_space_size > 0)
self.summary_path = trainer_parameters['summary_path']
if not os.path.exists(self.summary_path):
os.makedirs(self.summary_path)

agent_brain = all_brain_info[self.brain_name]
feed_dict = {self.model.dropout_rate: 1.0, self.model.sequence_length: 1}
if self.use_observations:
if self.use_visual_observations:
if self.use_states:
if self.use_vector_observations:
if self.use_recurrent:
agent_action, memories = self.sess.run(self.inference_run_list, feed_dict)
return agent_action, memories, None, None
else:

info_teacher_record, next_info_teacher_record = "true", "true"
if info_teacher_record == "true" and next_info_teacher_record == "true":
if not stored_info_teacher.local_done[idx]:
if self.use_observations:
if self.use_visual_observations:
if self.use_states:
if self.use_vector_observations:
self.training_buffer[agent_id]['vector_observations']\
.append(stored_info_teacher.vector_observations[idx])
if self.use_recurrent:

"""
Uses training_buffer to update model.
"""
self.training_buffer.update_buffer.shuffle()
batch_losses = []
for j in range(

end = (j + 1) * self.n_sequences
batch_states = np.array(_buffer['vector_observations'][start:end])
batch_actions = np.array(_buffer['actions'][start:end])
feed_dict[self.model.true_action] = batch_actions.reshape([-1, self.brain.vector_action_space_size])
else:
feed_dict[self.model.true_action] = batch_actions.reshape([-1])
if not self.is_continuous_observation:
feed_dict[self.model.vector_in] = batch_states.reshape([-1, self.brain.num_stacked_vector_observations])
feed_dict[self.model.true_action] = np.array(_buffer['actions'][start:end]).\
reshape([-1, self.brain.vector_action_space_size])
feed_dict[self.model.vector_in] = batch_states.reshape([-1, self.brain.vector_observation_space_size *
self.brain.num_stacked_vector_observations])
if self.use_observations:
feed_dict[self.model.true_action] = np.array(_buffer['actions'][start:end]).reshape([-1])
if self.use_vector_observations:
if not self.is_continuous_observation:
feed_dict[self.model.vector_in] = np.array(_buffer['vector_observations'][start:end])\
.reshape([-1, self.brain.num_stacked_vector_observations])
else:
feed_dict[self.model.vector_in] = np.array(_buffer['vector_observations'][start:end])\
.reshape([-1, self.brain.vector_observation_space_size * self.brain.num_stacked_vector_observations])
if self.use_visual_observations:
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.model.visual_in[i]] = _obs.reshape([-1, _w, _h, _c])
feed_dict[self.model.visual_in[i]] = _obs
loss, _ = self.sess.run([self.model.loss, self.model.update], feed_dict=feed_dict)
batch_losses.append(loss)
if len(batch_losses) > 0:

正在加载...
取消
保存