|
|
|
|
|
|
self.network_settings: NetworkSettings = trainer_settings.network_settings |
|
|
|
self.seed = seed |
|
|
|
self.previous_action_dict: Dict[str, np.ndarray] = {} |
|
|
|
self.previous_memory_dict: Dict[str, np.ndarray] = {} |
|
|
|
self.memory_dict: Dict[str, np.ndarray] = {} |
|
|
|
self.normalize = trainer_settings.network_settings.normalize |
|
|
|
self.use_recurrent = self.network_settings.memory is not None |
|
|
|
|
|
|
if memory_matrix is None: |
|
|
|
return |
|
|
|
|
|
|
|
# Pass old memories into previous_memory_dict |
|
|
|
for agent_id in agent_ids: |
|
|
|
self.previous_memory_dict[agent_id] = self.memory_dict[agent_id] |
|
|
|
|
|
|
|
for index, agent_id in enumerate(agent_ids): |
|
|
|
self.memory_dict[agent_id] = memory_matrix[index, :] |
|
|
|
|
|
|
|
|
|
|
memory_matrix[index, :] = self.memory_dict[agent_id] |
|
|
|
return memory_matrix |
|
|
|
|
|
|
|
def retrieve_previous_memories(self, agent_ids: List[str]) -> np.ndarray: |
|
|
|
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32) |
|
|
|
for index, agent_id in enumerate(agent_ids): |
|
|
|
if agent_id in self.memory_dict: |
|
|
|
memory_matrix[index, :] = self.previous_memory_dict[agent_id] |
|
|
|
return memory_matrix |
|
|
|
|
|
|
|
if agent_id in self.previous_memory_dict: |
|
|
|
self.previous_memory_dict.pop(agent_id) |
|
|
|
|
|
|
|
def make_empty_previous_action(self, num_agents: int) -> np.ndarray: |
|
|
|
""" |
|
|
|