|
|
|
|
|
|
from mlagents.trainers.settings import NetworkSettings, ConditioningType |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.decoders import ValueHeads |
|
|
|
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, ConditionalEncoder, Initialization |
|
|
|
from mlagents.trainers.torch.layers import ( |
|
|
|
LSTM, |
|
|
|
LinearEncoder, |
|
|
|
ConditionalEncoder, |
|
|
|
Initialization, |
|
|
|
) |
|
|
|
from mlagents.trainers.torch.encoders import VectorInput |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.trajectory import ObsUtil |
|
|
|
|
|
|
get_zero_entities_mask, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
|
|
|
EncoderFunction = Callable[ |
|
|
|
|
|
|
normalize=self.normalize, |
|
|
|
) |
|
|
|
|
|
|
|
#entity_num_max: int = 0 |
|
|
|
#var_processors = [p for p in self.processors if isinstance(p, EntityEmbedding)] |
|
|
|
#for processor in var_processors: |
|
|
|
# entity_max: int = processor.entity_num_max_elements |
|
|
|
# # Only adds entity max if it was known at construction |
|
|
|
# if entity_max > 0: |
|
|
|
# entity_num_max += entity_max |
|
|
|
#if len(var_processors) > 0: |
|
|
|
# if sum(self.embedding_sizes): |
|
|
|
# self.x_self_encoder = LinearEncoder( |
|
|
|
# sum(self.embedding_sizes), |
|
|
|
# 1, |
|
|
|
# self.h_size, |
|
|
|
# kernel_init=Initialization.Normal, |
|
|
|
# kernel_gain=(0.125 / self.h_size) ** 0.5, |
|
|
|
# ) |
|
|
|
# self.rsa = ResidualSelfAttention(self.h_size, entity_num_max) |
|
|
|
# total_enc_size = sum(self.embedding_sizes) + self.h_size |
|
|
|
#else: |
|
|
|
# total_enc_size = sum(self.embedding_sizes) |
|
|
|
|
|
|
|
total_enc_size, total_goal_size = 0, 0 |
|
|
|
for idx, embedding_size in enumerate(self.embedding_sizes): |
|
|
|
if ( |
|
|
|
|
|
|
total_goal_size += embedding_size |
|
|
|
total_enc_size += encoded_act_size |
|
|
|
|
|
|
|
entity_num_max: int = 0 |
|
|
|
var_processors = [p for p in self.processors if isinstance(p, EntityEmbedding)] |
|
|
|
for processor in var_processors: |
|
|
|
entity_max: int = processor.entity_num_max_elements |
|
|
|
# Only adds entity max if it was known at construction |
|
|
|
if entity_max > 0: |
|
|
|
entity_num_max += entity_max |
|
|
|
if len(var_processors) > 0: |
|
|
|
if sum(self.embedding_sizes): |
|
|
|
self.x_self_encoder = LinearEncoder( |
|
|
|
sum(self.embedding_sizes), |
|
|
|
1, |
|
|
|
self.h_size, |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / self.h_size) ** 0.5, |
|
|
|
) |
|
|
|
self.rsa = ResidualSelfAttention(self.h_size, entity_num_max) |
|
|
|
total_enc_size += self.h_size |
|
|
|
|
|
|
|
if ( |
|
|
|
ObservationType.GOAL in self.obs_types |
|
|
|
and self.conditioning_type != ConditioningType.DEFAULT |
|
|
|
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
|
|
|
#encodes = [] |
|
|
|
#var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = [] |
|
|
|
|
|
|
|
#for idx, processor in enumerate(self.processors): |
|
|
|
# if not isinstance(processor, EntityEmbedding): |
|
|
|
# # The input can be encoded without having to process other inputs |
|
|
|
# obs_input = inputs[idx] |
|
|
|
# processed_obs = processor(obs_input) |
|
|
|
# encodes.append(processed_obs) |
|
|
|
# else: |
|
|
|
# var_len_processor_inputs.append((processor, inputs[idx])) |
|
|
|
#if len(encodes) != 0: |
|
|
|
# encoded_self = torch.cat(encodes, dim=1) |
|
|
|
# input_exist = True |
|
|
|
#else: |
|
|
|
# input_exist = False |
|
|
|
#if len(var_len_processor_inputs) > 0: |
|
|
|
# # Some inputs need to be processed with a variable length encoder |
|
|
|
# masks = get_zero_entities_mask([p_i[1] for p_i in var_len_processor_inputs]) |
|
|
|
# embeddings: List[torch.Tensor] = [] |
|
|
|
# processed_self = self.x_self_encoder(encoded_self) if input_exist else None |
|
|
|
# for processor, var_len_input in var_len_processor_inputs: |
|
|
|
# embeddings.append(processor(processed_self, var_len_input)) |
|
|
|
# qkv = torch.cat(embeddings, dim=1) |
|
|
|
# attention_embedding = self.rsa(qkv, masks) |
|
|
|
# if not input_exist: |
|
|
|
# encoded_self = torch.cat([attention_embedding], dim=1) |
|
|
|
# input_exist = True |
|
|
|
# else: |
|
|
|
# encoded_self = torch.cat([encoded_self, attention_embedding], dim=1) |
|
|
|
|
|
|
|
#if not input_exist: |
|
|
|
# raise Exception( |
|
|
|
# "The trainer was unable to process any of the provided inputs. " |
|
|
|
# "Make sure the trained agents has at least one sensor attached to them." |
|
|
|
# ) |
|
|
|
|
|
|
|
#if actions is not None: |
|
|
|
# encoded_self = torch.cat([encoded_self, actions], dim=1) |
|
|
|
#encoding = self.linear_encoder(encoded_self) |
|
|
|
obs_encodes, goal_encodes = [], [] |
|
|
|
var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = [] |
|
|
|
obs_encodes = [] |
|
|
|
goal_encodes = [] |
|
|
|
obs_input = inputs[idx] |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
|
|
|
|
if self.obs_types[idx] == ObservationType.DEFAULT: |
|
|
|
obs_encodes.append(processed_obs) |
|
|
|
elif self.obs_types[idx] == ObservationType.GOAL: |
|
|
|
goal_encodes.append(processed_obs) |
|
|
|
if not isinstance(processor, EntityEmbedding): |
|
|
|
# The input can be encoded without having to process other inputs |
|
|
|
obs_input = inputs[idx] |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
if self.obs_types[idx] == ObservationType.DEFAULT: |
|
|
|
obs_encodes.append(processed_obs) |
|
|
|
elif self.obs_types[idx] == ObservationType.GOAL: |
|
|
|
goal_encodes.append(processed_obs) |
|
|
|
else: |
|
|
|
raise Exception( |
|
|
|
"TODO : Something other than a goal or observation was passed to the agent." |
|
|
|
) |
|
|
|
raise Exception( |
|
|
|
"TODO : Something other than a goal or observation was passed to the agent." |
|
|
|
) |
|
|
|
var_len_processor_inputs.append((processor, inputs[idx])) |
|
|
|
if len(obs_encodes) != 0: |
|
|
|
encoded_self = torch.cat(obs_encodes, dim=1) |
|
|
|
input_exist = True |
|
|
|
else: |
|
|
|
input_exist = False |
|
|
|
if len(var_len_processor_inputs) > 0: |
|
|
|
# Some inputs need to be processed with a variable length encoder |
|
|
|
masks = get_zero_entities_mask([p_i[1] for p_i in var_len_processor_inputs]) |
|
|
|
embeddings: List[torch.Tensor] = [] |
|
|
|
processed_self = self.x_self_encoder(encoded_self) if input_exist else None |
|
|
|
for processor, var_len_input in var_len_processor_inputs: |
|
|
|
embeddings.append(processor(processed_self, var_len_input)) |
|
|
|
qkv = torch.cat(embeddings, dim=1) |
|
|
|
attention_embedding = self.rsa(qkv, masks) |
|
|
|
if not input_exist: |
|
|
|
encoded_self = torch.cat([attention_embedding], dim=1) |
|
|
|
input_exist = True |
|
|
|
else: |
|
|
|
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1) |
|
|
|
|
|
|
|
if not input_exist: |
|
|
|
raise Exception( |
|
|
|
"The trainer was unable to process any of the provided inputs. " |
|
|
|
"Make sure the trained agents has at least one sensor attached to them." |
|
|
|
) |
|
|
|
|
|
|
|
if actions is not None: |
|
|
|
encoded_self = torch.cat([encoded_self, actions], dim=1) |
|
|
|
|
|
|
|
if self.conditioning_type == ConditioningType.DEFAULT: |
|
|
|
obs_encodes = obs_encodes + goal_encodes |
|
|
|
|
|
|
raise Exception("No valid inputs to network.") |
|
|
|
|
|
|
|
# Constants don't work in Barracuda |
|
|
|
if actions is not None: |
|
|
|
obs_inputs = torch.cat(obs_encodes + [actions], dim=-1) |
|
|
|
else: |
|
|
|
obs_inputs = torch.cat(obs_encodes, dim=-1) |
|
|
|
|
|
|
|
encoding = self.linear_encoder(obs_inputs) |
|
|
|
encoding = self.linear_encoder(encoded_self) |
|
|
|
encoding = self.linear_encoder(obs_inputs, goal_inputs) |
|
|
|
encoding = self.linear_encoder(encoded_self, goal_inputs) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
|
|
|
|
class Critic(abc.ABC): |
|
|
|
@abc.abstractmethod |
|
|
|
|
|
|
:returns: Dict of reward stream to output tensor for values. |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ValueNetwork(nn.Module, Critic): |
|
|
|
|
|
|
# Todo: add learning rate decay |
|
|
|
super().__init__() |
|
|
|
self.learning_rate = torch.Tensor([lr]) |
|
|
|
|