浏览代码

Multi-input network

/comms-grad
Ervin Teng 4 年前
当前提交
6846af21
共有 3 个文件被更改,包括 149 次插入4 次删除
  1. 3
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 3
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  3. 147
      ml-agents/mlagents/trainers/torch/networks.py

3
ml-agents/mlagents/trainers/policy/torch_policy.py


actions: torch.Tensor,
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
critic_obs: Optional[List[List[torch.Tensor]]] = None,
obs, masks, memories, seq_len
obs, masks, memories, critic_obs, seq_len
)
action_list = [actions[..., i] for i in range(actions.shape[-1])]
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists)

3
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


obs = ModelUtils.list_to_tensor_list(
AgentBuffer.obs_list_to_obs_batch(batch["obs"])
)
critic_obs = [ModelUtils.list_to_tensor_list(AgentBuffer.obs_list_to_obs_batch(agent_obs)) for agent_obs in batch["critic_obs"]]
act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
if self.policy.use_continuous_act:
actions = ModelUtils.list_to_tensor(batch["actions_pre"]).unsqueeze(-1)

masks=act_masks,
actions=actions,
memories=memories,
critic_obs=critic_obs,
seq_len=self.policy.sequence_length,
)
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)

147
ml-agents/mlagents/trainers/torch/networks.py


import enum
from typing import Callable, List, Dict, Tuple, Optional
import abc

return encoding, memories
# NOTE: this class will be replaced with a multi-head attention when the time comes
class MultiInputNetworkBody(nn.Module):
def __init__(
self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
encoded_act_size: int = 0,
num_obs_heads: int = 1,
):
super().__init__()
self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
self.h_size = network_settings.hidden_units
self.m_size = (
network_settings.memory.memory_size
if network_settings.memory is not None
else 0
)
self.processors = []
encoder_input_size = 0
for i in range(num_obs_heads):
_proc, _input_size = ModelUtils.create_input_processors(
observation_shapes,
self.h_size,
network_settings.vis_encode_type,
normalize=self.normalize,
)
self.processors.append(_proc)
encoder_input_size += _input_size
total_enc_size = encoder_input_size + encoded_act_size
self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
)
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)
else:
self.lstm = None # type: ignore
def update_normalization(self, net_inputs: List[torch.Tensor]) -> None:
for _proc in self.processors:
for _in, enc in zip(net_inputs, _proc):
enc.update_normalization(_in)
def copy_normalization(self, other_network: "NetworkBody") -> None:
if self.normalize:
for _proc in self.processors:
for n1, n2 in zip(_proc, other_network.processors):
n1.copy_normalization(n2)
@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
def forward(
self,
all_net_inputs: List[List[torch.Tensor]],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encodes = []
for net_inputs, processor_set in zip(all_net_inputs, self.processors):
for idx, processor in enumerate(processor_set):
net_input = net_inputs[idx]
if not exporting_to_onnx.is_exporting() and len(net_input.shape) > 3:
net_input = net_input.permute([0, 3, 1, 2])
processed_vec = processor(net_input)
encodes.append(processed_vec)
if len(encodes) == 0:
raise Exception("No valid inputs to network.")
# Constants don't work in Barracuda
if actions is not None:
inputs = torch.cat(encodes + [actions], dim=-1)
else:
inputs = torch.cat(encodes, dim=-1)
encoding = self.linear_encoder(inputs)
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
return encoding, memories
class ValueNetwork(nn.Module):
def __init__(
self,

return output, memories
class CentralizedValueNetwork(ValueNetwork):
def __init__(
self,
stream_names: List[str],
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
encoded_act_size: int = 0,
outputs_per_stream: int = 1,
num_agents: int = 1,
):
# This is not a typo, we want to call __init__ of nn.Module
nn.Module.__init__(self)
self.network_body = MultiInputNetworkBody(
observation_shapes,
network_settings,
encoded_act_size=encoded_act_size,
num_obs_heads=num_agents,
)
if network_settings.memory is not None:
encoding_size = network_settings.memory.memory_size // 2
else:
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)
def forward(
self,
all_net_inputs: List[List[torch.Tensor]],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories = self.network_body(
all_net_inputs, actions, memories, sequence_length
)
output = self.value_heads(encoding)
return output, memories
class Actor(abc.ABC):
@abc.abstractmethod
def update_normalization(self, net_inputs: List[torch.Tensor]) -> None:

net_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
critic_obs: Optional[List[List[torch.Tensor]]] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
"""

net_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
critic_obs: Optional[List[List[torch.Tensor]]] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories = self.network_body(

tanh_squash,
)
self.stream_names = stream_names
self.critic = ValueNetwork(stream_names, observation_shapes, network_settings)
self.critic = CentralizedValueNetwork(
stream_names, observation_shapes, network_settings, num_agents=3
)
@property
def memory_size(self) -> int:

self,
net_inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
critic_obs: List[List[torch.Tensor]] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
actor_mem, critic_mem = None, None

all_net_inputs = [net_inputs]
if critic_obs is not None:
all_net_inputs.extend(critic_obs)
net_inputs, memories=critic_mem, sequence_length=sequence_length
all_net_inputs, memories=critic_mem, sequence_length=sequence_length
)
if actor_mem is not None:
# Make memories with the actor mem unchanged

net_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
critic_obs: Optional[List[List[torch.Tensor]]] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
if self.use_lstm:

dists, actor_mem_outs = self.get_dists(
net_inputs, memories=actor_mem, sequence_length=sequence_length, masks=masks
)
all_net_inputs = [net_inputs]
if critic_obs is not None:
all_net_inputs.extend(critic_obs)
net_inputs, memories=critic_mem, sequence_length=sequence_length
all_net_inputs, memories=critic_mem, sequence_length=sequence_length
)
if self.use_lstm:
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)

正在加载...
取消
保存