|
|
|
|
|
|
MultiCategoricalDistribution, |
|
|
|
) |
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
from mlagents.trainers.models import EncoderType |
|
|
|
|
|
|
|
ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
|
|
|
EncoderFunction = Callable[ |
|
|
|
|
|
|
EPSILON = 1e-7 |
|
|
|
|
|
|
|
|
|
|
|
class EncoderType(Enum): |
|
|
|
SIMPLE = "simple" |
|
|
|
NATURE_CNN = "nature_cnn" |
|
|
|
RESNET = "resnet" |
|
|
|
|
|
|
|
|
|
|
|
class ActionType(Enum): |
|
|
|
|
|
|
hidden = encoder(vis_input) |
|
|
|
vis_embeds.append(hidden) |
|
|
|
|
|
|
|
#embedding = vec_embeds[0] |
|
|
|
# embedding = vec_embeds[0] |
|
|
|
if len(vec_embeds) > 0: |
|
|
|
vec_embeds = torch.stack(vec_embeds, dim=-1).sum(dim=-1) |
|
|
|
if len(vis_embeds) > 0: |
|
|
|
|
|
|
vec_inputs, vis_inputs, masks, memories, sequence_length |
|
|
|
) |
|
|
|
sampled_actions = self.sample_action(dists) |
|
|
|
return sampled_actions, dists[0].pdf(sampled_actions), self.version_number, self.memory_size, self.is_continuous_int, self.act_size_vector |
|
|
|
return ( |
|
|
|
sampled_actions, |
|
|
|
dists[0].pdf(sampled_actions), |
|
|
|
self.version_number, |
|
|
|
self.memory_size, |
|
|
|
self.is_continuous_int, |
|
|
|
self.act_size_vector, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class Critic(nn.Module): |
|
|
|
|
|
|
self.layers = [] |
|
|
|
last_channel = initial_channels |
|
|
|
for _, channel in enumerate(n_channels): |
|
|
|
self.layers.append(nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1)) |
|
|
|
self.layers.append( |
|
|
|
nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1) |
|
|
|
) |
|
|
|
self.layers.append(nn.MaxPool2d([3, 3], [2, 2])) |
|
|
|
height, width = pool_out_shape((height, width), 3) |
|
|
|
for _ in range(n_blocks): |
|
|
|
|
|
|
def forward(self, visual_obs): |
|
|
|
batch_size = visual_obs.shape[0] |
|
|
|
hidden = visual_obs |
|
|
|
for idx, layer in enumerate(self.layers): |
|
|
|
for layer in self.layers: |
|
|
|
if isinstance(layer, nn.Module): |
|
|
|
hidden = layer(hidden) |
|
|
|
elif isinstance(layer, list): |
|
|
|
|
|
|
EncoderType.NATURE_CNN: NatureVisualEncoder, |
|
|
|
EncoderType.RESNET: ResNetVisualEncoder, |
|
|
|
} |
|
|
|
print(encoder_type, ENCODER_FUNCTION_BY_TYPE.get(encoder_type)) |
|
|
|
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|