|
|
|
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.normalize = network_settings.normalize |
|
|
|
self.use_lstm = network_settings.memory is not None |
|
|
|
self.use_memory = network_settings.memory is not None |
|
|
|
self.h_size = network_settings.hidden_units |
|
|
|
self.m_size = ( |
|
|
|
network_settings.memory.memory_size |
|
|
|
|
|
|
total_enc_size, network_settings.num_layers, self.h_size |
|
|
|
) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
self.lstm = GRU(self.h_size, self.m_size) |
|
|
|
if self.use_memory: |
|
|
|
self.memory = GRU(self.h_size, self.m_size) |
|
|
|
self.lstm = None # type: ignore |
|
|
|
self.memory = None # type: ignore |
|
|
|
|
|
|
|
def update_normalization(self, buffer: AgentBuffer) -> None: |
|
|
|
obs = ObsUtil.from_buffer(buffer, len(self.processors)) |
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.lstm.memory_size if self.use_lstm else 0 |
|
|
|
return self.memory.memory_size if self.use_memory else 0 |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
|
|
|
encoded_self = torch.cat([encoded_self, actions], dim=1) |
|
|
|
encoding = self.linear_encoder(encoded_self) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
if self.use_memory: |
|
|
|
encoding, memories = self.lstm(encoding, memories) |
|
|
|
encoding = encoding.reshape([-1, self.lstm.hidden_size]) |
|
|
|
encoding, memories = self.memory(encoding, memories) |
|
|
|
encoding = encoding.reshape([-1, self.memory.output_size]) |
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.network_body = NetworkBody( |
|
|
|
observation_specs, network_settings, encoded_act_size=encoded_act_size |
|
|
|
) |
|
|
|
if network_settings.memory is not None: |
|
|
|
encoding_size = network_settings.memory.memory_size // 2 |
|
|
|
if self.network_body.memory is not None: |
|
|
|
encoding_size = self.network_body.memory.output_size |
|
|
|
else: |
|
|
|
encoding_size = network_settings.hidden_units |
|
|
|
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) |
|
|
|
|
|
|
requires_grad=False, |
|
|
|
) |
|
|
|
self.network_body = NetworkBody(observation_specs, network_settings) |
|
|
|
if network_settings.memory is not None: |
|
|
|
self.encoding_size = network_settings.memory.memory_size // 2 |
|
|
|
if self.network_body.memory is not None: |
|
|
|
self.encoding_size = self.network_body.memory.output_size |
|
|
|
else: |
|
|
|
self.encoding_size = network_settings.hidden_units |
|
|
|
self.memory_size_vector = torch.nn.Parameter( |
|
|
|