浏览代码

Fix ResNet

/develop/add-fire
Arthur Juliani 4 年前
当前提交
b6dfb4ac
共有 1 个文件被更改,包括 23 次插入10 次删除
  1. 33
      ml-agents/mlagents/trainers/models_torch.py

33
ml-agents/mlagents/trainers/models_torch.py


entropies = []
for idx, action_dist in enumerate(dists):
action = actions[..., idx]
log_probs.append(action_dist.log_prob(action))
log_prob = action_dist.log_prob(action)
log_probs.append(log_prob)
entropies.append(action_dist.entropy())
log_probs = torch.stack(log_probs, dim=-1)
entropies = torch.stack(entropies, dim=-1)

return h, w
def pool_out_shape(h_w, kernel_size):
height = (h_w[0] - kernel_size) // 2 + 1
width = (h_w[1] - kernel_size) // 2 + 1
return height, width
class SimpleVisualEncoder(nn.Module):
def __init__(self, height, width, initial_channels, output_size):
super(SimpleVisualEncoder, self).__init__()

def forward(self, visual_obs):
conv_1 = torch.relu(self.conv1(visual_obs))
conv_2 = torch.relu(self.conv2(conv_1))
hidden = self.dense(conv_2.view([-1, self.final_flat]))
hidden = torch.relu(self.dense(conv_2.view([-1, self.final_flat])))
return hidden

conv_1 = torch.relu(self.conv1(visual_obs))
conv_2 = torch.relu(self.conv2(conv_1))
conv_3 = torch.relu(self.conv3(conv_2))
hidden = self.dense(conv_3.view([-1, self.final_flat]))
hidden = torch.relu(self.dense(conv_3.view([-1, self.final_flat])))
return hidden

n_channels = [16, 32, 32] # channel for each stack
n_blocks = 2 # number of residual blocks
self.layers = []
last_channel = initial_channels
self.layers.append(nn.Conv2d(initial_channels, channel, [3, 3], [1, 1]))
self.layers.append(nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1))
height, width = pool_out_shape((height, width), 3)
last_channel = channel
self.dense = nn.Linear(n_channels[-1] * height * width, final_hidden)
nn.Conv2d(channel, channel, [3, 3], [1, 1]),
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1),
nn.Conv2d(channel, channel, [3, 3], [1, 1]),
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1),
]
return block_layers

return hidden + input_hidden
def forward(self, visual_obs):
batch_size = visual_obs.shape[0]
for layer in self.layers:
if layer is nn.Module:
for idx, layer in enumerate(self.layers):
if isinstance(layer, nn.Module):
elif layer is list:
elif isinstance(layer, list):
return hidden.flatten()
before_out = hidden.view(batch_size, -1)
return torch.relu(self.dense(before_out))
class ModelUtils:

正在加载...
取消
保存