|
|
|
|
|
|
) # Needed for mypy to pass since ndarray has no content type |
|
|
|
curr_agent_step = batched_step_result.get_agent_step_result(local_id) |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
stored_step = self.last_step_result.get(global_id, None) |
|
|
|
stored_agent_step, idx = self.last_step_result.get(global_id, (None, None)) |
|
|
|
if stored_step is not None and stored_take_action_outputs is not None: |
|
|
|
if stored_agent_step is not None and stored_take_action_outputs is not None: |
|
|
|
stored_agent_step = stored_step.get_agent_step_result(local_id) |
|
|
|
idx = stored_step.agent_id_to_index[local_id] |
|
|
|
obs = stored_agent_step.obs |
|
|
|
if not stored_agent_step.done: |
|
|
|
if self.policy.use_recurrent: |
|
|
|
|
|
|
elif not curr_agent_step.done: |
|
|
|
self.episode_steps[global_id] += 1 |
|
|
|
|
|
|
|
self.last_step_result[global_id] = batched_step_result |
|
|
|
for _gid in action_global_agent_ids: |
|
|
|
if _gid in self.last_step_result: |
|
|
|
if not self.last_step_result[_gid][0].done: |
|
|
|
if "action" in take_action_outputs: |
|
|
|
self.policy.save_previous_action( |
|
|
|
[global_id], take_action_outputs["action"] |
|
|
|
) |
|
|
|
else: |
|
|
|
# If it was done, delete it |
|
|
|
del self.last_step_result[_gid] |
|
|
|
if "action" in take_action_outputs: |
|
|
|
self.policy.save_previous_action( |
|
|
|
previous_action.agent_ids, take_action_outputs["action"] |
|
|
|
# Index is needed to grab from last_take_action_outputs |
|
|
|
self.last_step_result[global_id] = ( |
|
|
|
curr_agent_step, |
|
|
|
batched_step_result.agent_id_to_index[_id], |
|
|
|
) |
|
|
|
|
|
|
|
for terminated_id in terminated_agents: |
|
|
|
|
|
|
del self.last_take_action_outputs[global_id] |
|
|
|
del self.episode_steps[global_id] |
|
|
|
del self.episode_rewards[global_id] |
|
|
|
del self.last_step_result[global_id] |
|
|
|
self.policy.remove_previous_action([global_id]) |
|
|
|
self.policy.remove_memories([global_id]) |
|
|
|
|
|
|
|