浏览代码

Pytorch port of SAC (#4219)

/develop/add-fire
GitHub 4 年前
当前提交
45154f52
共有 8 个文件被更改,包括 871 次插入85 次删除
  1. 43
      experiment_torch.py
  2. 50
      ml-agents/mlagents/trainers/distributions_torch.py
  3. 230
      ml-agents/mlagents/trainers/models_torch.py
  4. 59
      ml-agents/mlagents/trainers/policy/torch_policy.py
  5. 14
      ml-agents/mlagents/trainers/ppo/trainer.py
  6. 63
      ml-agents/mlagents/trainers/sac/trainer.py
  7. 16
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  8. 481
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

43
experiment_torch.py


name: str,
steps: int,
use_torch: bool,
algo: str,
num_torch_threads: int,
use_gpu: bool,
num_envs: int = 1,

name,
str(steps),
str(use_torch),
algo,
str(num_torch_threads),
str(num_envs),
str(use_gpu),

if config_name is None:
config_name = name
run_options = parse_command_line(
[f"config/ppo/{config_name}.yaml", "--num-envs", f"{num_envs}"]
[f"config/{algo}/{config_name}.yaml", "--num-envs", f"{num_envs}"]
)
run_options.checkpoint_settings.run_id = (
f"{name}_test_" + str(steps) + "_" + ("torch" if use_torch else "tf")

tc_advance_total = tc_advance["total"]
tc_advance_count = tc_advance["count"]
if use_torch:
update_total = update["TorchPPOOptimizer.update"]["total"]
if algo == "ppo":
update_total = update["TorchPPOOptimizer.update"]["total"]
update_count = update["TorchPPOOptimizer.update"]["count"]
else:
update_total = update["SACTrainer._update_policy"]["total"]
update_count = update["SACTrainer._update_policy"]["count"]
update_count = update["TorchPPOOptimizer.update"]["count"]
update_total = update["TFPPOOptimizer.update"]["total"]
if algo == "ppo":
update_total = update["TFPPOOptimizer.update"]["total"]
update_count = update["TFPPOOptimizer.update"]["count"]
else:
update_total = update["SACTrainer._update_policy"]["total"]
update_count = update["SACTrainer._update_policy"]["count"]
update_count = update["TFPPOOptimizer.update"]["count"]
evaluate_count = evaluate["NNPolicy.evaluate"]["count"]
# todo: do total / count
return (

algo,
str(num_torch_threads),
str(num_envs),
str(use_gpu),

action="store_true",
help="If true, will only do 3dball",
)
parser.add_argument(
"--sac",
default=False,
action="store_true",
help="If true, will run sac instead of ppo",
)
args = parser.parse_args()
if args.gpu:

algo = "ppo"
if args.sac:
algo = "sac"
("Hallway", "Hallway"),
("VisualHallway", "VisualHallway"),
if algo == "ppo":
envs_config_tuples += [("Hallway", "Hallway"),
("VisualHallway", "VisualHallway")]
"algorithm",
"num_torch_threads",
"num_envs",
"use_gpu",

results = []
results.append(labels)
f = open(
f"result_data_steps_{args.steps}_envs_{args.num_envs}_gpu_{args.gpu}_thread_{args.threads}.txt",
f"result_data_steps_{args.steps}_algo_{algo}_envs_{args.num_envs}_gpu_{args.gpu}_thread_{args.threads}.txt",
"w",
)
f.write(" ".join(labels) + "\n")

name=env_config[0],
steps=args.steps,
use_torch=True,
algo=algo,
num_torch_threads=1,
use_gpu=args.gpu,
num_envs=args.num_envs,

name=env_config[0],
steps=args.steps,
use_torch=True,
algo=algo,
num_torch_threads=8,
use_gpu=args.gpu,
num_envs=args.num_envs,

name=env_config[0],
steps=args.steps,
use_torch=False,
algo=algo,
num_torch_threads=1,
use_gpu=args.gpu,
num_envs=args.num_envs,

50
ml-agents/mlagents/trainers/distributions_torch.py


self.std = std
def sample(self):
return self.mean + torch.randn_like(self.mean) * self.std
sample = self.mean + torch.randn_like(self.mean) * self.std
return sample
log_scale = self.std.log()
log_scale = torch.log(self.std + EPSILON)
-((value - self.mean) ** 2) / (2 * var)
-((value - self.mean) ** 2) / (2 * var + EPSILON)
- log_scale
- math.log(math.sqrt(2 * math.pi))
)

return torch.exp(log_prob)
def entropy(self):
return torch.log(2 * math.pi * math.e * self.std)
return torch.log(2 * math.pi * math.e * self.std + EPSILON)
class TanhGaussianDistInstance(GaussianDistInstance):
def __init__(self, mean, std):
super().__init__(mean, std)
self.transform = torch.distributions.transforms.TanhTransform(cache_size=1)
def sample(self):
unsquashed_sample = super().sample()
squashed = self.transform(unsquashed_sample)
return squashed
def _inverse_tanh(self, value):
capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON)
def log_prob(self, value):
unsquashed = self.transform.inv(value)
return super().log_prob(unsquashed) - self.transform.log_abs_det_jacobian(
unsquashed, value
)
class CategoricalDistInstance(nn.Module):

