|
|
|
|
|
|
|
|
|
|
def __init__(self, trainer: Trainer): |
|
|
|
self.experience_buffers: Dict[str, List] = defaultdict(list) |
|
|
|
self.last_brain_info: Dict[str, BrainInfo] = defaultdict(BrainInfo) |
|
|
|
self.last_brain_info: Dict[str, BrainInfo] = {} |
|
|
|
self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = defaultdict( |
|
|
|
ActionInfoOutputs |
|
|
|
) |
|
|
|
|
|
|
tmp_environment = np.array(next_info.rewards) |
|
|
|
|
|
|
|
for agent_id in next_info.agents: |
|
|
|
stored_info = self.last_brain_info[agent_id] |
|
|
|
stored_take_action_outputs = self.last_take_action_outputs[agent_id] |
|
|
|
stored_info = self.last_brain_info.get(agent_id, None) |
|
|
|
stored_take_action_outputs = self.last_take_action_outputs.get( |
|
|
|
agent_id, None |
|
|
|
) |
|
|
|
if stored_info is not None: |
|
|
|
idx = stored_info.agents.index(agent_id) |
|
|
|
next_idx = next_info.agents.index(agent_id) |
|
|
|
|
|
|
max_step = next_info.max_reached[next_idx] |
|
|
|
|
|
|
|
# Add the outputs of the last eval |
|
|
|
action = take_action_outputs["action"][idx] |
|
|
|
action = stored_take_action_outputs["action"][idx] |
|
|
|
action_pre = take_action_outputs["pre_action"][idx] |
|
|
|
action_pre = stored_take_action_outputs["pre_action"][idx] |
|
|
|
action_probs = take_action_outputs["log_probs"][idx] |
|
|
|
action_probs = stored_take_action_outputs["log_probs"][idx] |
|
|
|
action_masks = stored_info.action_masks[idx] |
|
|
|
prev_action = self.policy.retrieve_previous_action([agent_id])[0, :] |
|
|
|
|
|
|
|