比较提交

...
此合并请求有变更与目标分支冲突。
/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
/ml-agents/mlagents/trainers/agent_processor.py
/ml-agents/mlagents/trainers/policy/policy.py
/ml-agents/mlagents/trainers/torch/layers.py
/ml-agents/mlagents/trainers/torch/networks.py

3 次代码提交

作者 SHA1 备注 提交日期
Ervin Teng f3a2a81f Merge branch 'develop-fix-lstms' into develop-gru 4 年前
Ervin Teng e9025079 Properly use MemoryModule abstraction 4 年前
Ervin Teng 7c826fb1 Working GRU 4 年前
共有 5 个文件被更改,包括 134 次插入18 次删除
  1. 2
      ml-agents/mlagents/trainers/agent_processor.py
  2. 10
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  3. 15
      ml-agents/mlagents/trainers/policy/policy.py
  4. 92
      ml-agents/mlagents/trainers/torch/layers.py
  5. 33
      ml-agents/mlagents/trainers/torch/networks.py

2
ml-agents/mlagents/trainers/agent_processor.py


if stored_decision_step is not None and stored_take_action_outputs is not None:
obs = stored_decision_step.obs
if self.policy.use_recurrent:
memory = self.policy.retrieve_memories([global_id])[0, :]
memory = self.policy.retrieve_previous_memories([global_id])[0, :]
else:
memory = None
done = terminated # Since this is an ongoing step

10
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


from mlagents.torch_utils import torch
import numpy as np
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.buffer import AgentBuffer, BufferKey
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.components.bc.module import BCModule
from mlagents.trainers.torch.components.reward_providers import create_reward_provider

current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]
memory = torch.zeros([1, 1, self.policy.m_size])
memory = (
ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][0])
.unsqueeze(0)
.unsqueeze(0)
if self.policy.use_recurrent
else None
)
next_obs = [obs.unsqueeze(0) for obs in next_obs]

15
ml-agents/mlagents/trainers/policy/policy.py


self.network_settings: NetworkSettings = trainer_settings.network_settings
self.seed = seed
self.previous_action_dict: Dict[str, np.ndarray] = {}
self.previous_memory_dict: Dict[str, np.ndarray] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings.network_settings.normalize
self.use_recurrent = self.network_settings.memory is not None

if memory_matrix is None:
return
# Pass old memories into previous_memory_dict
for agent_id in agent_ids:
if agent_id in self.memory_dict:
self.previous_memory_dict[agent_id] = self.memory_dict[agent_id]
for index, agent_id in enumerate(agent_ids):
self.memory_dict[agent_id] = memory_matrix[index, :]

memory_matrix[index, :] = self.memory_dict[agent_id]
return memory_matrix
def retrieve_previous_memories(self, agent_ids: List[str]) -> np.ndarray:
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.previous_memory_dict:
memory_matrix[index, :] = self.previous_memory_dict[agent_id]
return memory_matrix
if agent_id in self.previous_memory_dict:
self.previous_memory_dict.pop(agent_id)
def make_empty_previous_action(self, num_agents: int) -> np.ndarray:
"""

92
ml-agents/mlagents/trainers/torch/layers.py


return lstm
def gru_layer(
input_size: int,
hidden_size: int,
num_layers: int = 1,
batch_first: bool = True,
forget_bias: float = 1.0,
kernel_init: Initialization = Initialization.XavierGlorotUniform,
bias_init: Initialization = Initialization.Zero,
) -> torch.nn.Module:
"""
Creates a torch.nn.LSTM and initializes its weights and biases. Provides a
forget_bias offset like is done in TensorFlow.
"""
lstm = torch.nn.GRU(input_size, hidden_size, num_layers, batch_first=batch_first)
# Add forget_bias to forget gate bias
for name, param in lstm.named_parameters():
# Each weight and bias is a concatenation of 2 matrices
if "weight" in name:
for idx in range(4):
block_size = param.shape[0] // 2
_init_methods[kernel_init](
param.data[idx * block_size : (idx + 1) * block_size]
)
if "bias" in name:
for idx in range(4):
block_size = param.shape[0] // 2
_init_methods[bias_init](
param.data[idx * block_size : (idx + 1) * block_size]
)
if idx == 1:
param.data[idx * block_size : (idx + 1) * block_size].add_(
forget_bias
)
return lstm
"""
pass
@abc.abstractproperty
def output_size(self) -> int:
"""
Size of output per timestep of this memory module.
"""
pass

