|
|
|
|
|
|
import torch |
|
|
|
from torch import nn |
|
|
|
|
|
|
|
from mlagents.trainers.distributions_torch import ( |
|
|
|
GaussianDistribution, |
|
|
|
MultiCategoricalDistribution, |
|
|
|
) |
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
from mlagents.trainers.distributions_torch import GaussianDistribution, CategoricalDistInstance |
|
|
|
|
|
|
|
ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
|
|
|
EncoderFunction = Callable[ |
|
|
|
|
|
|
vis_embeds.append(hidden) |
|
|
|
|
|
|
|
if len(vec_embeds) > 0: |
|
|
|
vec_embeds = torch.cat(vec_embeds) |
|
|
|
vec_embeds = torch.stack(vec_embeds, dim=-1).sum(dim=-1) |
|
|
|
vis_embeds = torch.cat(vis_embeds) |
|
|
|
vis_embeds = torch.stack(vis_embeds, dim=-1).sum(dim=-1) |
|
|
|
embedding = torch.cat([vec_embeds, vis_embeds]) |
|
|
|
embedding = torch.stack([vec_embeds, vis_embeds], dim=-1).sum(dim=-1) |
|
|
|
elif len(vis_embeds) > 0: |
|
|
|
embedding = vis_embeds |
|
|
|
embedding = vis_embeds |
|
|
|
raise Exception("No valid inputs to network.") |
|
|
|
embedding = embedding.reshape([sequence_length, -1, self.h_size]) |
|
|
|
embedding = embedding.view([sequence_length, -1, self.h_size]) |
|
|
|
embedding = embedding.reshape([-1, self.m_size // 2]) |
|
|
|
embedding = embedding.view([-1, self.m_size // 2]) |
|
|
|
memories = torch.cat(memories, dim=-1) |
|
|
|
return embedding, memories |
|
|
|
|
|
|
|
|
|
|
return dists, value_outputs, memories |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1 |
|
|
|
self, vec_inputs, vis_inputs=None, masks=None, memories=None, sequence_length=1 |
|
|
|
embedding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories, sequence_length |
|
|
|
) |
|
|
|
return sampled_actions, memories |
|
|
|
return sampled_actions, dists[0].pdf(sampled_actions) |
|
|
|
|
|
|
|
|
|
|
|
class Critic(nn.Module): |
|
|
|
|
|
|
def forward(self, visual_obs): |
|
|
|
conv_1 = torch.relu(self.conv1(visual_obs)) |
|
|
|
conv_2 = torch.relu(self.conv2(conv_1)) |
|
|
|
hidden = self.dense(conv_2.reshape([-1, self.final_flat])) |
|
|
|
hidden = self.dense(conv_2.view([-1, self.final_flat])) |
|
|
|
return hidden |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.final_flat = conv_3_hw[0] * conv_3_hw[1] * 64 |
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(initial_channels, 32, [8, 8], [4, 4]) |
|
|
|
self.conv2 = nn.Conv2d(43, 64, [4, 4], [2, 2]) |
|
|
|
self.conv2 = nn.Conv2d(32, 64, [4, 4], [2, 2]) |
|
|
|
self.conv3 = nn.Conv2d(64, 64, [3, 3], [1, 1]) |
|
|
|
self.dense = nn.Linear(self.final_flat, self.h_size) |
|
|
|
|
|
|
|
|
|
|
conv_3 = torch.relu(self.conv3(conv_2)) |
|
|
|
hidden = self.dense(conv_3.reshape([-1, self.final_flat])) |
|
|
|
hidden = self.dense(conv_3.view([-1, self.final_flat])) |
|
|
|
return hidden |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResNetVisualEncoder(nn.Module): |
|
|
|
def __init__(self, initial_channels): |
|
|
|
def __init__(self, height, width, initial_channels, final_hidden): |
|
|
|
super(ResNetVisualEncoder, self).__init__() |
|
|
|
n_channels = [16, 32, 32] # channel for each stack |
|
|
|
n_blocks = 2 # number of residual blocks |
|
|
|