浏览代码

Properly use MemoryModule abstraction

/develop/gru
Ervin Teng 4 年前
当前提交
e9025079
共有 2 个文件被更改,包括 28 次插入13 次删除
  1. 17
      ml-agents/mlagents/trainers/torch/layers.py
  2. 24
      ml-agents/mlagents/trainers/torch/networks.py

17
ml-agents/mlagents/trainers/torch/layers.py


"""
pass
@abc.abstractproperty
def output_size(self) -> int:
"""
Size of output per timestep of this memory module.
"""
pass
@abc.abstractmethod
def forward(
self, input_tensor: torch.Tensor, memories: torch.Tensor

super().__init__()
# We set hidden size to half of memory_size since the initial memory
# will be divided between the hidden state and initial cell state.
self.hidden_size = memory_size // 2
self.hidden_size = memory_size
self.gru = gru_layer(
input_size,
self.hidden_size,

@property
def memory_size(self) -> int:
return self.hidden_size
@property
def output_size(self) -> int:
return self.hidden_size
def forward(

@property
def memory_size(self) -> int:
return 2 * self.hidden_size
@property
def output_size(self) -> int:
return self.hidden_size
def forward(
self, input_tensor: torch.Tensor, memories: torch.Tensor

24
ml-agents/mlagents/trainers/torch/networks.py


):
super().__init__()
self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
self.use_memory = network_settings.memory is not None
self.h_size = network_settings.hidden_units
self.m_size = (
network_settings.memory.memory_size

total_enc_size, network_settings.num_layers, self.h_size
)
if self.use_lstm:
self.lstm = GRU(self.h_size, self.m_size)
if self.use_memory:
self.memory = GRU(self.h_size, self.m_size)
self.lstm = None # type: ignore
self.memory = None # type: ignore
def update_normalization(self, buffer: AgentBuffer) -> None:
obs = ObsUtil.from_buffer(buffer, len(self.processors))

@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
return self.memory.memory_size if self.use_memory else 0
def forward(
self,

encoded_self = torch.cat([encoded_self, actions], dim=1)
encoding = self.linear_encoder(encoded_self)
if self.use_lstm:
if self.use_memory:
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.lstm.hidden_size])
encoding, memories = self.memory(encoding, memories)
encoding = encoding.reshape([-1, self.memory.output_size])
return encoding, memories

self.network_body = NetworkBody(
observation_specs, network_settings, encoded_act_size=encoded_act_size
)
if network_settings.memory is not None:
encoding_size = network_settings.memory.memory_size // 2
if self.network_body.memory is not None:
encoding_size = self.network_body.memory.output_size
else:
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)

requires_grad=False,
)
self.network_body = NetworkBody(observation_specs, network_settings)
if network_settings.memory is not None:
self.encoding_size = network_settings.memory.memory_size // 2
if self.network_body.memory is not None:
self.encoding_size = self.network_body.memory.output_size
else:
self.encoding_size = network_settings.hidden_units
self.memory_size_vector = torch.nn.Parameter(

正在加载...
取消
保存