|
|
|
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.decoders import ValueHeads |
|
|
|
from mlagents.trainers.torch.layers import LSTM, LinearEncoder |
|
|
|
from mlagents.trainers.torch.model_serialization import exporting_to_onnx |
|
|
|
from mlagents.trainers.torch.encoders import VectorInput |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.trajectory import ObsUtil |
|
|
|
|
|
|
|
|
|
|
|
ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
|
|
|
EncoderFunction = Callable[ |
|
|
|
|
|
|
else 0 |
|
|
|
) |
|
|
|
|
|
|
|
self.visual_processors, self.vector_processors, encoder_input_size = ModelUtils.create_input_processors( |
|
|
|
self.encoders, self.embedding_sizes = ModelUtils.create_input_processors( |
|
|
|
total_enc_size = encoder_input_size + encoded_act_size |
|
|
|
|
|
|
|
total_enc_size = sum(self.embedding_sizes) + encoded_act_size |
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
total_enc_size, network_settings.num_layers, self.h_size |
|
|
|
) |
|
|
|
|
|
|
else: |
|
|
|
self.lstm = None # type: ignore |
|
|
|
|
|
|
|
def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None: |
|
|
|
for vec_input, vec_enc in zip(vec_inputs, self.vector_processors): |
|
|
|
vec_enc.update_normalization(vec_input) |
|
|
|
def update_normalization(self, buffer: AgentBuffer) -> None: |
|
|
|
obs = ObsUtil.from_buffer(buffer, len(self.encoders)) |
|
|
|
for vec_input, enc in zip(obs, self.encoders): |
|
|
|
if isinstance(enc, VectorInput): |
|
|
|
enc.update_normalization(torch.as_tensor(vec_input)) |
|
|
|
for n1, n2 in zip(self.vector_processors, other_network.vector_processors): |
|
|
|
n1.copy_normalization(n2) |
|
|
|
for n1, n2 in zip(self.encoders, other_network.encoders): |
|
|
|
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput): |
|
|
|
n1.copy_normalization(n2) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
for idx, processor in enumerate(self.vector_processors): |
|
|
|
vec_input = vec_inputs[idx] |
|
|
|
processed_vec = processor(vec_input) |
|
|
|
encodes.append(processed_vec) |
|
|
|
|
|
|
|
for idx, processor in enumerate(self.visual_processors): |
|
|
|
vis_input = vis_inputs[idx] |
|
|
|
if not exporting_to_onnx.is_exporting(): |
|
|
|
vis_input = vis_input.permute([0, 3, 1, 2]) |
|
|
|
processed_vis = processor(vis_input) |
|
|
|
encodes.append(processed_vis) |
|
|
|
for idx, processor in enumerate(self.encoders): |
|
|
|
obs_input = inputs[idx] |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
|
|
|
|
if len(encodes) == 0: |
|
|
|
raise Exception("No valid inputs to network.") |
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
vec_inputs, vis_inputs, actions, memories, sequence_length |
|
|
|
inputs, actions, memories, sequence_length |
|
|
|
) |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: |
|
|
|
def update_normalization(self, buffer: AgentBuffer) -> None: |
|
|
|
""" |
|
|
|
Updates normalization of Actor based on the provided List of vector obs. |
|
|
|
:param vector_obs: A List of vector obs as tensors. |
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def get_dists( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def forward( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
) -> Tuple[torch.Tensor, int, int, int, int]: |
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def get_dist_and_value( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
|
|
|
|
|
|
|
def get_dists( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if self.action_spec.is_continuous(): |
|
|
|
dists = self.distribution(encoding) |
|
|
|
|
|
|
""" |
|
|
|
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. |
|
|
|
""" |
|
|
|
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1) |
|
|
|
|
|
|
|
# This code will convert the ugly vec and obs into glorious unified list of inputs |
|
|
|
concatenated_vec_obs = vec_inputs[0] |
|
|
|
inputs = [] |
|
|
|
start = 0 |
|
|
|
end = 0 |
|
|
|
vis_index = 0 |
|
|
|
for i, enc in enumerate(self.network_body.encoders): |
|
|
|
if isinstance(enc, VectorInput): |
|
|
|
# This is a vec_obs |
|
|
|
vec_size = self.network_body.embedding_sizes[i] |
|
|
|
end = start + vec_size |
|
|
|
inputs.append(concatenated_vec_obs[:, start:end]) |
|
|
|
start = end |
|
|
|
else: |
|
|
|
inputs.append(vis_inputs[vis_index]) |
|
|
|
vis_index += 1 |
|
|
|
# End of code to convert the ugly vec and obs into glorious unified list of inputs |
|
|
|
|
|
|
|
dists, _ = self.get_dists(inputs, masks, memories, 1) |
|
|
|
if self.action_spec.is_continuous(): |
|
|
|
action_list = self.sample_action(dists) |
|
|
|
action_out = torch.stack(action_list, dim=-1) |
|
|
|
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
inputs, memories=memories, sequence_length=sequence_length |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if self.action_spec.is_continuous(): |
|
|
|
dists = self.distribution(encoding) |
|
|
|
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) |
|
|
|
value_outputs, critic_mem_out = self.critic( |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if actor_mem is not None: |
|
|
|
# Make memories with the actor mem unchanged |
|
|
|
|
|
|
|
|
|
|
def get_dist_and_value( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|
dists, actor_mem_outs = self.get_dists( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
inputs, |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if self.use_lstm: |
|
|
|
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) |
|
|
|
|
|
|
|
|
|
|
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: |
|
|
|
def update_normalization(self, vector_obs: AgentBuffer) -> None: |
|
|
|
super().update_normalization(vector_obs) |
|
|
|
self.critic.network_body.update_normalization(vector_obs) |
|
|
|
|
|
|
|