def log_prob(self, value):
return torch.log(self.pdf(value))
def all_log_prob(self):
return torch.log(self.probs)
def __init__(self, hidden_size, num_outputs, conditional_sigma=False, **kwargs):
def __init__(
self,
hidden_size,
num_outputs,
conditional_sigma=False,
tanh_squash=False,
**kwargs
):
self.tanh_squash = tanh_squash
nn.init.xavier_uniform_(self.mu.weight, gain=0.01)
if conditional_sigma:
self.log_sigma = nn.Linear(hidden_size, num_outputs)

def forward(self, inputs):
mu = self.mu(inputs)
if self.conditional_sigma:
log_sigma = self.log_sigma(inputs)
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
return [GaussianDistInstance(mu, torch.exp(log_sigma))]
if self.tanh_squash:
return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))]
else:
return [GaussianDistInstance(mu, torch.exp(log_sigma))]
class MultiCategoricalDistribution(nn.Module):

230
ml-agents/mlagents/trainers/models_torch.py


from enum import Enum
from typing import Callable, NamedTuple, List, Optional
from typing import Callable, NamedTuple, List, Optional, Dict, Tuple
import torch
from torch import nn

)
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.models import EncoderType
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.brain import CameraResolution
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[

running_variance: torch.Tensor
def break_into_branches(
concatenated_logits: torch.Tensor, action_size: List[int]
) -> List[torch.Tensor]:
"""
Takes a concatenated set of logits that represent multiple discrete action branches
and breaks it up into one Tensor per branch.
:param concatenated_logits: Tensor that represents the concatenated action branches
:param action_size: List of ints containing the number of possible actions for each branch.
:return: A List of Tensors containing one tensor per branch.
"""
action_idx = [0] + list(np.cumsum(action_size))
branched_logits = [
concatenated_logits[:, action_idx[i] : action_idx[i + 1]]
for i in range(len(action_size))
]
return branched_logits
def actions_to_onehot(
discrete_actions: torch.Tensor, action_size: List[int]
) -> List[torch.Tensor]:
onehot_branches = [
torch.nn.functional.one_hot(_act.T, action_size[i])
for i, _act in enumerate(discrete_actions.T)
]
return onehot_branches
class NetworkBody(nn.Module):
def __init__(
self,

h_size,
)
)
self.vector_normalizers = nn.ModuleList(self.vector_normalizers)
self.vector_encoders = nn.ModuleList(self.vector_encoders)
self.visual_encoders = nn.ModuleList(self.visual_encoders)
if use_lstm:

for idx, vec_input in enumerate(vec_inputs):
self.vector_normalizers[idx].update(vec_input)
def copy_normalization(self, other_network: "NetworkBody") -> None:
if self.normalize:
for n1, n2 in zip(
self.vector_normalizers, other_network.vector_normalizers
):
n1.copy_from(n2)
def forward(self, vec_inputs, vis_inputs, memories=None, sequence_length=1):
vec_embeds = []
for idx, encoder in enumerate(self.vector_encoders):

vis_embeds.append(hidden)
# embedding = vec_embeds[0]
if len(vec_embeds) > 0:
vec_embeds = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
if len(vis_embeds) > 0:
vis_embeds = torch.stack(vis_embeds, dim=-1).sum(dim=-1)
embedding = torch.stack([vec_embeds, vis_embeds], dim=-1).sum(dim=-1)
vec_embeds_tensor = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
vis_embeds_tensor = torch.stack(vis_embeds, dim=-1).sum(dim=-1)
embedding = torch.stack([vec_embeds_tensor, vis_embeds_tensor], dim=-1).sum(
dim=-1
)
embedding = vec_embeds
embedding = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
embedding = vis_embeds
embedding = torch.stack(vis_embeds, dim=-1).sum(dim=-1)
else:
raise Exception("No valid inputs to network.")

