浏览代码

[add-fire] Add LSTM to SAC, LSTM fixes and initializations (#4324)

/develop/add-fire
GitHub 4 年前
当前提交
f374f87a
共有 10 个文件被更改,包括 263 次插入108 次删除
  1. 53
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 2
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 44
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  4. 125
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  5. 22
      ml-agents/mlagents/trainers/tests/torch/test_layers.py
  6. 17
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  7. 16
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  8. 36
      ml-agents/mlagents/trainers/torch/layers.py
  9. 46
      ml-agents/mlagents/trainers/torch/networks.py
  10. 10
      ml-agents/mlagents/trainers/torch/utils.py

53
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


from typing import Dict, Optional, Tuple, List
import torch
import numpy as np
from mlagents_envs.base_env import DecisionSteps
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.torch.components.bc.module import BCModule
from mlagents.trainers.torch.components.reward_providers import create_reward_provider

from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.torch.utils import ModelUtils

reward_signal, self.policy.behavior_spec, settings
)
def get_value_estimates(
self, decision_requests: DecisionSteps, idx: int, done: bool
) -> Dict[str, float]:
"""
Generates value estimates for bootstrapping.
:param decision_requests:
:param idx: Index in BrainInfo of agent.
:param done: Whether or not this is the last element of the episode,
in which case the value estimate will be 0.
:return: The value estimate dictionary with key being the name of the reward signal
and the value the corresponding value estimate.
"""
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)
value_estimates = self.policy.actor_critic.critic_pass(
np.expand_dims(vec_vis_obs.vector_observations[idx], 0),
np.expand_dims(vec_vis_obs.visual_observations[idx], 0),
)
value_estimates = {k: float(v) for k, v in value_estimates.items()}
# If we're done, reassign all of the value estimates that need terminal states.
if done:
for k in value_estimates:
if not self.reward_signals[k].ignore_done:
value_estimates[k] = 0.0
return value_estimates
def get_trajectory_value_estimates(
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]:

