|
|
|
|
|
|
for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders): |
|
|
|
n1.copy_normalization(n2) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.lstm.memory_size if self.use_lstm else 0 |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
|
|
|
else: |
|
|
|
encoding_size = network_settings.hidden_units |
|
|
|
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
if self.network_body.lstm is not None: |
|
|
|
return self.network_body.lstm.memory_size |
|
|
|
else: |
|
|
|
return 0 |
|
|
|
return self.network_body.memory_size |
|
|
|
|
|
|
|
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: |
|
|
|
self.network_body.update_normalization(vector_obs) |
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
if self.network_body.lstm is not None: |
|
|
|
return 2 * self.network_body.lstm.memory_size |
|
|
|
else: |
|
|
|
return 0 |
|
|
|
return self.network_body.memory_size + self.critic.memory_size |
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|