浏览代码

Memory size abstraction and fixes

/develop/add-fire/memoryclass
Ervin Teng 4 年前
当前提交
cb0085a7
共有 3 个文件被更改,包括 38 次插入15 次删除
  1. 1
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 26
      ml-agents/mlagents/trainers/torch/layers.py
  3. 26
      ml-agents/mlagents/trainers/torch/networks.py

1
ml-agents/mlagents/trainers/policy/torch_policy.py


conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
self.m_size = self.actor_critic.memory_size
self.actor_critic.to(TestingConfiguration.device)

26
ml-agents/mlagents/trainers/torch/layers.py


self.layers.append(Swish())
self.seq_layers = torch.nn.Sequential(*self.layers)
def forward(self, input_tensor, h0_c0):
hidden = h0_c0
@property
def memory_size(self) -> int:
return self.hidden_size // 2 + 2 * self.hidden_size
def forward(self, input_tensor, memories):
# memories is 1/2 * hidden_size (accumulant) + hidden_size/2 (h0) + hidden_size/2 (c0)
acc, h0, c0 = torch.split(
memories,
[self.hidden_size // 2, self.hidden_size, self.hidden_size],
dim=-1,
)
hidden = (h0, c0)
m = None
lstm_out, hidden = self.lstm(input_tensor, hidden)
m = acc.permute([1, 0, 2])
lstm_out, (h0_out, c0_out) = self.lstm(input_tensor, hidden)
if m is None:
m = h_half_subt
else:
m = AMRLMax.PassthroughMax.apply(m, h_half_subt)
m = AMRLMax.PassthroughMax.apply(m, h_half_subt)
return concat_out, hidden
output_mem = torch.cat([m.permute([1, 0, 2]), h0_out, c0_out], dim=-1)
return concat_out, output_mem
class PassthroughMax(torch.autograd.Function):
@staticmethod

26
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import lstm_layer
from mlagents.trainers.torch.layers import AMRLMax
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[

)
if self.use_lstm:
self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True)
self.lstm = AMRLMax(self.h_size, self.m_size // 2, batch_first=True)
else:
self.lstm = None # type: ignore

if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
memories = torch.split(memories, self.m_size // 2, dim=-1)
# memories = torch.split(memories, self.m_size // 2, dim=-1)
memories = torch.cat(memories, dim=-1)
# memories = torch.cat(memories, dim=-1)
return encoding, memories

self.act_type = act_type
self.act_size = act_size
self.version_number = torch.nn.Parameter(torch.Tensor([2.0]))
self.memory_size = torch.nn.Parameter(torch.Tensor([0]))
self.memory_size_param = torch.nn.Parameter(torch.Tensor([0]))
self.is_continuous_int = torch.nn.Parameter(
torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
)

self.encoding_size, act_size
)
@property
def memory_size(self) -> int:
if self.network_body.lstm is not None:
return self.network_body.lstm.memory_size
else:
return 0
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
self.network_body.update_normalization(vector_obs)

sampled_actions,
dists[0].pdf(sampled_actions),
self.version_number,
self.memory_size,
self.memory_size_param,
self.is_continuous_int,
self.act_size_vector,
)

self.critic = ValueNetwork(
stream_names, observation_shapes, use_network_settings
)
@property
def memory_size(self) -> int:
if self.network_body.lstm is not None:
return 2 * self.network_body.lstm.memory_size
else:
return 0
def critic_pass(
self,

正在加载...
取消
保存