else:
visual_obs = []
memory = torch.zeros([1, len(vector_obs[0]), self.policy.m_size])
memory = torch.zeros([1, 1, self.policy.m_size])
next_obs = np.concatenate(next_obs, axis=-1)
next_obs = [ModelUtils.list_to_tensor(next_obs).unsqueeze(0)]
next_memory = torch.zeros([1, 1, self.policy.m_size])
vec_vis_obs = SplitObservations.from_observations(next_obs)
next_vec_obs = [
ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0)
]
next_vis_obs = [
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0)
for _vis_ob in vec_vis_obs.visual_observations
]
value_estimates = self.policy.actor_critic.critic_pass(
vector_obs, visual_obs, memory
value_estimates, next_memory = self.policy.actor_critic.critic_pass(
vector_obs, visual_obs, memory, sequence_length=batch.num_experiences
next_value_estimate = self.policy.actor_critic.critic_pass(
next_obs, next_obs, next_memory
next_value_estimate, _ = self.policy.actor_critic.critic_pass(
next_vec_obs, next_vis_obs, next_memory, sequence_length=1
)
for name, estimate in value_estimates.items():

2
ml-agents/mlagents/trainers/policy/torch_policy.py


run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0)
run_out["learning_rate"] = 0.0
if self.use_recurrent:
run_out["memories"] = memories.detach().cpu().numpy()
run_out["memory_out"] = memories.detach().cpu().numpy().squeeze(0)
return run_out
def get_action(

44
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


old_values: Dict[str, torch.Tensor],
returns: Dict[str, torch.Tensor],
epsilon: float,
loss_masks: torch.Tensor,
Creates training-specific Tensorflow ops for PPO models.
:param returns:
:param old_values:
:param values:
Evaluates value loss for PPO.
:param values: Value output of the current network.
:param old_values: Value stored with experiences in buffer.
:param returns: Computed returns.
:param epsilon: Clipping value for value estimate.
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
value_losses = []
for name, head in values.items():

)
v_opt_a = (returns_tensor - head) ** 2
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
value_loss = torch.mean(torch.max(v_opt_a, v_opt_b))
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
def ppo_policy_loss(self, advantages, log_probs, old_log_probs, masks):
def ppo_policy_loss(
self,
advantages: torch.Tensor,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
loss_masks: torch.Tensor,
) -> torch.Tensor:
Creates training-specific Tensorflow ops for PPO models.
:param masks:
:param advantages:
Evaluate PPO policy loss.
:param advantages: Computed advantages.
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
advantage = advantages.unsqueeze(-1)

p_opt_b = (
torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage
)
policy_loss = -torch.mean(torch.min(p_opt_a, p_opt_b))
policy_loss = -1 * ModelUtils.masked_mean(
torch.min(p_opt_a, p_opt_b).flatten(), loss_masks
)
return policy_loss
@timed

memories=memories,
seq_len=self.policy.sequence_length,
)
value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps)
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
value_loss = self.ppo_value_loss(
values, old_values, returns, decay_eps, loss_masks
)
ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32),
loss_masks,
)
loss = (
policy_loss
+ 0.5 * value_loss
- decay_bet * ModelUtils.masked_mean(entropy.flatten(), loss_masks)
loss = policy_loss + 0.5 * value_loss - decay_bet * torch.mean(entropy)
# Set optimizer learning rate
ModelUtils.update_learning_rate(self.optimizer, decay_lr)

125
ml-agents/mlagents/trainers/sac/optimizer_torch.py


import numpy as np
from typing import Dict, List, Mapping, cast, Tuple
from typing import Dict, List, Mapping, cast, Tuple, Optional
import attr
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import ActionType

self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
actions: torch.Tensor = None,
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
q1_out, _ = self.q1_network(vec_inputs, vis_inputs, actions=actions)
q2_out, _ = self.q2_network(vec_inputs, vis_inputs, actions=actions)
q1_out, _ = self.q1_network(
vec_inputs,
vis_inputs,
actions=actions,
memories=memories,
sequence_length=sequence_length,
)
q2_out, _ = self.q2_network(
vec_inputs,
vis_inputs,
actions=actions,
memories=memories,
sequence_length=sequence_length,
)
return q1_out, q2_out
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):

for name in self.stream_names
}
# Critics should have 1/2 of the memory of the policy
critic_memory = policy_network_settings.memory
if critic_memory is not None:
critic_memory = attr.evolve(
critic_memory, memory_size=critic_memory.memory_size // 2
)
value_network_settings = attr.evolve(
policy_network_settings, memory=critic_memory
)
policy_network_settings,
value_network_settings,
policy_network_settings,
value_network_settings,
)
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0)

* self.gammas[i]
* target_values[name]
)
_q1_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(q_backup, q1_stream)
_q1_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks
_q2_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(q_backup, q2_stream)
_q2_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks
)
q1_losses.append(_q1_loss)

v_backup = min_policy_qs[name] - torch.sum(
_ent_coef * log_probs, dim=1
)
# print(log_probs, v_backup, _ent_coef, loss_masks)
value_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(values[name], v_backup)
value_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks
)
value_losses.append(value_loss)
else:

v_backup = min_policy_qs[name] - torch.mean(
branched_ent_bonus, axis=0
)
value_loss = 0.5 * torch.mean(
loss_masks
* torch.nn.functional.mse_loss(values[name], v_backup.squeeze())
value_loss = 0.5 * ModelUtils.masked_mean(
torch.pan><span class="n">nn.functional.mse_loss(values[namen><span class="p">], v_backup.squeeze()),
loss_masks,
)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))

