浏览代码

Make lists modulelists

/goal-conditioning/new
Arthur Juliani 3 年前
当前提交
b8e81b00
共有 2 个文件被更改,包括 15 次插入8 次删除
  1. 17
      ml-agents/mlagents/trainers/torch/layers.py
  2. 6
      ml-agents/mlagents/trainers/torch/networks.py

17
ml-agents/mlagents/trainers/torch/layers.py


super().__init__()
self.layers = []
self.goal_encoders = []
prev_size = input_size
for i in range(num_layers):
prev_size = input_size + goal_size
for _ in range(num_layers):
self.layers.append(
linear_layer(
prev_size,

)
)
self.goal_encoders.append(LinearEncoder(goal_size, 2, hidden_size, final_activation=False))
self.goal_encoders.append(
LinearEncoder(goal_size, 2, hidden_size, final_activation=True)
)
self.layers = torch.nn.ModuleList(self.layers)
self.goal_encoders = torch.nn.ModuleList(self.goal_encoders)
activation = input_tensor
activation = torch.cat([input_tensor, goal_tensor], dim=-1)
activation = layer(activation) * self.goal_encoders[idx//2](goal_tensor)
activation = layer(activation) + self.goal_encoders[idx // 2](
goal_tensor
)
return activation

HyperNetwork(prev_size, hidden_size, goal_size, 2, hidden_size)
)
self.layers.append(Swish())
self.layers = torch.nn.ModuleList(self.layers)
prev_size = hidden_size
def forward(

6
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.torch.layers import (
LSTM,
LinearEncoder,
HyperNetwork,
ConditionalEncoder, HyperEncoder,
ConditionalEncoder,
HyperEncoder,
)
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer

total_goal_size,
network_settings.num_layers,
self.h_size,
num_hyper_layers=0
num_hyper_layers=1,
)
elif (
ObservationType.GOAL in self.obs_types

正在加载...
取消
保存