|
|
|
|
|
|
from typing import Callable, List, Dict, Tuple, Optional |
|
|
|
import attr |
|
|
|
import abc |
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|
encoding = encoding.reshape([-1, sequence_length, self.h_size]) |
|
|
|
# memories = torch.split(memories, self.m_size // 2, dim=-1) |
|
|
|
# memories = torch.cat(memories, dim=-1) |
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Give the Actor only half the memories. Note we previously validate |
|
|
|
# that memory_size must be a multiple of 4. |
|
|
|
self.use_lstm = network_settings.memory is not None |
|
|
|
if network_settings.memory is not None: |
|
|
|
self.half_mem_size = network_settings.memory.memory_size // 2 |
|
|
|
new_memory_settings = attr.evolve( |
|
|
|
network_settings.memory, memory_size=self.half_mem_size |
|
|
|
) |
|
|
|
use_network_settings = attr.evolve( |
|
|
|
network_settings, memory=new_memory_settings |
|
|
|
) |
|
|
|
else: |
|
|
|
use_network_settings = network_settings |
|
|
|
self.half_mem_size = 0 |
|
|
|
# if network_settings.memory is not None: |
|
|
|
# self.half_mem_size = network_settings.memory.memory_size // 2 |
|
|
|
# new_memory_settings = attr.evolve( |
|
|
|
# network_settings.memory, memory_size=self.half_mem_size |
|
|
|
# ) |
|
|
|
# use_network_settings = attr.evolve( |
|
|
|
# network_settings, memory=new_memory_settings |
|
|
|
# ) |
|
|
|
# else: |
|
|
|
# use_network_settings = network_settings |
|
|
|
# self.half_mem_size = 0 |
|
|
|
use_network_settings, |
|
|
|
network_settings, |
|
|
|
act_type, |
|
|
|
act_size, |
|
|
|
conditional_sigma, |
|
|
|
|
|
|
self.critic = ValueNetwork( |
|
|
|
stream_names, observation_shapes, use_network_settings |
|
|
|
) |
|
|
|
self.critic = ValueNetwork(stream_names, observation_shapes, network_settings) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|