return self.seq_layers(input_tensor)
class GRU(MemoryModule):
"""
Memory module that implements GRU
"""
def __init__(
self,
input_size: int,
memory_size: int,
num_layers: int = 1,
forget_bias: float = 1.0,
kernel_init: Initialization = Initialization.XavierGlorotUniform,
bias_init: Initialization = Initialization.Zero,
):
super().__init__()
# We set hidden size to half of memory_size since the initial memory
# will be divided between the hidden state and initial cell state.
self.hidden_size = memory_size
self.gru = gru_layer(
input_size,
self.hidden_size,
num_layers,
True,
forget_bias,
kernel_init,
bias_init,
)
@property
def memory_size(self) -> int:
return self.hidden_size
@property
def output_size(self) -> int:
return self.hidden_size
def forward(
self, input_tensor: torch.Tensor, memories: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# We don't use torch.split here since it is not supported by Barracuda
h0 = memories.contiguous()
gru_out, hidden_out = self.gru(input_tensor, h0)
return gru_out, hidden_out
class LSTM(MemoryModule):
"""
Memory module that implements LSTM.

@property
def memory_size(self) -> int:
return 2 * self.hidden_size
@property
def output_size(self) -> int:
return self.hidden_size
def forward(
self, input_tensor: torch.Tensor, memories: torch.Tensor

33
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, Initialization
from mlagents.trainers.torch.layers import GRU, LinearEncoder, Initialization
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil

):
super().__init__()
self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
self.use_memory = network_settings.memory is not None
self.h_size = network_settings.hidden_units
self.m_size = (
network_settings.memory.memory_size

total_enc_size, network_settings.num_layers, self.h_size
)
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)
if self.use_memory:
self.memory = GRU(self.h_size, self.m_size)
self.lstm = None # type: ignore
self.memory = None # type: ignore
def update_normalization(self, buffer: AgentBuffer) -> None:
obs = ObsUtil.from_buffer(buffer, len(self.processors))

@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
return self.memory.memory_size if self.use_memory else 0
def forward(
self,

encoded_self = torch.cat([encoded_self, actions], dim=1)
encoding = self.linear_encoder(encoded_self)
if self.use_lstm:
if self.use_memory:
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
encoding, memories = self.memory(encoding, memories)
encoding = encoding.reshape([-1, self.memory.output_size])
return encoding, memories

self.network_body = NetworkBody(
observation_specs, network_settings, encoded_act_size=encoded_act_size
)
if network_settings.memory is not None:
encoding_size = network_settings.memory.memory_size // 2
if self.network_body.memory is not None:
encoding_size = self.network_body.memory.output_size
else:
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)

requires_grad=False,
)
self.network_body = NetworkBody(observation_specs, network_settings)
if network_settings.memory is not None:
self.encoding_size = network_settings.memory.memory_size // 2
if self.network_body.memory is not None:
self.encoding_size = self.network_body.memory.output_size
else:
self.encoding_size = network_settings.hidden_units
self.memory_size_vector = torch.nn.Parameter(

inputs, masks=masks, memories=actor_mem, sequence_length=sequence_length
)
if critic_mem is not None:
# Make memories with the actor mem unchanged
memories_out = torch.cat([actor_mem_out, critic_mem], dim=-1)
# Get value memories with the actor mem unchanged
_, critic_mem_outs = self.critic(
inputs, memories=critic_mem, sequence_length=sequence_length
)
memories_out = torch.cat([actor_mem_out, critic_mem_outs], dim=-1)
else:
memories_out = None
return action, log_probs, entropies, memories_out

正在加载...
取消
保存