|
|
|
|
|
|
next_critic_obs: List[List[np.ndarray]], |
|
|
|
done: bool, |
|
|
|
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: |
|
|
|
|
|
|
|
|
|
|
|
next_obs = ObsUtil.from_buffer_next(batch, n_obs) |
|
|
|
team_obs = TeamObsUtil.from_buffer(batch, n_obs) |
|
|
|
next_team_obs = TeamObsUtil.from_buffer_next(batch, n_obs) |
|
|
|
|
|
|
|
memory = torch.zeros([1, 1, self.policy.m_size]) |
|
|
|
team_obs = [ |
|
|
|
[ModelUtils.list_to_tensor(obs) for obs in _teammate_obs] |
|
|
|
for _teammate_obs in team_obs |
|
|
|
] |
|
|
|
next_team_obs = [ |
|
|
|
[ModelUtils.list_to_tensor(obs) for obs in _teammate_obs] |
|
|
|
for _teammate_obs in next_team_obs |
|
|
|
] |
|
|
|
next_obs = [obs.unsqueeze(0) for obs in next_obs] |
|
|
|
next_actions = AgentAction.from_dict_next(batch) |
|
|
|
team_actions = AgentAction.from_team_dict(batch) |
|
|
|
next_team_actions = AgentAction.from_team_dict_next(batch) |
|
|
|
critic_obs = TeamObsUtil.from_buffer(batch, n_obs) |
|
|
|
critic_obs = [ |
|
|
|
[ModelUtils.list_to_tensor(obs) for obs in _teammate_obs] |
|
|
|
for _teammate_obs in critic_obs |
|
|
|
] |
|
|
|
next_critic_obs = [ |
|
|
|
ModelUtils.list_to_tensor_list(_list_obs) for _list_obs in next_critic_obs |
|
|
|
] |
|
|
|
# next_obs = [obs.unsqueeze(0) for obs in next_obs] |
|
|
|
|
|
|
|
# critic_obs = TeamObsUtil.from_buffer(batch, n_obs) |
|
|
|
# critic_obs = [ |
|
|
|
# [ModelUtils.list_to_tensor(obs) for obs in _teammate_obs] |
|
|
|
# for _teammate_obs in critic_obs |
|
|
|
# ] |
|
|
|
# next_critic_obs = [ |
|
|
|
# ModelUtils.list_to_tensor_list(_list_obs) for _list_obs in next_critic_obs |
|
|
|
# ] |
|
|
|
next_critic_obs = [ |
|
|
|
[_obs.unsqueeze(0) for _obs in _list_obs] for _list_obs in next_critic_obs |
|
|
|
] |
|
|
|
# next_critic_obs = [ |
|
|
|
# [_obs.unsqueeze(0) for _obs in _list_obs] for _list_obs in next_critic_obs |
|
|
|
# ] |
|
|
|
value_estimates, marg_val_estimates, next_memory = self.policy.actor_critic.critic_pass( |
|
|
|
value_estimates, marg_val_estimates, mem = self.policy.actor_critic.critic_pass( |
|
|
|
critic_obs=critic_obs, |
|
|
|
team_obs=team_obs, |
|
|
|
team_act=team_actions, |
|
|
|
) |
|
|
|
next_value_estimates, next_marg_val_estimates, next_mem = self.policy.actor_critic.critic_pass( |
|
|
|
next_obs, |
|
|
|
next_actions, |
|
|
|
memory, |
|
|
|
sequence_length=batch.num_experiences, |
|
|
|
team_obs=next_team_obs, |
|
|
|
team_act=next_team_actions, |
|
|
|
) |
|
|
|
|
|
|
|
# # Actions is a hack here, we need the next actions |
|
|
|
|
|
|
# These aren't used in COMAttention |
|
|
|
next_value_estimate, next_marg_val_estimate = {}, {} |
|
|
|
next_value_estimate[name] = 0.0 |
|
|
|
for name, estimate in next_value_estimates.items(): |
|
|
|
next_value_estimates[name] = ModelUtils.to_numpy(estimate) |
|
|
|
next_marg_val_estimate[name] = 0.0 |
|
|
|
for name, estimate in next_marg_val_estimates.items(): |
|
|
|
next_marg_val_estimates[name] = ModelUtils.to_numpy(estimate) |
|
|
|
for k in next_value_estimate: |
|
|
|
for k in next_value_estimates: |
|
|
|
next_value_estimate[k] = 0.0 |
|
|
|
next_value_estimates[k] = 0.0 |
|
|
|
next_value_estimate, |
|
|
|
next_marg_val_estimate, |
|
|
|
next_value_estimates, |
|
|
|
next_marg_val_estimates, |
|
|
|
) |