浏览代码

Fix bug; add critic_obs to buffer

/comms-grad
Ervin Teng 4 年前
当前提交
f479ce83
共有 2 个文件被更改,包括 9 次插入7 次删除
  1. 14
      ml-agents/mlagents/trainers/agent_processor.py
  2. 2
      ml-agents/mlagents/trainers/trajectory.py

14
ml-agents/mlagents/trainers/agent_processor.py


if global_id in self.last_experience:
experience = self.last_experience[global_id]
terminated = isinstance(step, TerminalStep)
# Add remaining obs to AgentExperience
for _id, _exp in self.last_experience.items():
if _id == global_id:
continue
else:
self.last_experience[global_id].collab_obs.append(_exp.obs)
# Add the value outputs if needed
self.experience_buffers[global_id].append(experience)
self.episode_rewards[global_id] += step.reward

len(self.experience_buffers[global_id]) >= self.max_trajectory_length
or terminated
):
# Add remaining obs to AgentExperience
for _id, _exp in self.last_experience.items():
if _id == global_id:
continue
else:
self.last_experience[global_id].collab_obs.append(_exp.obs)
next_obs = step.obs
trajectory = Trajectory(
steps=self.experience_buffers[global_id],

self._safe_delete(self.last_step_result, global_id)
self._safe_delete(self.episode_steps, global_id)
self._safe_delete(self.episode_rewards, global_id)
self._safe_delete(self.last_experience, global_id)
self.policy.remove_previous_action([global_id])
self.policy.remove_memories([global_id])

2
ml-agents/mlagents/trainers/trajectory.py


agent_buffer_trajectory["next_visual_obs%d" % i].append(
next_vec_vis_obs.visual_observations[i]
)
agent_buffer_trajectory["critic_obs"].append(exp.collab_obs)
agent_buffer_trajectory["vector_obs"].append(
vec_vis_obs.vector_observations
)

正在加载...
取消
保存