|
|
|
|
|
|
from mlagents.trainers.settings import NetworkSettings |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.decoders import ValueHeads |
|
|
|
from mlagents.trainers.torch.layers import lstm_layer |
|
|
|
|
|
|
|
ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
|
|
|
EncoderFunction = Callable[ |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
self.lstm = nn.LSTM(self.h_size, self.m_size // 2, 1) |
|
|
|
self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True) |
|
|
|
else: |
|
|
|
self.lstm = None |
|
|
|
|
|
|
|
|
|
|
raise Exception("No valid inputs to network.") |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
encoding = encoding.view([sequence_length, -1, self.h_size]) |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|
encoding = encoding.reshape([-1, sequence_length, self.h_size]) |
|
|
|
encoding, memories = self.lstm( |
|
|
|
encoding.contiguous(), |
|
|
|
(memories[0].contiguous(), memories[1].contiguous()), |
|
|
|
) |
|
|
|
encoding = encoding.view([-1, self.m_size // 2]) |
|
|
|
encoding, memories = self.lstm(encoding, memories) |
|
|
|
encoding = encoding.reshape([-1, self.m_size // 2]) |
|
|
|
memories = torch.cat(memories, dim=-1) |
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
) -> Dict[str, torch.Tensor]: |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
""" |
|
|
|
Get value outputs for the given obs. |
|
|
|
:param vec_inputs: List of vector inputs as tensors. |
|
|
|
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
) -> Dict[str, torch.Tensor]: |
|
|
|
encoding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories) |
|
|
|
return self.value_heads(encoding) |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
encoding, memories_out = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
return self.value_heads(encoding), memories_out |
|
|
|
|
|
|
|
def get_dist_and_value( |
|
|
|
self, |
|
|
|
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
) -> Dict[str, torch.Tensor]: |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
actor_mem, critic_mem = None, None |
|
|
|
_, critic_mem = torch.split(memories, self.half_mem_size, -1) |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, -1) |
|
|
|
value_outputs, critic_mem_out = self.critic( |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if actor_mem is not None: |
|
|
|
# Make memories with the actor mem unchanged |
|
|
|
memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1) |
|
|
|
critic_mem = None |
|
|
|
value_outputs, _memories = self.critic( |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem |
|
|
|
) |
|
|
|
return value_outputs |
|
|
|
memories_out = None |
|
|
|
return value_outputs, memories_out |
|
|
|
|
|
|
|
def get_dist_and_value( |
|
|
|
self, |
|
|
|
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if self.use_lstm: |
|
|
|
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=1) |
|
|
|
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) |
|
|
|
else: |
|
|
|
mem_out = None |
|
|
|
return dists, value_outputs, mem_out |
|
|
|