浏览代码

[add-fire] Memory class abstraction (#4375)

/develop/add-fire
GitHub 4 年前
当前提交
6a1d993f
共有 7 个文件被更改,包括 132 次插入34 次删除
  1. 12
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 19
      ml-agents/mlagents/trainers/tests/torch/test_layers.py
  3. 6
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  4. 4
      ml-agents/mlagents/trainers/torch/components/bc/module.py
  5. 65
      ml-agents/mlagents/trainers/torch/layers.py
  6. 2
      ml-agents/mlagents/trainers/torch/model_serialization.py
  7. 58
      ml-agents/mlagents/trainers/torch/networks.py

12
ml-agents/mlagents/trainers/policy/torch_policy.py


conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
# Save the m_size needed for export
self._export_m_size = self.m_size
# m_size needed for training is determined by network, not trainer settings
self.m_size = self.actor_critic.memory_size
@property
def export_memory_size(self) -> int:
"""
Returns the memory size of the exported ONNX policy. This only includes the memory
of the Actor and not any auxillary networks.
"""
return self._export_m_size
def _split_decision_step(
self, decision_requests: DecisionSteps

19
ml-agents/mlagents/trainers/tests/torch/test_layers.py


linear_layer,
lstm_layer,
Initialization,
LSTM,
)

assert torch.all(
torch.eq(param.data[4:8], torch.ones_like(param.data[4:8]))
)
def test_lstm_class():
torch.manual_seed(0)
input_size = 12
memory_size = 64
batch_size = 8
seq_len = 16
lstm = LSTM(input_size, memory_size)
assert lstm.memory_size == memory_size
sample_input = torch.ones((batch_size, seq_len, input_size))
sample_memories = torch.ones((1, batch_size, memory_size))
out, mem = lstm(sample_input, sample_memories)
# Hidden size should be half of memory_size
assert out.shape == (batch_size, seq_len, memory_size // 2)
assert mem.shape == (1, batch_size, memory_size)

6
ml-agents/mlagents/trainers/tests/torch/test_networks.py


if lstm:
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
memories = torch.ones(
(
1,
network_settings.memory.sequence_length,
network_settings.memory.memory_size,
)
(1, network_settings.memory.sequence_length, actor.memory_size)
)
else:
sample_obs = torch.ones((1, obs_size))

4
ml-agents/mlagents/trainers/torch/components/bc/module.py


memories = []
if self.policy.use_recurrent:
memories = torch.zeros(
1, self.n_sequences, self.policy.actor_critic.half_mem_size * 2
)
memories = torch.zeros(1, self.n_sequences, self.policy.m_size)
if self.policy.use_vis_obs:
vis_obs = []

65
ml-agents/mlagents/trainers/torch/layers.py


import torch
import abc
from typing import Tuple
from enum import Enum

forget_bias
)
return lstm
class MemoryModule(torch.nn.Module):
@abc.abstractproperty
def memory_size(self) -> int:
"""
Size of memory that is required at the start of a sequence.
"""
pass
@abc.abstractmethod
def forward(
self, input_tensor: torch.Tensor, memories: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Pass a sequence to the memory module.
:input_tensor: Tensor of shape (batch_size, seq_length, size) that represents the input.
:memories: Tensor of initial memories.
:return: Tuple of output, final memories.
"""
pass
class LSTM(MemoryModule):
"""
Memory module that implements LSTM.
"""
def __init__(
self,
input_size: int,
memory_size: int,
num_layers: int = 1,
forget_bias: float = 1.0,
kernel_init: Initialization = Initialization.XavierGlorotUniform,
bias_init: Initialization = Initialization.Zero,
):
super().__init__()
# We set hidden size to half of memory_size since the initial memory
# will be divided between the hidden state and initial cell state.
self.hidden_size = memory_size // 2
self.lstm = lstm_layer(
input_size,
self.hidden_size,
num_layers,
True,
forget_bias,
kernel_init,
bias_init,
)
@property
def memory_size(self) -> int:
return 2 * self.hidden_size
def forward(
self, input_tensor: torch.Tensor, memories: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
h0, c0 = torch.split(memories, self.hidden_size, dim=-1)
hidden = (h0, c0)
lstm_out, hidden_out = self.lstm(input_tensor, hidden)
output_mem = torch.cat(hidden_out, dim=-1)
return lstm_out, output_mem

2
ml-agents/mlagents/trainers/torch/model_serialization.py


else []
)
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)])
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size])
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.export_memory_size])
# Need to pass all posslible inputs since currently keyword arguments is not
# supported by torch.nn.export()

58
ml-agents/mlagents/trainers/torch/networks.py


from typing import Callable, List, Dict, Tuple, Optional
import attr
import abc
import torch

from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import lstm_layer
from mlagents.trainers.torch.layers import LSTM
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[

)
if self.use_lstm:
self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True)
self.lstm = LSTM(self.h_size, self.m_size)
self.lstm = None
self.lstm = None # type: ignore
def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None:
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders):

for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders):
n1.copy_normalization(n2)
@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
def forward(
self,
vec_inputs: List[torch.Tensor],

if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
memories = torch.split(memories, self.m_size // 2, dim=-1)
memories = torch.cat(memories, dim=-1)
return encoding, memories

encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)
@property
def memory_size(self) -> int:
return self.network_body.memory_size
def forward(
self,
vec_inputs: List[torch.Tensor],

"""
pass
@abc.abstractproperty
def memory_size(self):
"""
Returns the size of the memory (same size used as input and output in the other
methods) used by this Actor.
"""
pass
class SimpleActor(nn.Module, Actor):
def __init__(

self.act_type = act_type
self.act_size = act_size
self.version_number = torch.nn.Parameter(torch.Tensor([2.0]))
self.memory_size = torch.nn.Parameter(torch.Tensor([0]))
self.is_continuous_int = torch.nn.Parameter(
torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
)

self.encoding_size = network_settings.memory.memory_size // 2
else:
self.encoding_size = network_settings.hidden_units
if self.act_type == ActionType.CONTINUOUS:
self.distribution = GaussianDistribution(
self.encoding_size,

self.encoding_size, act_size
)
@property
def memory_size(self) -> int:
return self.network_body.memory_size
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
self.network_body.update_normalization(vector_obs)

sampled_actions,
log_probs,
self.version_number,
self.memory_size,
torch.Tensor([self.network_body.memory_size]),
self.is_continuous_int,
self.act_size_vector,
)

# Give the Actor only half the memories. Note we previously validate
# that memory_size must be a multiple of 4.
self.use_lstm = network_settings.memory is not None
if network_settings.memory is not None:
self.half_mem_size = network_settings.memory.memory_size // 2
new_memory_settings = attr.evolve(
network_settings.memory, memory_size=self.half_mem_size
)
use_network_settings = attr.evolve(
network_settings, memory=new_memory_settings
)
else:
use_network_settings = network_settings
self.half_mem_size = 0
use_network_settings,
network_settings,
act_type,
act_size,
conditional_sigma,

self.critic = ValueNetwork(
stream_names, observation_shapes, use_network_settings
)
self.critic = ValueNetwork(stream_names, observation_shapes, network_settings)
@property
def memory_size(self) -> int:
return self.network_body.memory_size + self.critic.memory_size
def critic_pass(
self,

actor_mem, critic_mem = None, None
if self.use_lstm:
# Use only the back half of memories for critic
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, -1)
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1)
value_outputs, critic_mem_out = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)

) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
if self.use_lstm:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1)
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
else:
critic_mem = None
actor_mem = None

正在加载...
取消
保存