|
|
|
|
|
|
1 for shape in behavior_spec.observation_shapes if len(shape) == 3 |
|
|
|
) |
|
|
|
self.use_continuous_act = self.behavior_spec.action_spec.is_continuous() |
|
|
|
self.previous_action_dict: Dict[str, Dict[str, np.ndarray]] = {} |
|
|
|
self.previous_action_dict: Dict[str, np.ndarray] = {} |
|
|
|
self.memory_dict: Dict[str, np.ndarray] = {} |
|
|
|
self.normalize = trainer_settings.network_settings.normalize |
|
|
|
self.use_recurrent = self.network_settings.memory is not None |
|
|
|
|
|
|
if agent_id in self.memory_dict: |
|
|
|
self.memory_dict.pop(agent_id) |
|
|
|
|
|
|
|
def make_empty_previous_action(self, num_agents: int) -> Dict[str, np.ndarray]: |
|
|
|
def make_empty_previous_action(self, num_agents: int) -> np.ndarray: |
|
|
|
act_dict: Dict[str, np.ndarray] = {} |
|
|
|
action_tuple = self.behavior_spec.action_spec.empty_action(num_agents) |
|
|
|
if self.behavior_spec.action_spec.continuous_size > 0: |
|
|
|
act_dict["continuous_action"] = action_tuple.continuous |
|
|
|
if self.behavior_spec.action_spec.discrete_size > 0: |
|
|
|
act_dict["discrete_action"] = action_tuple.discrete |
|
|
|
return act_dict |
|
|
|
return np.zeros( |
|
|
|
(num_agents, self.behavior_spec.action_spec.discrete_size), dtype=np.int32 |
|
|
|
) |
|
|
|
if action_dict is None: |
|
|
|
if action_dict is None or "discrete_action" not in action_dict: |
|
|
|
agent_action_dict: Dict[str, np.ndarray] = {} |
|
|
|
for act_type in action_dict: |
|
|
|
agent_action_dict[act_type] = action_dict[act_type][index, :] |
|
|
|
self.previous_action_dict[agent_id] = agent_action_dict |
|
|
|
self.previous_action_dict[agent_id] = action_dict["discrete_action"][ |
|
|
|
index, : |
|
|
|
] |
|
|
|
def retrieve_previous_action(self, agent_ids: List[str]) -> Dict[str, np.ndarray]: |
|
|
|
action_dict = self.make_empty_previous_action(len(agent_ids)) |
|
|
|
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray: |
|
|
|
action_matrix = self.make_empty_previous_action(len(agent_ids)) |
|
|
|
for act_type in action_dict: |
|
|
|
action_dict[act_type][index, :] = self.previous_action_dict[ |
|
|
|
agent_id |
|
|
|
][act_type] |
|
|
|
return action_dict |
|
|
|
action_matrix[index, :] = self.previous_action_dict[agent_id] |
|
|
|
return action_matrix |
|
|
|
|
|
|
|
def remove_previous_action(self, agent_ids): |
|
|
|
for agent_id in agent_ids: |
|
|
|