if not discrete:
mean_q1 = mean_q1.unsqueeze(1)
batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1)
policy_loss = torch.mean(loss_masks * batch_policy_loss)
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
else:
action_probs = log_probs.exp()
branched_per_action_ent = ModelUtils.break_into_branches(

target_current_diff = torch.squeeze(
target_current_diff_branched, axis=2
)
entropy_loss = -torch.mean(
loss_masks
* torch.mean(self._log_ent_coef * target_current_diff, axis=1)
entropy_loss = -1 * ModelUtils.masked_mean(
torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks
)
return entropy_loss

else:
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long)
memories = [
memories_list = [
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
offset = 1 if self.policy.sequence_length > 1 else 0
next_memories_list = [
ModelUtils.list_to_tensor(
batch["memory"][i][self.policy.m_size // 2 :]
) # only pass value part of memory to target network
for i in range(offset, len(batch["memory"]), self.policy.sequence_length)
]
if len(memories_list) > 0:
memories = torch.stack(memories_list).unsqueeze(0)
next_memories = torch.stack(next_memories_list).unsqueeze(0)
else:
memories = None
next_memories = None
# Q network memories are 0'ed out, since we don't have them during inference.
q_memories = torch.zeros_like(next_memories)
vis_obs: List[torch.Tensor] = []
next_vis_obs: List[torch.Tensor] = []
if self.policy.use_vis_obs:

)
if self.policy.use_continuous_act:
squeezed_actions = actions.squeeze(-1)
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs, sampled_actions)
q1_out, q2_out = self.value_network(vec_obs, vis_obs, squeezed_actions)
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
sampled_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,
squeezed_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs)
q1_out, q2_out = self.value_network(vec_obs, vis_obs)
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
target_values, _ = self.target_network(next_vec_obs, next_vis_obs)
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32)
target_values, _ = self.target_network(
next_vec_obs,
next_vis_obs,
memories=next_memories,
sequence_length=self.policy.sequence_length,
)
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
use_discrete = not self.policy.use_continuous_act
dones = ModelUtils.list_to_tensor(batch["done"])

22
ml-agents/mlagents/trainers/tests/torch/test_layers.py


import torch
from mlagents.trainers.torch.layers import Swish, linear_layer, Initialization
from mlagents.trainers.torch.layers import (
Swish,
linear_layer,
lstm_layer,
Initialization,
)
def test_swish():

)
assert torch.all(torch.eq(layer.weight.data, torch.zeros_like(layer.weight.data)))
assert torch.all(torch.eq(layer.bias.data, torch.zeros_like(layer.bias.data)))
def test_lstm_layer():
torch.manual_seed(0)
# Test zero for LSTM
layer = lstm_layer(
4, 4, kernel_init=Initialization.Zero, bias_init=Initialization.Zero
)
for name, param in layer.named_parameters():
if "weight" in name:
assert torch.all(torch.eq(param.data, torch.zeros_like(param.data)))
elif "bias" in name:
assert torch.all(
torch.eq(param.data[4:8], torch.ones_like(param.data[4:8]))
)

17
ml-agents/mlagents/trainers/tests/torch/test_networks.py


obs_size = 4
seq_len = 16
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=4)
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=12)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4)
for _ in range(100):
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 4))
for _ in range(200):
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 12))
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()

# memories isn't always set to None, the network should be able to
# deal with that.
# Test critic pass
value_out = actor.critic_pass([sample_obs], [], memories=memories)
value_out, memories_out = actor.critic_pass([sample_obs], [], memories=memories)
assert memories_out.shape == memories.shape
dists, value_out, _ = actor.get_dist_and_value([sample_obs], [], memories=memories)
dists, value_out, mem_out = actor.get_dist_and_value(
[sample_obs], [], memories=memories
)
if mem_out is not None:
assert mem_out.shape == memories.shape
for dist in dists:
assert isinstance(dist, GaussianDistInstance)
for stream in stream_names:

16
ml-agents/mlagents/trainers/tests/torch/test_utils.py


