|
|
|
|
|
|
) |
|
|
|
self.stream_names = stream_names |
|
|
|
self.critic = CentralizedValueNetwork( |
|
|
|
stream_names, observation_shapes, network_settings, num_agents=3 |
|
|
|
stream_names, observation_shapes, network_settings, num_agents=2 |
|
|
|
) |
|
|
|
|
|
|
|
@property |
|
|
|
|
|
|
) -> Tuple[ |
|
|
|
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor |
|
|
|
]: |
|
|
|
print(len(critic_obs)) |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic and actor |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) |
|
|
|