浏览代码

Trainer with attention

/exp-alternate-atten
vincentpierre 4 年前
当前提交
d3d4eb90
共有 5 个文件被更改,包括 78 次插入26 次删除
  1. 19
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 19
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  3. 7
      ml-agents/mlagents/trainers/torch/layers.py
  4. 37
      ml-agents/mlagents/trainers/torch/networks.py
  5. 22
      ml-agents/mlagents/trainers/torch/utils.py

19
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]:
vector_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])]
if self.policy.use_vis_obs:
visual_obs = []
for idx, _ in enumerate(
self.policy.actor_critic.network_body.visual_processors
):
visual_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
visual_obs.append(visual_ob)
else:
visual_obs = []
# if self.policy.use_vis_obs:
# visual_obs = []
# for idx, _ in enumerate(
# self.policy.actor_critic.network_body.visual_processors
# ):
# visual_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
# visual_obs.append(visual_ob)
# else:
# visual_obs = []
visual_obs = [ModelUtils.list_to_tensor(batch["visual_obs%d" % 0])]
memory = torch.zeros([1, 1, self.policy.m_size])

19
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
if self.policy.use_vis_obs:
vis_obs = []
for idx, _ in enumerate(
self.policy.actor_critic.network_body.visual_processors
):
vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
vis_obs.append(vis_ob)
else:
vis_obs = []
# if self.policy.use_vis_obs:
# vis_obs = []
# for idx, _ in enumerate(
# self.policy.actor_critic.network_body.visual_processors
# ):
# vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
# vis_obs.append(vis_ob)
# else:
# vis_obs = []
vis_obs = [ ModelUtils.list_to_tensor(batch["visual_obs%d" % 0])]
log_probs, entropy, values = self.policy.evaluate_actions(
vec_obs,
vis_obs,

7
ml-agents/mlagents/trainers/torch/layers.py


self.fc_q = torch.nn.Linear(query_size, self.n_heads * self.embedding_size)
self.fc_k = torch.nn.Linear(key_size, self.n_heads * self.embedding_size)
self.fc_v = torch.nn.Linear(value_size, self.n_heads * self.embedding_size)
# self.fc_q = LinearEncoder(query_size, 2, self.n_heads * self.embedding_size)
# self.fc_k = LinearEncoder(key_size,2, self.n_heads * self.embedding_size)
# self.fc_v = LinearEncoder(value_size,2, self.n_heads * self.embedding_size)
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_mask:torch.Tensor
key_mask = torch.sum(key ** 2, axis=2) < 0.01
# key_mask = torch.sum(key ** 2, axis=2) < 0.01
key_mask = key_mask.reshape(b, 1, 1, n_k)
query = self.fc_q(query) # (b, n_q, h*d)

37
ml-agents/mlagents/trainers/torch/networks.py


else 0
)
self.visual_processors, self.vector_processors, encoder_input_size = ModelUtils.create_input_processors(
self.visual_processors, self.vector_processors, self.attention, encoder_input_size = ModelUtils.create_input_processors(
observation_shapes,
self.h_size,
network_settings.vis_encode_type,

self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
)
self.self_embedding = LinearEncoder(6, 1, 64)
self.obs_embeding = LinearEncoder(4, 1, 64)
self.self_and_obs_embedding = LinearEncoder(64 + 64, 1, 64)
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)

inputs = torch.cat(encodes + [actions], dim=-1)
else:
inputs = torch.cat(encodes, dim=-1)
encoding = self.linear_encoder(inputs)
# TODO : This is a Hack
var_len_input = vis_inputs[0].reshape(-1, 20, 4)
key_mask = torch.sum(var_len_input ** 2, axis=2) < 0.01 # 1 means mask and 0 means let though
self_encoding = processed_vec.reshape(-1, 1, processed_vec.shape[1])
self_encoding = self.self_embedding(self_encoding) # (b, 64)
obs_encoding = self.obs_embeding(var_len_input)
expanded_self_encoding = self_encoding.reshape(-1, 1, 64).repeat(1, 20, 1)
self_and_key_emb = self.self_and_obs_embedding(
torch.cat([obs_encoding, expanded_self_encoding], dim=2)
)
# add the self to the entities
self_and_key_emb = torch.cat([self_encoding, self_and_key_emb], dim=1)
key_mask = torch.cat([torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1)
output, _ = self.attention(self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask)
output = torch.sum(output * (1 - key_mask).reshape(-1,21,1), dim=1) / torch.sum(1-key_mask, dim=1, keepdim=True)
# output = torch.cat([inputs, output], dim=1)
encoding = self.linear_encoder(output)
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)

if self.use_lstm:
# Use only the back half of memories for critic
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1)
if len(vis_inputs) == 0:
print("O vis obs for some reason")
value_outputs, critic_mem_out = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)

22
ml-agents/mlagents/trainers/torch/utils.py


SmallVisualEncoder,
VectorInput,
)
from mlagents.trainers.torch.layers import MultiHeadAttention
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import ActionSpec

Creates visual and vector encoders, along with their normalizers.
:param observation_shapes: List of Tuples that represent the action dimensions.
:param action_size: Number of additional un-normalized inputs to each vector encoder. Used for
conditioining network on other values (e.g. actions for a Q function)
conditioning network on other values (e.g. actions for a Q function)
:param h_size: Number of hidden units per layer.
:param vis_encode_type: Type of visual encoder to use.
:param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector

vector_size = 0
visual_output_size = 0
for i, dimension in enumerate(observation_shapes):
if len(dimension) == 3:
if False: # len(dimension) == 3:
ModelUtils._check_resolution_for_encoder(
dimension[0], dimension[1], vis_encode_type
)

elif len(dimension) == 1:
vector_size += dimension[0]
else:
raise UnityTrainerException(
print(#raise UnityTrainerException(
f"Unsupported shape of {dimension} for observation {i}"
)
if vector_size > 0:

# HardCoded
max_observables, observable_size, output_size = (20, 4, 64)
attention = MultiHeadAttention(
query_size=output_size,
key_size= output_size,
value_size=output_size,
output_size=output_size,
num_heads=4,
embedding_size=64
)
total_processed_size,
attention,
output_size#total_processed_size + output_size,
)
@staticmethod

正在加载...
取消
保存