assert entropies.shape == (1, len(dist_list))
# Make sure the first action has high probability than the others.
assert log_probs.flatten()[0] > log_probs.flatten()[1]
def test_masked_mean():
test_input = torch.tensor([1, 2, 3, 4, 5])
masks = torch.ones_like(test_input).bool()
mean = ModelUtils.masked_mean(test_input, masks=masks)
assert mean == 3.0
masks = torch.tensor([False, False, True, True, True])
mean = ModelUtils.masked_mean(test_input, masks=masks)
assert mean == 4.0
# Make sure it works if all masks are off
masks = torch.tensor([False, False, False, False, False])
mean = ModelUtils.masked_mean(test_input, masks=masks)
assert mean == 0.0

36
ml-agents/mlagents/trainers/torch/layers.py


layer.weight.data *= kernel_gain
_init_methods[bias_init](layer.bias.data)
return layer
def lstm_layer(
input_size: int,
hidden_size: int,
num_layers: int = 1,
batch_first: bool = True,
forget_bias: float = 1.0,
kernel_init: Initialization = Initialization.XavierGlorotUniform,
bias_init: Initialization = Initialization.Zero,
) -> torch.nn.Module:
"""
Creates a torch.nn.LSTM and initializes its weights and biases. Provides a
forget_bias offset like is done in TensorFlow.
"""
lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=batch_first)
# Add forget_bias to forget gate bias
for name, param in lstm.named_parameters():
# Each weight and bias is a concatenation of 4 matrices
if "weight" in name:
for idx in range(4):
block_size = param.shape[0] // 4
_init_methods[kernel_init](
param.data[idx * block_size : (idx + 1) * block_size]
)
if "bias" in name:
for idx in range(4):
block_size = param.shape[0] // 4
_init_methods[bias_init](
param.data[idx * block_size : (idx + 1) * block_size]
)
if idx == 1:
param.data[idx * block_size : (idx + 1) * block_size].add_(
forget_bias
)
return lstm

46
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import lstm_layer
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[

)
if self.use_lstm:
self.lstm = nn.LSTM(self.h_size, self.m_size // 2, 1)
self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True)
else:
self.lstm = None

raise Exception("No valid inputs to network.")
if self.use_lstm:
encoding = encoding.view([sequence_length, -1, self.h_size])
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
encoding, memories = self.lstm(
encoding.contiguous(),
(memories[0].contiguous(), memories[1].contiguous()),
)
encoding = encoding.view([-1, self.m_size // 2])
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
memories = torch.cat(memories, dim=-1)
return encoding, memories

vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""
Get value outputs for the given obs.
:param vec_inputs: List of vector inputs as tensors.

vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
encoding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories)
return self.value_heads(encoding)
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories_out = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
return self.value_heads(encoding), memories_out
def get_dist_and_value(
self,

vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
actor_mem, critic_mem = None, None
_, critic_mem = torch.split(memories, self.half_mem_size, -1)
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, -1)
value_outputs, critic_mem_out = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)
if actor_mem is not None:
# Make memories with the actor mem unchanged
memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1)
critic_mem = None
value_outputs, _memories = self.critic(
vec_inputs, vis_inputs, memories=critic_mem
)
return value_outputs
memories_out = None
return value_outputs, memories_out
def get_dist_and_value(
self,

vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)
if self.use_lstm:
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=1)
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)
else:
mem_out = None
return dists, value_outputs, mem_out

10
ml-agents/mlagents/trainers/torch/utils.py


else:
all_probs = torch.cat(all_probs_list, dim=-1)
return log_probs, entropies, all_probs
@staticmethod
def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
"""
Returns the mean of the tensor but ignoring the values specified by masks.
Used for masking out loss functions.
:param tensor: Tensor which needs mean computation.
:param masks: Boolean tensor of masks with same dimension as tensor.
"""
return (tensor * masks).sum() / torch.clamp(masks.float().sum(), min=1.0)
正在加载...
取消
保存