浏览代码

Use hypernetwork if there is a goal

/goal-conditioning/new
Arthur Juliani 4 年前
当前提交
e8d54b6f
共有 4 个文件被更改,包括 100 次插入63 次删除
  1. 63
      Project/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity
  2. 50
      ml-agents/mlagents/trainers/torch/layers.py
  3. 42
      ml-agents/mlagents/trainers/torch/networks.py
  4. 8
      ml-agents/mlagents/trainers/torch/utils.py

63
Project/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity


propertyPath: m_Name
value: Area (6)
objectReference: {fileID: 0}
- target: {fileID: 1625008366184734, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 4124767863011510, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_LocalPosition.x
value: -1

m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 1488387672112076, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_Name
value: FloatAgent
objectReference: {fileID: 0}
- target: {fileID: 1625008366184734, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 4124767863011510, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_LocalPosition.x
value: 0

propertyPath: m_RootOrder
value: 6
objectReference: {fileID: 0}
- target: {fileID: 114650561397225712, guid: 5c2bd19e4bbda4991b74387ca5d28156,
type: 3}
propertyPath: m_UseHeuristic
value: 0
objectReference: {fileID: 0}
- target: {fileID: 114889700908650620, guid: 5c2bd19e4bbda4991b74387ca5d28156,
type: 3}
propertyPath: compression
value: 0
objectReference: {fileID: 0}
- target: {fileID: 114889700908650620, guid: 5c2bd19e4bbda4991b74387ca5d28156,
type: 3}
propertyPath: m_Compression
value: 0
objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
--- !u!1001 &715789529

- target: {fileID: 1625008366184734, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_Name
value: Area (2)
objectReference: {fileID: 0}
- target: {fileID: 1625008366184734, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 4124767863011510, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_LocalPosition.x

propertyPath: m_Name
value: Area (5)
objectReference: {fileID: 0}
- target: {fileID: 1625008366184734, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 4124767863011510, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_LocalPosition.x
value: 0

propertyPath: m_Name
value: Area (4)
objectReference: {fileID: 0}
- target: {fileID: 1625008366184734, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 4124767863011510, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_LocalPosition.x
value: -1

propertyPath: m_RootOrder
value: 10
objectReference: {fileID: 0}
- target: {fileID: 4124767863011510, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_LocalEulerAnglesHint.x
value: 0
objectReference: {fileID: 0}
- target: {fileID: 4124767863011510, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_LocalEulerAnglesHint.y
value: 0
objectReference: {fileID: 0}
- target: {fileID: 4124767863011510, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_LocalEulerAnglesHint.z
value: 0
objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
--- !u!1 &959566328

propertyPath: m_Name
value: Area (3)
objectReference: {fileID: 0}
- target: {fileID: 1625008366184734, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 4124767863011510, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_LocalPosition.x
value: 1

propertyPath: m_Name
value: Area (1)
objectReference: {fileID: 0}
- target: {fileID: 1625008366184734, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 4124767863011510, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_LocalPosition.x
value: 1

- target: {fileID: 1625008366184734, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_Name
value: Area (7)
objectReference: {fileID: 0}
- target: {fileID: 1625008366184734, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 4124767863011510, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}
propertyPath: m_LocalPosition.x

50
ml-agents/mlagents/trainers/torch/layers.py


lstm_out, hidden_out = self.lstm(input_tensor, hidden)
output_mem = torch.cat(hidden_out, dim=-1)
return lstm_out, output_mem
class HyperNetwork(torch.nn.Module):
def __init__(
self, input_size, output_size, hyper_input_size, num_layers, layer_size
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
layer_in_size = hyper_input_size
layers = []
for _ in range(num_layers):
layers.append(
linear_layer(
layer_in_size,
layer_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
bias_init=Initialization.Zero,
)
)
layers.append(Swish())
layer_in_size = layer_size
flat_output = linear_layer(
layer_size,
input_size * output_size + output_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
bias_init=Initialization.Zero,
)
self.hypernet = torch.nn.Sequential(*layers, flat_output)
def forward(self, input_activation, hyper_input):
flat_output_weights = self.hypernet(hyper_input)
batch_size = input_activation.size(0)
output_weights, output_bias = torch.split(
flat_output_weights, self.input_size * self.output_size, dim=-1
)
output_weights = output_weights.view(
batch_size, self.input_size, self.output_size
)
output_bias = output_bias.view(batch_size, self.output_size)
output = (
torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1)
+ output_bias
)
return output

42
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.torch_utils import torch, nn
from mlagents_envs.base_env import ActionSpec, ObservationSpec
from mlagents_envs.base_env import ActionSpec, ObservationSpec, ObservationType
from mlagents.trainers.torch.action_model import ActionModel
from mlagents.trainers.torch.agent_action import AgentAction
from mlagents.trainers.torch.action_log_probs import ActionLogProbs

from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, HyperNetwork
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil

else 0
)
self.processors, self.embedding_sizes = ModelUtils.create_input_processors(
self.processors, self.embedding_sizes, self.obs_types = ModelUtils.create_input_processors(
observation_specs,
self.h_size,
network_settings.vis_encode_type,

total_enc_size = sum(self.embedding_sizes) + encoded_act_size
self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
)
total_enc_size, total_goal_size = 0, 0
for idx, embedding_size in enumerate(self.embedding_sizes):
if self.obs_types[idx] == ObservationType.DEFAULT:
total_enc_size += embedding_size
if self.obs_types[idx] == ObservationType.GOAL:
total_goal_size += embedding_size
total_enc_size += encoded_act_size
if ObservationType.GOAL in self.obs_types:
self.linear_encoder = HyperNetwork(
total_enc_size,
self.h_size,
total_goal_size,
network_settings.num_layers,
self.h_size,
)
else:
self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
)
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)

sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encodes = []
goal_signal = None
encodes.append(processed_obs)
if self.obs_types[idx] == ObservationType.DEFAULT:
encodes.append(processed_obs)
else:
goal_signal = processed_obs
if len(encodes) == 0:
raise Exception("No valid inputs to network.")

inputs = torch.cat(encodes + [actions], dim=-1)
else:
inputs = torch.cat(encodes, dim=-1)
encoding = self.linear_encoder(inputs)
if goal_signal is None:
encoding = self.linear_encoder(inputs)
else:
encoding = self.linear_encoder(inputs, goal_signal)
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)

8
ml-agents/mlagents/trainers/torch/utils.py


)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import ObservationSpec
from mlagents_envs.base_env import ObservationSpec, ObservationType
class ModelUtils:

h_size: int,
vis_encode_type: EncoderType,
normalize: bool = False,
) -> Tuple[nn.ModuleList, List[int]]:
) -> Tuple[nn.ModuleList, List[int], List[ObservationType]]:
"""
Creates visual and vector encoders, along with their normalizers.
:param observation_specs: List of ObservationSpec that represent the observation dimensions.

"""
encoders: List[nn.Module] = []
embedding_sizes: List[int] = []
obs_types: List[ObservationType] = []
for obs_spec in observation_specs:
encoder, embedding_size = ModelUtils.get_encoder_for_obs(
obs_spec.shape, normalize, h_size, vis_encode_type

obs_types.append(obs_spec.observation_type)
return (nn.ModuleList(encoders), embedding_sizes)
return (nn.ModuleList(encoders), embedding_sizes, obs_types)
@staticmethod
def list_to_tensor(

正在加载...
取消
保存