浏览代码

POCA Attention will use h_size for embedding size and not 128 (#5281)

/check-for-ModelOverriders
GitHub 4 年前
当前提交
1678be1c
共有 1 个文件被更改,包括 5 次插入6 次删除
  1. 11
      ml-agents/mlagents/trainers/torch/networks.py

11
ml-agents/mlagents/trainers/torch/networks.py


class MultiAgentNetworkBody(torch.nn.Module):
ATTENTION_EMBEDDING_SIZE = 128
"""
A network body that uses a self attention layer to handle state
and action input from a potentially variable number of agents that

+ self.action_spec.continuous_size
)
attention_embeding_size = self.h_size
obs_only_ent_size, None, self.ATTENTION_EMBEDDING_SIZE
obs_only_ent_size, None, attention_embeding_size
q_ent_size, None, self.ATTENTION_EMBEDDING_SIZE
q_ent_size, None, attention_embeding_size
self.self_attn = ResidualSelfAttention(self.ATTENTION_EMBEDDING_SIZE)
self.self_attn = ResidualSelfAttention(attention_embeding_size)
self.ATTENTION_EMBEDDING_SIZE,
attention_embeding_size,
network_settings.num_layers,
self.h_size,
kernel_gain=(0.125 / self.h_size) ** 0.5,

正在加载...
取消
保存