|
|
|
|
|
|
if len(memories) > 0: |
|
|
|
memories = torch.stack(memories).unsqueeze(0) |
|
|
|
|
|
|
|
comms = self.policy.get_comms( |
|
|
|
comm_obs[0], |
|
|
|
masks=act_masks, |
|
|
|
memories=memories, |
|
|
|
seq_len=self.policy.sequence_length, |
|
|
|
) |
|
|
|
# this is a little bit of a hack but is whats recommended in the |
|
|
|
# gumbel softmax documentation |
|
|
|
one_hot_diff_comms = obs[-1] - comms[1].detach() + comms[1] |
|
|
|
obs[-1] = one_hot_diff_comms |
|
|
|
#comms = self.policy.get_comms( |
|
|
|
# comm_obs[0], |
|
|
|
# masks=act_masks, |
|
|
|
# memories=memories, |
|
|
|
# seq_len=self.policy.sequence_length, |
|
|
|
#) |
|
|
|
## this is a little bit of a hack but is whats recommended in the |
|
|
|
## gumbel softmax documentation |
|
|
|
#one_hot_diff_comms = obs[-1] - comms[1].detach() + comms[1] |
|
|
|
#obs[-1] = one_hot_diff_comms |
|
|
|
|
|
|
|
log_probs, entropy, values = self.policy.evaluate_actions( |
|
|
|
obs, |
|
|
|