|
|
|
|
|
|
import sys |
|
|
|
from typing import List, Dict, Deque, TypeVar, Generic |
|
|
|
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Set |
|
|
|
from mlagents_envs.base_env import BatchedStepResult |
|
|
|
from mlagents_envs.base_env import BatchedStepResult, StepResult |
|
|
|
from mlagents.trainers.trajectory import Trajectory, AgentExperience |
|
|
|
from mlagents.trainers.tf_policy import TFPolicy |
|
|
|
from mlagents.trainers.policy import Policy |
|
|
|
|
|
|
:param stats_category: The category under which to write the stats. Usually, this comes from the Trainer. |
|
|
|
""" |
|
|
|
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list) |
|
|
|
self.last_step_result: Dict[str, BatchedStepResult] = {} |
|
|
|
self.last_step_result: Dict[str, Tuple[StepResult, int]] = {} |
|
|
|
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while |
|
|
|
# grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1). |
|
|
|
self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {} |
|
|
|
|
|
|
"Policy/Learning Rate", take_action_outputs["learning_rate"] |
|
|
|
) |
|
|
|
|
|
|
|
terminated_agents: List[str] = [] |
|
|
|
terminated_agents: Set[str] = set() |
|
|
|
self.last_take_action_outputs[global_id] = take_action_outputs |
|
|
|
if global_id in self.last_step_result: # Don't store if agent just reset |
|
|
|
self.last_take_action_outputs[global_id] = take_action_outputs |
|
|
|
|
|
|
|
for _id in batched_step_result.agent_id: # Assume agent_id is 1-D |
|
|
|
local_id = int( |
|
|
|
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
stored_step = self.last_step_result.get(global_id, None) |
|
|
|
stored_agent_step, idx = self.last_step_result.get(global_id, (None, None)) |
|
|
|
if stored_step is not None and stored_take_action_outputs is not None: |
|
|
|
if stored_agent_step is not None and stored_take_action_outputs is not None: |
|
|
|
stored_agent_step = stored_step.get_agent_step_result(local_id) |
|
|
|
idx = stored_step.agent_id_to_index[local_id] |
|
|
|
obs = stored_agent_step.obs |
|
|
|
if not stored_agent_step.done: |
|
|
|
if self.policy.use_recurrent: |
|
|
|
|
|
|
"Environment/Episode Length", |
|
|
|
self.episode_steps.get(global_id, 0), |
|
|
|
) |
|
|
|
terminated_agents += [global_id] |
|
|
|
terminated_agents.add(global_id) |
|
|
|
self.last_step_result[global_id] = batched_step_result |
|
|
|
|
|
|
|
if "action" in take_action_outputs: |
|
|
|
self.policy.save_previous_action( |
|
|
|
previous_action.agent_ids, take_action_outputs["action"] |
|
|
|
# Index is needed to grab from last_take_action_outputs |
|
|
|
self.last_step_result[global_id] = ( |
|
|
|
curr_agent_step, |
|
|
|
batched_step_result.agent_id_to_index[_id], |
|
|
|
for _gid in action_global_agent_ids: |
|
|
|
# If the ID doesn't have a last step result, the agent just reset, |
|
|
|
# don't store the action. |
|
|
|
if _gid in self.last_step_result: |
|
|
|
if "action" in take_action_outputs: |
|
|
|
self.policy.save_previous_action( |
|
|
|
[_gid], take_action_outputs["action"] |
|
|
|
) |
|
|
|
|
|
|
|
def _clean_agent_data(self, global_id: str) -> None: |
|
|
|
""" |
|
|
|
Removes the data for an Agent. |
|
|
|
|
|
|
del self.last_step_result[global_id] |
|
|
|
del self.last_step_result[global_id] |
|
|
|
self.policy.remove_previous_action([global_id]) |
|
|
|
self.policy.remove_memories([global_id]) |
|
|
|
|
|
|
|