|
|
|
|
|
|
from typing import Callable, List, Dict, Tuple, Optional |
|
|
|
import attr |
|
|
|
import abc |
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
from mlagents.trainers.settings import NetworkSettings |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.decoders import ValueHeads |
|
|
|
from mlagents.trainers.torch.layers import lstm_layer |
|
|
|
from mlagents.trainers.torch.layers import LSTM |
|
|
|
|
|
|
|
ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
|
|
|
EncoderFunction = Callable[ |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True) |
|
|
|
self.lstm = LSTM(self.h_size, self.m_size) |
|
|
|
self.lstm = None |
|
|
|
self.lstm = None # type: ignore |
|
|
|
|
|
|
|
def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None: |
|
|
|
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders): |
|
|
|
|
|
|
for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders): |
|
|
|
n1.copy_normalization(n2) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.lstm.memory_size if self.use_lstm else 0 |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|
encoding = encoding.reshape([-1, sequence_length, self.h_size]) |
|
|
|
memories = torch.split(memories, self.m_size // 2, dim=-1) |
|
|
|
memories = torch.cat(memories, dim=-1) |
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoding_size = network_settings.hidden_units |
|
|
|
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
@abc.abstractproperty |
|
|
|
def memory_size(self): |
|
|
|
""" |
|
|
|
Returns the size of the memory (same size used as input and output in the other |
|
|
|
methods) used by this Actor. |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class SimpleActor(nn.Module, Actor): |
|
|
|
def __init__( |
|
|
|
|
|
|
self.act_type = act_type |
|
|
|
self.act_size = act_size |
|
|
|
self.version_number = torch.nn.Parameter(torch.Tensor([2.0])) |
|
|
|
self.memory_size = torch.nn.Parameter(torch.Tensor([0])) |
|
|
|
self.is_continuous_int = torch.nn.Parameter( |
|
|
|
torch.Tensor([int(act_type == ActionType.CONTINUOUS)]) |
|
|
|
) |
|
|
|
|
|
|
self.encoding_size = network_settings.memory.memory_size // 2 |
|
|
|
else: |
|
|
|
self.encoding_size = network_settings.hidden_units |
|
|
|
|
|
|
|
if self.act_type == ActionType.CONTINUOUS: |
|
|
|
self.distribution = GaussianDistribution( |
|
|
|
self.encoding_size, |
|
|
|
|
|
|
self.encoding_size, act_size |
|
|
|
) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size |
|
|
|
|
|
|
|
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: |
|
|
|
self.network_body.update_normalization(vector_obs) |
|
|
|
|
|
|
|
|
|
|
sampled_actions, |
|
|
|
log_probs, |
|
|
|
self.version_number, |
|
|
|
self.memory_size, |
|
|
|
torch.Tensor([self.network_body.memory_size]), |
|
|
|
self.is_continuous_int, |
|
|
|
self.act_size_vector, |
|
|
|
) |
|
|
|
|
|
|
# Give the Actor only half the memories. Note we previously validate |
|
|
|
# that memory_size must be a multiple of 4. |
|
|
|
self.use_lstm = network_settings.memory is not None |
|
|
|
if network_settings.memory is not None: |
|
|
|
self.half_mem_size = network_settings.memory.memory_size // 2 |
|
|
|
new_memory_settings = attr.evolve( |
|
|
|
network_settings.memory, memory_size=self.half_mem_size |
|
|
|
) |
|
|
|
use_network_settings = attr.evolve( |
|
|
|
network_settings, memory=new_memory_settings |
|
|
|
) |
|
|
|
else: |
|
|
|
use_network_settings = network_settings |
|
|
|
self.half_mem_size = 0 |
|
|
|
use_network_settings, |
|
|
|
network_settings, |
|
|
|
act_type, |
|
|
|
act_size, |
|
|
|
conditional_sigma, |
|
|
|
|
|
|
self.critic = ValueNetwork( |
|
|
|
stream_names, observation_shapes, use_network_settings |
|
|
|
) |
|
|
|
self.critic = ValueNetwork(stream_names, observation_shapes, network_settings) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size + self.critic.memory_size |
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
|
|
|
actor_mem, critic_mem = None, None |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, -1) |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) |
|
|
|
value_outputs, critic_mem_out = self.critic( |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic and actor |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1) |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) |
|
|
|
else: |
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|