浏览代码

Remove unnecessary feed_dicts for GAIL and Curiosity (#2348)

/develop-generalizationTraining-TrainerController
GitHub 6 年前
当前提交
dd0d2a10
共有 2 个文件被更改,包括 1 次插入25 次删除
  1. 20
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
  2. 6
      ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py

20
ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py


feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[i]
if self.policy.use_vec_obs:
feed_dict[self.model.next_vector_in] = next_info.vector_observations
if self.policy.use_recurrent:
if current_info.memories.shape[1] == 0:
current_info.memories = self.policy.make_empty_memory(
len(current_info.agents)
)
feed_dict[self.policy.model.memory_in] = current_info.memories
unscaled_reward = self.policy.sess.run(
self.model.intrinsic_reward, feed_dict=feed_dict
)

feed_dict[self.policy.model.output_pre] = mini_batch["actions_pre"].reshape(
[-1, self.policy.model.act_size[0]]
)
feed_dict[self.policy.model.epsilon] = mini_batch[
"random_normal_epsilon"
].reshape([-1, self.policy.model.act_size[0]])
if self.policy.use_recurrent:
feed_dict[self.policy.model.prev_action] = mini_batch[
"prev_action"
].reshape([-1, len(self.policy.model.act_size)])
feed_dict[self.policy.model.action_masks] = mini_batch[
"action_mask"
].reshape([-1, sum(self.policy.brain.vector_action_space_size)])
if self.policy.use_vec_obs:
feed_dict[self.policy.model.vector_in] = mini_batch["vector_obs"].reshape(
[-1, self.policy.vec_obs_size]

)
else:
feed_dict[self.model.next_visual_in[i]] = _obs
if self.policy.use_recurrent:
mem_in = mini_batch["memory"][:, 0, :]
feed_dict[self.policy.model.memory_in] = mem_in
self.has_updated = True
run_out = self.policy._execute_model(feed_dict, self.update_dict)
return run_out

6
ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py


feed_dict[
self.policy.model.action_holder
] = next_info.previous_vector_actions
if self.policy.use_recurrent:
if current_info.memories.shape[1] == 0:
current_info.memories = self.policy.make_empty_memory(
len(current_info.agents)
)
feed_dict[self.policy.model.memory_in] = current_info.memories
unscaled_reward = self.policy.sess.run(
self.model.intrinsic_reward, feed_dict=feed_dict
)

正在加载...
取消
保存