|
|
|
|
|
|
|
|
|
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
|
|
class TeammateStatus: |
|
|
|
class GroupmateStatus: |
|
|
|
""" |
|
|
|
Stores data related to an agent's teammate. |
|
|
|
""" |
|
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
|
|
class AgentExperience: |
|
|
|
obs: List[np.ndarray] |
|
|
|
teammate_status: List[TeammateStatus] |
|
|
|
group_status: List[GroupmateStatus] |
|
|
|
reward: float |
|
|
|
done: bool |
|
|
|
action: ActionTuple |
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class TeamObsUtil: |
|
|
|
class GroupObsUtil: |
|
|
|
return f"team_obs_{index}" |
|
|
|
return f"group_obs_{index}" |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_name_at_next(index: int) -> str: |
|
|
|
|
|
|
return f"team_obs_next_{index}" |
|
|
|
return f"group_obs_next_{index}" |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _padded_time_to_batch( |
|
|
|
|
|
|
""" |
|
|
|
# Find the first observation. This should be USUALLY O(1) |
|
|
|
obs_shape = None |
|
|
|
for _team_obs in agent_buffer_field: |
|
|
|
if _team_obs: |
|
|
|
obs_shape = _team_obs[0].shape |
|
|
|
for _group_obs in agent_buffer_field: |
|
|
|
if _group_obs: |
|
|
|
obs_shape = _group_obs[0].shape |
|
|
|
break |
|
|
|
# If there were no critic obs at all |
|
|
|
if obs_shape is None: |
|
|
|
|
|
|
separated_obs: List[np.array] = [] |
|
|
|
for i in range(num_obs): |
|
|
|
separated_obs.append( |
|
|
|
TeamObsUtil._padded_time_to_batch(batch[TeamObsUtil.get_name_at(i)]) |
|
|
|
GroupObsUtil._padded_time_to_batch(batch[GroupObsUtil.get_name_at(i)]) |
|
|
|
result = TeamObsUtil._transpose_list_of_lists(separated_obs) |
|
|
|
result = GroupObsUtil._transpose_list_of_lists(separated_obs) |
|
|
|
return result |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
|
separated_obs: List[np.array] = [] |
|
|
|
for i in range(num_obs): |
|
|
|
separated_obs.append( |
|
|
|
TeamObsUtil._padded_time_to_batch( |
|
|
|
batch[TeamObsUtil.get_name_at_next(i)] |
|
|
|
GroupObsUtil._padded_time_to_batch( |
|
|
|
batch[GroupObsUtil.get_name_at_next(i)] |
|
|
|
result = TeamObsUtil._transpose_list_of_lists(separated_obs) |
|
|
|
result = GroupObsUtil._transpose_list_of_lists(separated_obs) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
np.ndarray |
|
|
|
] # Observation following the trajectory, for bootstrapping |
|
|
|
next_collab_obs: List[List[np.ndarray]] |
|
|
|
next_group_obs: List[List[np.ndarray]] |
|
|
|
agent_id: str |
|
|
|
behavior_id: str |
|
|
|
|
|
|
|
|
|
|
[], |
|
|
|
[], |
|
|
|
) |
|
|
|
for teammate_status in exp.teammate_status: |
|
|
|
teammate_rewards.append(teammate_status.reward) |
|
|
|
teammate_continuous_actions.append(teammate_status.action.continuous) |
|
|
|
teammate_discrete_actions.append(teammate_status.action.discrete) |
|
|
|
for group_status in exp.group_status: |
|
|
|
teammate_rewards.append(group_status.reward) |
|
|
|
teammate_continuous_actions.append(group_status.action.continuous) |
|
|
|
teammate_discrete_actions.append(group_status.action.discrete) |
|
|
|
|
|
|
|
# Team actions |
|
|
|
agent_buffer_trajectory["team_continuous_action"].append( |
|
|
|
|
|
|
teammate_disc_next_actions = [] |
|
|
|
if not is_last_step: |
|
|
|
next_exp = self.steps[step + 1] |
|
|
|
for teammate_status in next_exp.teammate_status: |
|
|
|
teammate_cont_next_actions.append(teammate_status.action.continuous) |
|
|
|
teammate_disc_next_actions.append(teammate_status.action.discrete) |
|
|
|
for group_status in next_exp.group_status: |
|
|
|
teammate_cont_next_actions.append(group_status.action.continuous) |
|
|
|
teammate_disc_next_actions.append(group_status.action.discrete) |
|
|
|
for teammate_status in exp.teammate_status: |
|
|
|
teammate_cont_next_actions.append(teammate_status.action.continuous) |
|
|
|
teammate_disc_next_actions.append(teammate_status.action.discrete) |
|
|
|
for group_status in exp.group_status: |
|
|
|
teammate_cont_next_actions.append(group_status.action.continuous) |
|
|
|
teammate_disc_next_actions.append(group_status.action.discrete) |
|
|
|
|
|
|
|
agent_buffer_trajectory["team_next_continuous_action"].append( |
|
|
|
teammate_cont_next_actions |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
for i in range(num_obs): |
|
|
|
ith_team_obs = [] |
|
|
|
for _teammate_status in exp.teammate_status: |
|
|
|
ith_group_obs = [] |
|
|
|
for _group_status in exp.group_status: |
|
|
|
ith_team_obs.append(_teammate_status.obs[i]) |
|
|
|
agent_buffer_trajectory[TeamObsUtil.get_name_at(i)].append(ith_team_obs) |
|
|
|
ith_group_obs.append(_group_status.obs[i]) |
|
|
|
agent_buffer_trajectory[GroupObsUtil.get_name_at(i)].append( |
|
|
|
ith_group_obs |
|
|
|
) |
|
|
|
ith_team_obs_next = [] |
|
|
|
ith_group_obs_next = [] |
|
|
|
for _obs in self.next_collab_obs: |
|
|
|
ith_team_obs_next.append(_obs[i]) |
|
|
|
for _obs in self.next_group_obs: |
|
|
|
ith_group_obs_next.append(_obs[i]) |
|
|
|
next_teammate_status = self.steps[step + 1].teammate_status |
|
|
|
for _teammate_status in next_teammate_status: |
|
|
|
next_group_status = self.steps[step + 1].group_status |
|
|
|
for _group_status in next_group_status: |
|
|
|
ith_team_obs_next.append(_teammate_status.obs[i]) |
|
|
|
agent_buffer_trajectory[TeamObsUtil.get_name_at_next(i)].append( |
|
|
|
ith_team_obs_next |
|
|
|
ith_group_obs_next.append(_group_status.obs[i]) |
|
|
|
agent_buffer_trajectory[GroupObsUtil.get_name_at_next(i)].append( |
|
|
|
ith_group_obs_next |
|
|
|
) |
|
|
|
|
|
|
|
if exp.memory is not None: |
|
|
|
|
|
|
agent_buffer_trajectory["done"].append(exp.done) |
|
|
|
agent_buffer_trajectory["team_dones"].append( |
|
|
|
[_status.done for _status in exp.teammate_status] |
|
|
|
[_status.done for _status in exp.group_status] |
|
|
|
) |
|
|
|
|
|
|
|
# Adds the log prob and action of continuous/discrete separately |
|
|
|
|
|
|
Returns true if all teammates are done at the end of the trajectory. |
|
|
|
Combine with done_reached to check if the whole team is done. |
|
|
|
""" |
|
|
|
return all(_status.done for _status in self.steps[-1].teammate_status) |
|
|
|
return all(_status.done for _status in self.steps[-1].group_status) |
|
|
|
|
|
|
|
@property |
|
|
|
def interrupted(self) -> bool: |
|
|
|