浏览代码

Self-attention Centralized Critic

/develop/centralizedcritic
Ervin Teng 4 年前
当前提交
efa67290
共有 2 个文件被更改,包括 109 次插入35 次删除
  1. 61
      ml-agents/mlagents/trainers/torch/attention.py
  2. 83
      ml-agents/mlagents/trainers/torch/networks.py

61
ml-agents/mlagents/trainers/torch/attention.py


for ent in observations
]
return key_masks
class SmallestAttention(torch.nn.Module):
def __init__(
self,
x_self_size: int,
entities_sizes: List[int],
embedding_size: int,
output_size: Optional[int] = None,
):
super().__init__()
self.self_size = x_self_size
self.entities_sizes = entities_sizes
self.entities_num_max_elements: Optional[List[int]] = None
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(self.self_size + ent_size, 2, embedding_size)
# LinearEncoder(self.self_size + ent_size, 3, embedding_size)
# LinearEncoder(self.self_size + ent_size, 1, embedding_size)
for ent_size in self.entities_sizes
]
)
self.importance_layer = LinearEncoder(embedding_size, 1, 1)
def forward(
self,
x_self: torch.Tensor,
entities: List[torch.Tensor],
key_masks: List[torch.Tensor],
) -> torch.Tensor:
# Gather the maximum number of entities information
if self.entities_num_max_elements is None:
self.entities_num_max_elements = []
for ent in entities:
self.entities_num_max_elements.append(ent.shape[1])
# Concatenate all observations with self
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entities_num_max_elements, entities):
expanded_self = x_self.reshape(-1, 1, self.self_size)
# .repeat(
# 1, num_entities, 1
# )
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
# Generate the tensor that will serve as query, key and value to self attention
qkv = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
dim=1,
)
mask = torch.cat(key_masks, dim=1)
# Feed to self attention
max_num_ent = sum(self.entities_num_max_elements)
importance = self.importance_layer(qkv) + mask.unsqueeze(2) * -1e6
importance = torch.softmax(importance, dim=1)
weighted_qkv = qkv * importance
output = torch.sum(weighted_qkv, dim=1)
output = torch.cat([output, x_self], dim=1)
return output

83
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.attention import SmallestAttention, SimpleTransformer
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]

if network_settings.memory is not None
else 0
)
self.processors = []
encoder_input_size = 0
for i in range(num_obs_heads):
_proc, _input_size = ModelUtils.create_input_processors(
sensor_specs,
self.h_size,
network_settings.vis_encode_type,
normalize=self.normalize,
)
self.processors.append(_proc)
encoder_input_size += sum(_input_size)
self.processors, _input_size = ModelUtils.create_input_processors(
sensor_specs,
self.h_size,
network_settings.vis_encode_type,
normalize=self.normalize,
)
self.transformer = SmallestAttention(
sum(_input_size), [sum(_input_size)], self.h_size, self.h_size
)
encoder_input_size = self.h_size + sum(_input_size)
total_enc_size = encoder_input_size + encoded_act_size
self.linear_encoder = LinearEncoder(

else:
self.lstm = None # type: ignore
@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
for _proc in self.processors:
for _in, enc in zip(obs, _proc):
enc.update_normalization(_in)
for vec_input, enc in zip(obs, self.processors):
if isinstance(enc, VectorInput):
enc.update_normalization(torch.as_tensor(vec_input))
for _proc in self.processors:
for n1, n2 in zip(_proc, other_network.processors):
for n1, n2 in zip(self.processors, other_network.processors):
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput):
@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
def forward(
self,
all_net_inputs: List[List[torch.Tensor]],

) -> Tuple[torch.Tensor, torch.Tensor]:
encodes = []
for inputs, processor_set in zip(all_net_inputs, self.processors):
for idx, processor in enumerate(processor_set):
concat_encoded_obs = []
x_self = None
self_encodes = []
inputs = all_net_inputs[0]
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
processed_obs = processor(obs_input)
self_encodes.append(processed_obs)
x_self = torch.cat(self_encodes, dim=-1)
# Get the self encoding separately, but keep it in the entities
concat_encoded_obs = [x_self]
for inputs in all_net_inputs[1:]:
encodes = []
for idx, processor in enumerate(self.processors):
concat_encoded_obs.append(torch.cat(encodes, dim=-1))
if len(encodes) == 0:
concat_entites = torch.stack(concat_encoded_obs, dim=1)
encoded_state = self.transformer(
x_self, [concat_entites], SimpleTransformer.get_masks([concat_entites])
)
if len(concat_encoded_obs) == 0:
inputs = torch.cat(encodes + [actions], dim=-1)
inputs = torch.cat([encoded_state, actions], dim=-1)
inputs = torch.cat(encodes, dim=-1)
inputs = encoded_state
encoding = self.linear_encoder(inputs)
if self.use_lstm:

network_settings: NetworkSettings,
encoded_act_size: int = 0,
outputs_per_stream: int = 1,
num_agents: int = 1,
observation_shapes,
network_settings,
encoded_act_size=encoded_act_size,
num_obs_heads=num_agents,
observation_shapes, network_settings, encoded_act_size=encoded_act_size,
)
if network_settings.memory is not None:
encoding_size = network_settings.memory.memory_size // 2

)
self.stream_names = stream_names
self.critic = CentralizedValueNetwork(
stream_names, sensor_specs, network_settings, num_agents=2
stream_names, sensor_specs, network_settings
)
@property

if critic_obs is not None:
all_net_inputs.extend(critic_obs)
value_outputs, critic_mem_outs = self.critic(
all_net_inputs,
memories=critic_mem,
sequence_length=sequence_length,
all_net_inputs, memories=critic_mem, sequence_length=sequence_length,
)
return log_probs, entropies, value_outputs

正在加载...
取消
保存