|
|
|
|
|
|
entropies = [] |
|
|
|
for idx, action_dist in enumerate(dists): |
|
|
|
action = actions[..., idx] |
|
|
|
log_probs.append(action_dist.log_prob(action)) |
|
|
|
log_prob = action_dist.log_prob(action) |
|
|
|
log_probs.append(log_prob) |
|
|
|
entropies.append(action_dist.entropy()) |
|
|
|
log_probs = torch.stack(log_probs, dim=-1) |
|
|
|
entropies = torch.stack(entropies, dim=-1) |
|
|
|
|
|
|
return h, w |
|
|
|
|
|
|
|
|
|
|
|
def pool_out_shape(h_w, kernel_size): |
|
|
|
height = (h_w[0] - kernel_size) // 2 + 1 |
|
|
|
width = (h_w[1] - kernel_size) // 2 + 1 |
|
|
|
return height, width |
|
|
|
|
|
|
|
|
|
|
|
class SimpleVisualEncoder(nn.Module): |
|
|
|
def __init__(self, height, width, initial_channels, output_size): |
|
|
|
super(SimpleVisualEncoder, self).__init__() |
|
|
|
|
|
|
def forward(self, visual_obs): |
|
|
|
conv_1 = torch.relu(self.conv1(visual_obs)) |
|
|
|
conv_2 = torch.relu(self.conv2(conv_1)) |
|
|
|
hidden = self.dense(conv_2.view([-1, self.final_flat])) |
|
|
|
hidden = torch.relu(self.dense(conv_2.view([-1, self.final_flat]))) |
|
|
|
return hidden |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conv_1 = torch.relu(self.conv1(visual_obs)) |
|
|
|
conv_2 = torch.relu(self.conv2(conv_1)) |
|
|
|
conv_3 = torch.relu(self.conv3(conv_2)) |
|
|
|
hidden = self.dense(conv_3.view([-1, self.final_flat])) |
|
|
|
hidden = torch.relu(self.dense(conv_3.view([-1, self.final_flat]))) |
|
|
|
return hidden |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_channels = [16, 32, 32] # channel for each stack |
|
|
|
n_blocks = 2 # number of residual blocks |
|
|
|
self.layers = [] |
|
|
|
last_channel = initial_channels |
|
|
|
self.layers.append(nn.Conv2d(initial_channels, channel, [3, 3], [1, 1])) |
|
|
|
self.layers.append(nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1)) |
|
|
|
height, width = pool_out_shape((height, width), 3) |
|
|
|
last_channel = channel |
|
|
|
self.dense = nn.Linear(n_channels[-1] * height * width, final_hidden) |
|
|
|
nn.Conv2d(channel, channel, [3, 3], [1, 1]), |
|
|
|
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), |
|
|
|
nn.Conv2d(channel, channel, [3, 3], [1, 1]), |
|
|
|
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), |
|
|
|
] |
|
|
|
return block_layers |
|
|
|
|
|
|
|
|
|
|
return hidden + input_hidden |
|
|
|
|
|
|
|
def forward(self, visual_obs): |
|
|
|
batch_size = visual_obs.shape[0] |
|
|
|
for layer in self.layers: |
|
|
|
if layer is nn.Module: |
|
|
|
for idx, layer in enumerate(self.layers): |
|
|
|
if isinstance(layer, nn.Module): |
|
|
|
elif layer is list: |
|
|
|
elif isinstance(layer, list): |
|
|
|
return hidden.flatten() |
|
|
|
before_out = hidden.view(batch_size, -1) |
|
|
|
return torch.relu(self.dense(before_out)) |
|
|
|
|
|
|
|
|
|
|
|
class ModelUtils: |
|
|
|