embedding, memories = self.lstm(embedding.contiguous(), (memories[0].contiguous(), memories[1].contiguous()))
embedding, memories = self.lstm(
embedding.contiguous(),
(memories[0].contiguous(), memories[1].contiguous()),
)
class QNetwork(NetworkBody):
def __init__( # pylint: disable=W0231
self,
stream_names: List[str],
vector_sizes: List[int],
visual_sizes: List[CameraResolution],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
):
# This is not a typo, we want to call __init__ of nn.Module
nn.Module.__init__(self)
self.normalize = network_settings.normalize
self.visual_encoders = []
self.vector_encoders = []
self.vector_normalizers = []
self.use_lstm = network_settings.memory is not None
self.h_size = network_settings.hidden_units
self.m_size = (
network_settings.memory.memory_size
if network_settings.memory is not None
else 0
)
visual_encoder = ModelUtils.get_encoder_for_type(
network_settings.vis_encode_type
)
for vector_size in vector_sizes:
if vector_size != 0:
self.vector_normalizers.append(Normalizer(vector_size))
input_size = (
vector_size + sum(act_size)
if not act_type == ActionType.DISCRETE
else vector_size
)
self.vector_encoders.append(
VectorEncoder(input_size, self.h_size, network_settings.num_layers)
)
for visual_size in visual_sizes:
self.visual_encoders.append(
visual_encoder(
visual_size.height,
visual_size.width,
visual_size.num_channels,
self.h_size,
)
)
self.vector_normalizers = nn.ModuleList(self.vector_normalizers)
self.vector_encoders = nn.ModuleList(self.vector_encoders)
self.visual_encoders = nn.ModuleList(self.visual_encoders)
if self.use_lstm:
self.lstm = nn.LSTM(self.h_size, self.m_size // 2, 1)
else:
self.lstm = None
if act_type == ActionType.DISCRETE:
self.q_heads = ValueHeads(
stream_names, network_settings.hidden_units, sum(act_size)
)
else:
self.q_heads = ValueHeads(stream_names, network_settings.hidden_units)
def forward( # pylint: disable=W0221
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
memories: torch.Tensor = None,
sequence_length: int = 1,
actions: torch.Tensor = None,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
vec_embeds = []
for i, (enc, norm) in enumerate(
zip(self.vector_encoders, self.vector_normalizers)
):
vec_input = vec_inputs[i]
if self.normalize:
vec_input = norm(vec_input)
if actions is not None:
hidden = enc(torch.cat([vec_input, actions], dim=-1))
else:
hidden = enc(vec_input)
vec_embeds.append(hidden)
vis_embeds = []
for idx, encoder in enumerate(self.visual_encoders):
vis_input = vis_inputs[idx]
vis_input = vis_input.permute([0, 3, 1, 2])
hidden = encoder(vis_input)
vis_embeds.append(hidden)
# embedding = vec_embeds[0]
if len(vec_embeds) > 0 and len(vis_embeds) > 0:
vec_embeds_tensor = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
vis_embeds_tensor = torch.stack(vis_embeds, dim=-1).sum(dim=-1)
embedding = torch.stack([vec_embeds_tensor, vis_embeds_tensor], dim=-1).sum(
dim=-1
)
elif len(vec_embeds) > 0:
embedding = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
elif len(vis_embeds) > 0:
embedding = torch.stack(vis_embeds, dim=-1).sum(dim=-1)
else:
raise Exception("No valid inputs to network.")
if self.lstm is not None:
embedding = embedding.view([sequence_length, -1, self.h_size])
memories_tensor = torch.split(memories, self.m_size // 2, dim=-1)
embedding, memories = self.lstm(embedding, memories_tensor)
embedding = embedding.view([-1, self.m_size // 2])
memories = torch.cat(memories_tensor, dim=-1)
output, _ = self.q_heads(embedding)
return output, memories
class ActorCritic(nn.Module):
def __init__(
self,

use_lstm,
stream_names,
separate_critic,
conditional_sigma=False,
tanh_squash=False,
):
super(ActorCritic, self).__init__()
self.act_type = ActionType.from_str(act_type)

else:
embedding_size = h_size
if self.act_type == ActionType.CONTINUOUS:
self.distribution = GaussianDistribution(embedding_size, act_size[0])
self.distribution = GaussianDistribution(
embedding_size,
act_size[0],
conditional_sigma=conditional_sigma,
tanh_squash=tanh_squash,
)
else:
self.distribution = MultiCategoricalDistribution(embedding_size, act_size)
if separate_critic:

for action_dist in dists:
action = action_dist.sample()
actions.append(action)
actions = torch.stack(actions, dim=-1)
def get_probs_and_entropy(self, actions, dists):
def get_probs_and_entropy(self, action_list, dists):
all_probs = []
for idx, action_dist in enumerate(dists):
action = actions[..., idx]
for action, action_dist in zip(action_list, dists):
if self.act_type == ActionType.DISCRETE:
all_probs.append(action_dist.all_log_prob())
return log_probs, entropies
all_probs = None
else:
all_probs = torch.cat(all_probs, dim=-1)
return log_probs, entropies, all_probs
def get_dist_and_value(
self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1

dists, value_outputs, memories = self.get_dist_and_value(
vec_inputs, vis_inputs, masks, memories, sequence_length
)
sampled_actions = self.sample_action(dists)
action_list = self.sample_action(dists)
sampled_actions = torch.stack(action_list, dim=-1)
return (
sampled_actions,
dists[0].pdf(sampled_actions),

self.running_variance = new_variance
self.normalization_steps = total_new_steps
def copy_from(self, other_normalizer: "Normalizer") -> None:
self.normalization_steps.data.copy_(other_normalizer.normalization_steps.data)
self.running_mean.data.copy_(other_normalizer.running_mean.data)
self.running_variance.copy_(other_normalizer.running_variance.data)
def __init__(self, stream_names, input_size):
def __init__(self, stream_names, input_size, output_size=1):
self.value_heads = {}
_value_heads = {}
value = nn.Linear(input_size, 1)
self.value_heads[name] = value
self.value_heads = nn.ModuleDict(self.value_heads)
value = nn.Linear(input_size, output_size)
_value_heads[name] = value
self.value_heads = nn.ModuleDict(_value_heads)
for stream_name, _ in self.value_heads.items():
value_outputs[stream_name] = self.value_heads[stream_name](hidden).squeeze(
-1
)
for stream_name, head in self.value_heads.items():
value_outputs[stream_name] = head(hidden).squeeze(-1)
return (
value_outputs,
torch.mean(torch.stack(list(value_outputs.values())), dim=0),

for _ in range(num_layers - 1):
self.layers.append(nn.Linear(hidden_size, hidden_size))
self.layers.append(nn.ReLU())
self.layers = nn.ModuleList(self.layers)
self.seq_layers = nn.Sequential(*self.layers)
x = inputs
for layer in self.layers:
x = layer(x)
return x
return self.seq_layers(inputs)
def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):

conv_1 = torch.relu(self.conv1(visual_obs))
conv_2 = torch.relu(self.conv2(conv_1))
# hidden = torch.relu(self.dense(conv_2.view([-1, self.final_flat])))
hidden = torch.relu(self.dense(torch.reshape(conv_2,(-1, self.final_flat))))
hidden = torch.relu(self.dense(torch.reshape(conv_2, (-1, self.final_flat))))
return hidden

59
ml-agents/mlagents/trainers/policy/torch_policy.py


from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import os
from torch import onnx
from mlagents.trainers.action_info import ActionInfo

tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
separate_critic: Optional[bool] = None,
):
"""
Policy that uses a multilayer perceptron to map the observations to actions. Could

self.global_step = 0
self.m_size = 0
self.model_path = model_path
self.network_settings = trainer_settings.network_settings
self.act_size = brain.vector_action_space_size
self.act_type = brain.vector_action_space_type

torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
torch.set_default_tensor_type(torch.FloatTensor)
self.inference_dict: Dict[str, tf.Tensor] = {}
self.update_dict: Dict[str, tf.Tensor] = {}

visual_sizes=brain.camera_resolutions,
vis_encode_type=trainer_settings.network_settings.vis_encode_type,
stream_names=reward_signal_names,
separate_critic=self.use_continuous_act,
separate_critic=separate_critic
if separate_critic is not None
else self.use_continuous_act,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
self.actor_critic.to(TestingConfiguration.device)

self.actor_critic.update_normalization(vector_obs)
@timed
def sample_actions(self, vec_obs, vis_obs, masks=None, memories=None, seq_len=1):
def sample_actions(
self,
vec_obs,
vis_obs,
masks=None,
memories=None,
seq_len=1,
all_log_probs=False,
):
"""
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
"""
dists, (
value_heads,
mean_value,

actions = self.actor_critic.sample_action(dists)
log_probs, entropies = self.actor_critic.get_probs_and_entropy(actions, dists)
action_list = self.actor_critic.sample_action(dists)
log_probs, entropies, all_logs = self.actor_critic.get_probs_and_entropy(
action_list, dists
)
actions = torch.stack(action_list, dim=-1)
return actions, log_probs, entropies, value_heads, memories
return (
actions,
all_logs if all_log_probs else log_probs,
entropies,
value_heads,
memories,
)
def evaluate_actions(
self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1

)
log_probs, entropies = self.actor_critic.get_probs_and_entropy(actions, dists)
if len(actions.shape) <= 2:
actions = actions.unsqueeze(-1)
action_list = [actions[..., i] for i in range(actions.shape[2])]
log_probs, entropies, _ = self.actor_critic.get_probs_and_entropy(
action_list, dists
)
return log_probs, entropies, value_heads

run_out["learning_rate"] = 0.0
if self.use_recurrent:
run_out["memories"] = memories.detach().cpu().numpy()
self.actor_critic.update_normalization(vec_obs)
return run_out
def get_action(

def export_model(self, step=0):
try:
fake_vec_obs = [torch.zeros([1] + [self.brain.vector_observation_space_size])]
fake_vec_obs = [
torch.zeros([1] + [self.brain.vector_observation_space_size])
]
fake_vis_obs = [torch.zeros([1] + [84, 84, 3])]
fake_masks = torch.ones([1] + self.actor_critic.act_size)
# fake_memories = torch.zeros([1] + [self.m_size])

dynamic_axes = {"vector_observation": [0], "action": [0], "action_probs": [0]}
dynamic_axes = {
"vector_observation": [0],
"action": [0],
"action_probs": [0],
}
onnx.export(
self.actor_critic,
(fake_vec_obs, fake_vis_obs, fake_masks),

14
ml-agents/mlagents/trainers/ppo/trainer.py


# ## ML-Agent Learning (PPO)
# Contains an implementation of PPO as described in: https://arxiv.org/abs/1707.06347
use_torch = False
use_torch = True
from collections import defaultdict

from mlagents.trainers.settings import TrainerSettings, PPOSettings
logger = get_logger(__name__)
class PPOTrainer(RLTrainer):

self._stats_reporter.add_stat(stat, val)
self._clear_update_buffer()
return True
def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters
) -> Policy:
if self.framework == "torch":
return self.create_torch_policy(parsed_behavior_id, brain_parameters)
else:
return self.create_tf_policy(parsed_behavior_id, brain_parameters)
def create_tf_policy(
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters

63
ml-agents/mlagents/trainers/sac/trainer.py


from mlagents.trainers.trajectory import Trajectory, SplitObservations
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
from mlagents.trainers.settings import TrainerSettings, SACSettings

def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters
) -> Policy:
policy = NNPolicy(
self.seed,
brain_parameters,
self.trainer_settings,
self.is_training,
self.artifact_path,
self.load,
tanh_squash=True,
reparameterize=True,
create_tf_graph=False,
)
policy = super().create_policy(parsed_behavior_id, brain_parameters)
# Load the replay buffer if load
if self.load and self.checkpoint_replay_buffer:
try:

)
)
return policy
def create_tf_policy(
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters
) -> NNPolicy:
policy = NNPolicy(
self.seed,
brain_parameters,
self.trainer_settings,
self.artifact_path,
self.load,
tanh_squash=True,
reparameterize=True,
create_tf_graph=False,
)
return policy
def create_torch_policy(
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters
) -> TorchPolicy:
"""
Creates a PPO policy to trainers list of policies.
:param parsed_behavior_id:
:param brain_parameters: specifications for policy construction
:return policy
"""
policy = TorchPolicy(
self.seed,
brain_parameters,
self.trainer_settings,
self.artifact_path,
self.load,
condition_sigma_on_obs=True,
tanh_squash=True,
separate_critic=True,
)
return policy
def _update_sac_policy(self) -> bool:

self.__class__.__name__
)
)
if not isinstance(policy, NNPolicy):
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()")
self.optimizer = SACOptimizer(self.policy, self.trainer_settings)
if self.framework == "torch":
self.optimizer = TorchSACOptimizer( # type: ignore
self.policy, self.trainer_settings # type: ignore
) # type: ignore
else:
if not isinstance(policy, NNPolicy):
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()")
self.optimizer = SACOptimizer( # type: ignore
self.policy, self.trainer_settings # type: ignore
) # type: ignore
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly

16
ml-agents/mlagents/trainers/trainer/rl_trainer.py


from mlagents.trainers.trainer import Trainer
from mlagents.trainers.components.reward_signals import RewardSignalResult
from mlagents_envs.timers import hierarchical_timer
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.policy.policy import Policy
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.ppo.trainer import TestingConfiguration
RewardSignalResults = Dict[str, RewardSignalResult]

self._stats_reporter.add_property(
StatsPropertyType.HYPERPARAMETERS, self.trainer_settings.as_dict()
)
self.framework = "torch" if TestingConfiguration.use_torch else "tf"
if TestingConfiguration.max_steps > 0:
self.trainer_settings.max_steps = TestingConfiguration.max_steps
self._next_save_step = 0
self._next_summary_step = 0

:return: A boolean corresponding to wether or not update_model() can be run
"""
return False
def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters
) -> Policy:
if self.framework == "torch":
return self.create_torch_policy(parsed_behavior_id, brain_parameters)
else:
return self.create_tf_policy(parsed_behavior_id, brain_parameters)
@abc.abstractmethod
def _update_policy(self) -> bool:

481
ml-agents/mlagents/trainers/sac/optimizer_torch.py


import numpy as np
from typing import Dict, List, Mapping, cast, Tuple
import torch
from torch import nn
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.brain import CameraResolution
from mlagents.trainers.models_torch import (
Critic,
QNetwork,
ActionType,
list_to_tensor,
break_into_branches,
actions_to_onehot,
)
from mlagents.trainers.buffer import AgentBuffer
from mlagents_envs.timers import timed
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.settings import TrainerSettings, SACSettings
EPSILON = 1e-6 # Small value to avoid divide by zero
logger = get_logger(__name__)
class TorchSACOptimizer(TorchOptimizer):
class PolicyValueNetwork(nn.Module):
def __init__(
self,
stream_names: List[str],
vector_sizes: List[int],
visual_sizes: List[CameraResolution],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
):
super().__init__()
self.q1_network = QNetwork(
stream_names,
vector_sizes,
visual_sizes,
network_settings,
act_type,
act_size,
)
self.q2_network = QNetwork(
stream_names,
vector_sizes,
visual_sizes,
network_settings,
act_type,
act_size,
)
def forward(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
actions: torch.Tensor = None,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
q1_out, _ = self.q1_network(vec_inputs, vis_inputs, actions=actions)
q2_out, _ = self.q2_network(vec_inputs, vis_inputs, actions=actions)
return q1_out, q2_out
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
super().__init__(policy, trainer_params)
hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters)
lr = hyperparameters.learning_rate
# lr_schedule = hyperparameters.learning_rate_schedule
# max_step = trainer_params.max_steps
self.tau = hyperparameters.tau
self.init_entcoef = hyperparameters.init_entcoef
self.policy = policy
self.act_size = policy.act_size
policy_network_settings = policy.network_settings
# h_size = policy_network_settings.hidden_units
# num_layers = policy_network_settings.num_layers
# vis_encode_type = policy_network_settings.vis_encode_type
self.tau = hyperparameters.tau
self.burn_in_ratio = 0.0
# Non-exposed SAC parameters
self.discrete_target_entropy_scale = 0.2 # Roughly equal to e-greedy 0.05
self.continuous_target_entropy_scale = 1.0
self.stream_names = list(self.reward_signals.keys())
# Use to reduce "survivor bonus" when using Curiosity or GAIL.
self.gammas = [_val.gamma for _val in trainer_params.reward_signals.values()]
self.use_dones_in_backup = {
name: int(self.reward_signals[name].use_terminal_states)
for name in self.stream_names
}
# self.disable_use_dones = {
# name: self.use_dones_in_backup[name].assign(0.0)
# for name in stream_names
# }
brain = policy.brain
self.value_network = TorchSACOptimizer.PolicyValueNetwork(
self.stream_names,
[brain.vector_observation_space_size],
brain.camera_resolutions,
policy_network_settings,
ActionType.from_str(policy.act_type),
self.act_size,
)
self.target_network = Critic(
self.stream_names,
policy_network_settings.hidden_units,
[brain.vector_observation_space_size],
brain.camera_resolutions,
policy_network_settings.normalize,
policy_network_settings.num_layers,
policy_network_settings.memory.memory_size
if policy_network_settings.memory is not None
else 0,
policy_network_settings.vis_encode_type,
)
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0)
self._log_ent_coef = torch.nn.Parameter(
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))),
requires_grad=True,
)
if self.policy.use_continuous_act:
self.target_entropy = torch.as_tensor(
-1
* self.continuous_target_entropy_scale
* np.prod(self.act_size[0]).astype(np.float32)
)
else:
self.target_entropy = [
self.discrete_target_entropy_scale * np.log(i).astype(np.float32)
for i in self.act_size
]
policy_params = list(self.policy.actor_critic.network_body.parameters()) + list(
self.policy.actor_critic.distribution.parameters()
)
value_params = list(self.value_network.parameters()) + list(
self.policy.actor_critic.critic.parameters()
)
logger.debug("value_vars")
for param in value_params:
logger.debug(param.shape)
logger.debug("policy_vars")
for param in policy_params:
logger.debug(param.shape)
self.policy_optimizer = torch.optim.Adam(policy_params, lr=lr)
self.value_optimizer = torch.optim.Adam(value_params, lr=lr)
self.entropy_optimizer = torch.optim.Adam([self._log_ent_coef], lr=lr)
def sac_q_loss(
self,
q1_out: Dict[str, torch.Tensor],
q2_out: Dict[str, torch.Tensor],
target_values: Dict[str, torch.Tensor],
dones: torch.Tensor,
rewards: Dict[str, torch.Tensor],
loss_masks: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
q1_losses = []
q2_losses = []
# Multiple q losses per stream
for i, name in enumerate(q1_out.keys()):
q1_stream = q1_out[name].squeeze()
q2_stream = q2_out[name].squeeze()
with torch.no_grad():
q_backup = rewards[name] + (
(1.0 - self.use_dones_in_backup[name] * dones)
* self.gammas[i]
* target_values[name]
)
_q1_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(q_backup, q1_stream)
)
_q2_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(q_backup, q2_stream)
)
q1_losses.append(_q1_loss)
q2_losses.append(_q2_loss)
q1_loss = torch.mean(torch.stack(q1_losses))
q2_loss = torch.mean(torch.stack(q2_losses))
return q1_loss, q2_loss
def soft_update(self, source: nn.Module, target: nn.Module, tau: float) -> None:
for source_param, target_param in zip(source.parameters(), target.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - tau) + source_param.data * tau
)
def sac_value_loss(
self,
log_probs: torch.Tensor,
values: Dict[str, torch.Tensor],
q1p_out: Dict[str, torch.Tensor],
q2p_out: Dict[str, torch.Tensor],
loss_masks: torch.Tensor,
discrete: bool,
) -> torch.Tensor:
min_policy_qs = {}
with torch.no_grad():
_ent_coef = torch.exp(self._log_ent_coef)
for name in values.keys():
if not discrete:
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name])
else:
action_probs = log_probs.exp()
_branched_q1p = break_into_branches(
q1p_out[name] * action_probs, self.act_size
)
_branched_q2p = break_into_branches(
q2p_out[name] * action_probs, self.act_size
)
_q1p_mean = torch.mean(
torch.stack(
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p]
),
dim=0,
)
_q2p_mean = torch.mean(
torch.stack(
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p]
),
dim=0,
)
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)
value_losses = []
if not discrete:
for name in values.keys():
with torch.no_grad():
v_backup = min_policy_qs[name] - torch.sum(
_ent_coef * log_probs, dim=1
)
# print(log_probs, v_backup, _ent_coef, loss_masks)
value_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(values[name], v_backup)
)
value_losses.append(value_loss)
else:
branched_per_action_ent = break_into_branches(
log_probs * log_probs.exp(), self.act_size
)
# We have to do entropy bonus per action branch
branched_ent_bonus = torch.stack(
[
torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True)
for i, _lp in enumerate(branched_per_action_ent)
]
)
for name in values.keys():
with torch.no_grad():
v_backup = min_policy_qs[name] - torch.mean(
branched_ent_bonus, axis=0
)
value_loss = 0.5 * torch.mean(
loss_masks
* torch.nn.functional.mse_loss(values[name], v_backup.squeeze())
)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))
if torch.isinf(value_loss).any() or torch.isnan(value_loss).any():
raise UnityTrainerException("Inf found")
return value_loss
def sac_policy_loss(
self,
log_probs: torch.Tensor,
q1p_outs: Dict[str, torch.Tensor],
loss_masks: torch.Tensor,
discrete: bool,
) -> torch.Tensor:
_ent_coef = torch.exp(self._log_ent_coef)
mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0)
if not discrete:
mean_q1 = mean_q1.unsqueeze(1)
batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1)
policy_loss = torch.mean(loss_masks * batch_policy_loss)
else:
action_probs = log_probs.exp()
branched_per_action_ent = break_into_branches(
log_probs * action_probs, self.act_size
)
branched_q_term = break_into_branches(mean_q1 * action_probs, self.act_size)
branched_policy_loss = torch.stack(
[
torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True)
for i, (_lp, _qt) in enumerate(
zip(branched_per_action_ent, branched_q_term)
)
]
)
batch_policy_loss = torch.squeeze(branched_policy_loss)
policy_loss = torch.mean(loss_masks * batch_policy_loss)
return policy_loss
def sac_entropy_loss(
self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool
) -> torch.Tensor:
if not discrete:
with torch.no_grad():
target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1)
entropy_loss = -torch.mean(
self._log_ent_coef * loss_masks * target_current_diff
)
else:
with torch.no_grad():
branched_per_action_ent = break_into_branches(
log_probs * log_probs.exp(), self.act_size
)
target_current_diff_branched = torch.stack(
[
torch.sum(_lp, axis=1, keepdim=True) + _te
for _lp, _te in zip(
branched_per_action_ent, self.target_entropy
)
],
axis=1,
)
target_current_diff = torch.squeeze(
target_current_diff_branched, axis=2
)
entropy_loss = -torch.mean(
loss_masks
* torch.mean(self._log_ent_coef * target_current_diff, axis=1)
)
return entropy_loss
def _condense_q_streams(
self, q_output: Dict[str, torch.Tensor], discrete_actions: torch.Tensor
) -> Dict[str, torch.Tensor]:
condensed_q_output = {}
onehot_actions = actions_to_onehot(discrete_actions, self.act_size)
for key, item in q_output.items():
branched_q = break_into_branches(item, self.act_size)
only_action_qs = torch.stack(
[
torch.sum(_act * _q, dim=1, keepdim=True)
for _act, _q in zip(onehot_actions, branched_q)
]
)
condensed_q_output[key] = torch.mean(only_action_qs, dim=0)
return condensed_q_output
@timed
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"""
Updates model using buffer.
:param num_sequences: Number of trajectories in batch.
:param batch: Experience mini-batch.
:param update_target: Whether or not to update target value network
:param reward_signal_batches: Minibatches to use for updating the reward signals,
indexed by name. If none, don't update the reward signals.
:return: Output from update process.
"""
rewards = {}
for name in self.reward_signals:
rewards[name] = list_to_tensor(batch["{}_rewards".format(name)])
vec_obs = [list_to_tensor(batch["vector_obs"])]
next_vec_obs = [list_to_tensor(batch["next_vector_in"])]
act_masks = list_to_tensor(batch["action_mask"])
if self.policy.use_continuous_act:
actions = list_to_tensor(batch["actions"]).unsqueeze(-1)
else:
actions = list_to_tensor(batch["actions"], dtype=torch.long)
memories = [
list_to_tensor(batch["memory"][i])
for i in range(0, len(batch["memory"]), self.policy.sequence_length)
]
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
vis_obs: List[torch.Tensor] = []
next_vis_obs: List[torch.Tensor] = []
if self.policy.use_vis_obs:
vis_obs = []
for idx, _ in enumerate(
self.policy.actor_critic.network_body.visual_encoders
):
vis_ob = list_to_tensor(batch["visual_obs%d" % idx])
vis_obs.append(vis_ob)
next_vis_ob = list_to_tensor(batch["next_visual_obs%d" % idx])
next_vis_obs.append(next_vis_ob)
# Copy normalizers from policy
self.value_network.q1_network.copy_normalization(
self.policy.actor_critic.network_body
)
self.value_network.q2_network.copy_normalization(
self.policy.actor_critic.network_body
)
self.target_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
(
sampled_actions,
log_probs,
entropies,
sampled_values,
_,
) = self.policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,
memories=memories,
seq_len=self.policy.sequence_length,
all_log_probs=not self.policy.use_continuous_act,
)
if self.policy.use_continuous_act:
squeezed_actions = actions.squeeze(-1)
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs, sampled_actions)
q1_out, q2_out = self.value_network(vec_obs, vis_obs, squeezed_actions)
q1_stream, q2_stream = q1_out, q2_out
else:
with torch.no_grad():
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs)
q1_out, q2_out = self.value_network(vec_obs, vis_obs)
q1_stream = self._condense_q_streams(q1_out, actions)
q2_stream = self._condense_q_streams(q2_out, actions)
with torch.no_grad():
target_values, _ = self.target_network(next_vec_obs, next_vis_obs)
masks = list_to_tensor(batch["masks"], dtype=torch.int32)
use_discrete = not self.policy.use_continuous_act
dones = list_to_tensor(batch["done"])
q1_loss, q2_loss = self.sac_q_loss(
q1_stream, q2_stream, target_values, dones, rewards, masks
)
value_loss = self.sac_value_loss(
log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete
)
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete)
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete)
total_value_loss = q1_loss + q2_loss + value_loss
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
self.value_optimizer.zero_grad()
total_value_loss.backward()
self.value_optimizer.step()
self.entropy_optimizer.zero_grad()
entropy_loss.backward()
self.entropy_optimizer.step()
# Update target network
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau)
update_stats = {
"Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()),
"Losses/Value Loss": value_loss.detach().cpu().numpy(),
"Losses/Q1 Loss": q1_loss.detach().cpu().numpy(),
"Losses/Q2 Loss": q2_loss.detach().cpu().numpy(),
"Policy/Entropy Coeff": torch.exp(self._log_ent_coef)
.detach()
.cpu()
.numpy(),
}
return update_stats
def update_reward_signals(
self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int
) -> Dict[str, float]:
return {}
正在加载...
取消
保存