浏览代码

Fix for CC models w/ RNN and Curiosity (#860)

/develop-generalizationTraining-TrainerController
GitHub 7 年前
当前提交
6e6e8d96
共有 1 个文件被更改,包括 7 次插入3 次删除
  1. 10
      python/unitytrainers/ppo/trainer.py

10
python/unitytrainers/ppo/trainer.py


feed_dict = {self.model.batch_size: len(curr_brain_info.vector_observations),
self.model.sequence_length: 1}
if self.use_recurrent:
feed_dict[self.model.prev_action] = curr_brain_info.previous_vector_actions.flatten()
if not self.is_continuous_action:
feed_dict[self.model.prev_action] = curr_brain_info.previous_vector_actions.flatten()
if curr_brain_info.memories.shape[1] == 0:
curr_brain_info.memories = np.zeros((len(curr_brain_info.agents), self.m_size))
feed_dict[self.model.memory_in] = curr_brain_info.memories

if self.use_curiosity:
feed_dict = {self.model.batch_size: len(curr_info.vector_observations), self.model.sequence_length: 1}
if self.is_continuous_action:
feed_dict[self.model.output] = next_info.previous_vector_actions.flatten()
feed_dict[self.model.output] = next_info.previous_vector_actions
else:
feed_dict[self.model.action_holder] = next_info.previous_vector_actions.flatten()
if self.use_visual_obs:

if self.use_vector_obs:
feed_dict[self.model.vector_in] = curr_info.vector_observations
feed_dict[self.model.next_vector_in] = next_info.vector_observations
if self.use_recurrent:
if curr_info.memories.shape[1] == 0:
curr_info.memories = np.zeros((len(curr_info.agents), self.m_size))
feed_dict[self.model.memory_in] = curr_info.memories
intrinsic_rewards = self.sess.run(self.model.intrinsic_reward,
feed_dict=feed_dict) * float(self.has_updated)
return intrinsic_rewards

正在加载...
取消
保存