浏览代码
[feature] Add experimental PyTorch support (#4335)
[feature] Add experimental PyTorch support (#4335)
* Begin porting work * Add ResNet and distributions * Dynamically construct actor and critic * Initial optimizer port * Refactoring policy and optimizer * Resolving a few bugs * Share more code between tf and torch policies * Slightly closer to running model * Training runs, but doesn’t actually work * Fix a couple additional bugs * Add conditional sigma for distribution * Fix normalization * Support discrete actions as well * Continuous and discrete now train * Mulkti-discrete now working * Visual observations now train as well * GRU in-progress and dynamic cnns * Fix for memories * Remove unused arg * Combine actor and critic classes. Initial export. * Support tf and pytorch alongside one another * Prepare model for onnx export * Use LSTM and fix a few merge errors * Fix bug in probs calculation * Optimize np -> tensor operations * Time action sample funct.../MLA-1734-demo-provider
GitHub
4 年前
当前提交
1955af9e
共有 52 个文件被更改,包括 5374 次插入 和 156 次删除
-
4com.unity.ml-agents/CHANGELOG.md
-
2ml-agents/mlagents/trainers/buffer.py
-
7ml-agents/mlagents/trainers/cli_utils.py
-
12ml-agents/mlagents/trainers/ghost/trainer.py
-
2ml-agents/mlagents/trainers/policy/tf_policy.py
-
76ml-agents/mlagents/trainers/ppo/trainer.py
-
120ml-agents/mlagents/trainers/sac/trainer.py
-
14ml-agents/mlagents/trainers/settings.py
-
7ml-agents/mlagents/trainers/tests/test_ghost.py
-
5ml-agents/mlagents/trainers/tests/test_rl_trainer.py
-
3ml-agents/mlagents/trainers/tests/test_sac.py
-
2ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
19ml-agents/mlagents/trainers/tests/torch/test_layers.py
-
17ml-agents/mlagents/trainers/tests/torch/test_networks.py
-
6ml-agents/mlagents/trainers/tests/torch/test_utils.py
-
20ml-agents/mlagents/trainers/tf/model_serialization.py
-
17ml-agents/mlagents/trainers/torch/encoders.py
-
67ml-agents/mlagents/trainers/torch/layers.py
-
115ml-agents/mlagents/trainers/torch/networks.py
-
4ml-agents/mlagents/trainers/torch/utils.py
-
99ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
5ml-agents/mlagents/trainers/trainer/trainer.py
-
7ml-agents/mlagents/trainers/trainer_controller.py
-
94ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
-
281ml-agents/mlagents/trainers/policy/torch_policy.py
-
203ml-agents/mlagents/trainers/ppo/optimizer_torch.py
-
561ml-agents/mlagents/trainers/sac/optimizer_torch.py
-
118ml-agents/mlagents/trainers/saver/torch_saver.py
-
1001ml-agents/mlagents/trainers/tests/torch/test.demo
-
144ml-agents/mlagents/trainers/tests/torch/test_bcmodule.py
-
177ml-agents/mlagents/trainers/tests/torch/test_ghost.py
-
150ml-agents/mlagents/trainers/tests/torch/test_policy.py
-
505ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
-
446ml-agents/mlagents/trainers/tests/torch/testdcvis.demo
-
74ml-agents/mlagents/trainers/torch/model_serialization.py
-
111ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
-
56ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
-
138ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
-
32ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
-
0ml-agents/mlagents/trainers/torch/components/__init__.py
-
0ml-agents/mlagents/trainers/torch/components/bc/__init__.py
-
183ml-agents/mlagents/trainers/torch/components/bc/module.py
-
15ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py
-
72ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py
-
15ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py
-
43ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py
-
225ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
-
256ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
|
|||
from typing import Dict, Optional, Tuple, List |
|||
import torch |
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
from mlagents.trainers.torch.components.bc.module import BCModule |
|||
from mlagents.trainers.torch.components.reward_providers import create_reward_provider |
|||
|
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.optimizer import Optimizer |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
|
|||
class TorchOptimizer(Optimizer): # pylint: disable=W0223 |
|||
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|||
super().__init__() |
|||
self.policy = policy |
|||
self.trainer_settings = trainer_settings |
|||
self.update_dict: Dict[str, torch.Tensor] = {} |
|||
self.value_heads: Dict[str, torch.Tensor] = {} |
|||
self.memory_in: torch.Tensor = None |
|||
self.memory_out: torch.Tensor = None |
|||
self.m_size: int = 0 |
|||
self.global_step = torch.tensor(0) |
|||
self.bc_module: Optional[BCModule] = None |
|||
self.create_reward_signals(trainer_settings.reward_signals) |
|||
if trainer_settings.behavioral_cloning is not None: |
|||
self.bc_module = BCModule( |
|||
self.policy, |
|||
trainer_settings.behavioral_cloning, |
|||
policy_learning_rate=trainer_settings.hyperparameters.learning_rate, |
|||
default_batch_size=trainer_settings.hyperparameters.batch_size, |
|||
default_num_epoch=3, |
|||
) |
|||
|
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
pass |
|||
|
|||
def create_reward_signals(self, reward_signal_configs): |
|||
""" |
|||
Create reward signals |
|||
:param reward_signal_configs: Reward signal config. |
|||
""" |
|||
for reward_signal, settings in reward_signal_configs.items(): |
|||
# Name reward signals by string in case we have duplicates later |
|||
self.reward_signals[reward_signal.value] = create_reward_provider( |
|||
reward_signal, self.policy.behavior_spec, settings |
|||
) |
|||
|
|||
def get_trajectory_value_estimates( |
|||
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool |
|||
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: |
|||
vector_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
|||
if self.policy.use_vis_obs: |
|||
visual_obs = [] |
|||
for idx, _ in enumerate( |
|||
self.policy.actor_critic.network_body.visual_encoders |
|||
): |
|||
visual_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
|||
visual_obs.append(visual_ob) |
|||
else: |
|||
visual_obs = [] |
|||
|
|||
memory = torch.zeros([1, 1, self.policy.m_size]) |
|||
|
|||
vec_vis_obs = SplitObservations.from_observations(next_obs) |
|||
next_vec_obs = [ |
|||
ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0) |
|||
] |
|||
next_vis_obs = [ |
|||
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0) |
|||
for _vis_ob in vec_vis_obs.visual_observations |
|||
] |
|||
|
|||
value_estimates, next_memory = self.policy.actor_critic.critic_pass( |
|||
vector_obs, visual_obs, memory, sequence_length=batch.num_experiences |
|||
) |
|||
|
|||
next_value_estimate, _ = self.policy.actor_critic.critic_pass( |
|||
next_vec_obs, next_vis_obs, next_memory, sequence_length=1 |
|||
) |
|||
|
|||
for name, estimate in value_estimates.items(): |
|||
value_estimates[name] = estimate.detach().cpu().numpy() |
|||
next_value_estimate[name] = next_value_estimate[name].detach().cpu().numpy() |
|||
|
|||
if done: |
|||
for k in next_value_estimate: |
|||
if not self.reward_signals[k].ignore_done: |
|||
next_value_estimate[k] = 0.0 |
|||
|
|||
return value_estimates, next_value_estimate |
|
|||
from typing import Any, Dict, List, Tuple, Optional |
|||
import numpy as np |
|||
import torch |
|||
import copy |
|||
|
|||
from mlagents.trainers.action_info import ActionInfo |
|||
from mlagents.trainers.behavior_id_utils import get_global_agent_id |
|||
from mlagents.trainers.policy import Policy |
|||
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec |
|||
from mlagents_envs.timers import timed |
|||
|
|||
from mlagents.trainers.settings import TrainerSettings |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
from mlagents.trainers.torch.networks import ( |
|||
SharedActorCritic, |
|||
SeparateActorCritic, |
|||
GlobalSteps, |
|||
) |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
EPSILON = 1e-7 # Small value to avoid divide by zero |
|||
|
|||
|
|||
class TorchPolicy(Policy): |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
behavior_spec: BehaviorSpec, |
|||
trainer_settings: TrainerSettings, |
|||
tanh_squash: bool = False, |
|||
reparameterize: bool = False, |
|||
separate_critic: bool = True, |
|||
condition_sigma_on_obs: bool = True, |
|||
): |
|||
""" |
|||
Policy that uses a multilayer perceptron to map the observations to actions. Could |
|||
also use a CNN to encode visual input prior to the MLP. Supports discrete and |
|||
continuous action spaces, as well as recurrent networks. |
|||
:param seed: Random seed. |
|||
:param brain: Assigned BrainParameters object. |
|||
:param trainer_settings: Defined training parameters. |
|||
:param load: Whether a pre-trained model will be loaded or a new one created. |
|||
:param tanh_squash: Whether to use a tanh function on the continuous output, |
|||
or a clipped output. |
|||
:param reparameterize: Whether we are using the resampling trick to update the policy |
|||
in continuous output. |
|||
""" |
|||
super().__init__( |
|||
seed, |
|||
behavior_spec, |
|||
trainer_settings, |
|||
tanh_squash, |
|||
reparameterize, |
|||
condition_sigma_on_obs, |
|||
) |
|||
self.global_step = ( |
|||
GlobalSteps() |
|||
) # could be much simpler if TorchPolicy is nn.Module |
|||
self.grads = None |
|||
|
|||
torch.set_default_tensor_type(torch.FloatTensor) |
|||
|
|||
reward_signal_configs = trainer_settings.reward_signals |
|||
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] |
|||
|
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
} |
|||
if separate_critic: |
|||
ac_class = SeparateActorCritic |
|||
else: |
|||
ac_class = SharedActorCritic |
|||
self.actor_critic = ac_class( |
|||
observation_shapes=self.behavior_spec.observation_shapes, |
|||
network_settings=trainer_settings.network_settings, |
|||
act_type=behavior_spec.action_type, |
|||
act_size=self.act_size, |
|||
stream_names=reward_signal_names, |
|||
conditional_sigma=self.condition_sigma_on_obs, |
|||
tanh_squash=tanh_squash, |
|||
) |
|||
# Save the m_size needed for export |
|||
self._export_m_size = self.m_size |
|||
# m_size needed for training is determined by network, not trainer settings |
|||
self.m_size = self.actor_critic.memory_size |
|||
|
|||
self.actor_critic.to("cpu") |
|||
|
|||
@property |
|||
def export_memory_size(self) -> int: |
|||
""" |
|||
Returns the memory size of the exported ONNX policy. This only includes the memory |
|||
of the Actor and not any auxillary networks. |
|||
""" |
|||
return self._export_m_size |
|||
|
|||
def _split_decision_step( |
|||
self, decision_requests: DecisionSteps |
|||
) -> Tuple[SplitObservations, np.ndarray]: |
|||
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs) |
|||
mask = None |
|||
if not self.use_continuous_act: |
|||
mask = torch.ones([len(decision_requests), np.sum(self.act_size)]) |
|||
if decision_requests.action_mask is not None: |
|||
mask = torch.as_tensor( |
|||
1 - np.concatenate(decision_requests.action_mask, axis=1) |
|||
) |
|||
return vec_vis_obs, mask |
|||
|
|||
def update_normalization(self, vector_obs: np.ndarray) -> None: |
|||
""" |
|||
If this policy normalizes vector observations, this will update the norm values in the graph. |
|||
:param vector_obs: The vector observations to add to the running estimate of the distribution. |
|||
""" |
|||
vector_obs = [torch.as_tensor(vector_obs)] |
|||
if self.use_vec_obs and self.normalize: |
|||
self.actor_critic.update_normalization(vector_obs) |
|||
|
|||
@timed |
|||
def sample_actions( |
|||
self, |
|||
vec_obs: List[torch.Tensor], |
|||
vis_obs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
seq_len: int = 1, |
|||
all_log_probs: bool = False, |
|||
) -> Tuple[ |
|||
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor |
|||
]: |
|||
""" |
|||
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action. |
|||
""" |
|||
dists, value_heads, memories = self.actor_critic.get_dist_and_value( |
|||
vec_obs, vis_obs, masks, memories, seq_len |
|||
) |
|||
action_list = self.actor_critic.sample_action(dists) |
|||
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy( |
|||
action_list, dists |
|||
) |
|||
actions = torch.stack(action_list, dim=-1) |
|||
if self.use_continuous_act: |
|||
actions = actions[:, :, 0] |
|||
else: |
|||
actions = actions[:, 0, :] |
|||
|
|||
return ( |
|||
actions, |
|||
all_logs if all_log_probs else log_probs, |
|||
entropies, |
|||
value_heads, |
|||
memories, |
|||
) |
|||
|
|||
def evaluate_actions( |
|||
self, |
|||
vec_obs: torch.Tensor, |
|||
vis_obs: torch.Tensor, |
|||
actions: torch.Tensor, |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
seq_len: int = 1, |
|||
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: |
|||
dists, value_heads, _ = self.actor_critic.get_dist_and_value( |
|||
vec_obs, vis_obs, masks, memories, seq_len |
|||
) |
|||
action_list = [actions[..., i] for i in range(actions.shape[-1])] |
|||
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists) |
|||
|
|||
return log_probs, entropies, value_heads |
|||
|
|||
@timed |
|||
def evaluate( |
|||
self, decision_requests: DecisionSteps, global_agent_ids: List[str] |
|||
) -> Dict[str, Any]: |
|||
""" |
|||
Evaluates policy for the agent experiences provided. |
|||
:param global_agent_ids: |
|||
:param decision_requests: DecisionStep object containing inputs. |
|||
:return: Outputs from network as defined by self.inference_dict. |
|||
""" |
|||
vec_vis_obs, masks = self._split_decision_step(decision_requests) |
|||
vec_obs = [torch.as_tensor(vec_vis_obs.vector_observations)] |
|||
vis_obs = [ |
|||
torch.as_tensor(vis_ob) for vis_ob in vec_vis_obs.visual_observations |
|||
] |
|||
memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze( |
|||
0 |
|||
) |
|||
|
|||
run_out = {} |
|||
with torch.no_grad(): |
|||
action, log_probs, entropy, value_heads, memories = self.sample_actions( |
|||
vec_obs, vis_obs, masks=masks, memories=memories |
|||
) |
|||
run_out["action"] = action.detach().cpu().numpy() |
|||
run_out["pre_action"] = action.detach().cpu().numpy() |
|||
# Todo - make pre_action difference |
|||
run_out["log_probs"] = log_probs.detach().cpu().numpy() |
|||
run_out["entropy"] = entropy.detach().cpu().numpy() |
|||
run_out["value_heads"] = { |
|||
name: t.detach().cpu().numpy() for name, t in value_heads.items() |
|||
} |
|||
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0) |
|||
run_out["learning_rate"] = 0.0 |
|||
if self.use_recurrent: |
|||
run_out["memory_out"] = memories.detach().cpu().numpy().squeeze(0) |
|||
return run_out |
|||
|
|||
def get_action( |
|||
self, decision_requests: DecisionSteps, worker_id: int = 0 |
|||
) -> ActionInfo: |
|||
""" |
|||
Decides actions given observations information, and takes them in environment. |
|||
:param worker_id: |
|||
:param decision_requests: A dictionary of brain names and BrainInfo from environment. |
|||
:return: an ActionInfo containing action, memories, values and an object |
|||
to be passed to add experiences |
|||
""" |
|||
if len(decision_requests) == 0: |
|||
return ActionInfo.empty() |
|||
|
|||
global_agent_ids = [ |
|||
get_global_agent_id(worker_id, int(agent_id)) |
|||
for agent_id in decision_requests.agent_id |
|||
] # For 1-D array, the iterator order is correct. |
|||
|
|||
run_out = self.evaluate( |
|||
decision_requests, global_agent_ids |
|||
) # pylint: disable=assignment-from-no-return |
|||
self.save_memories(global_agent_ids, run_out.get("memory_out")) |
|||
return ActionInfo( |
|||
action=run_out.get("action"), |
|||
value=run_out.get("value"), |
|||
outputs=run_out, |
|||
agent_ids=list(decision_requests.agent_id), |
|||
) |
|||
|
|||
@property |
|||
def use_vis_obs(self): |
|||
return self.vis_obs_size > 0 |
|||
|
|||
@property |
|||
def use_vec_obs(self): |
|||
return self.vec_obs_size > 0 |
|||
|
|||
def get_current_step(self): |
|||
""" |
|||
Gets current model step. |
|||
:return: current model step. |
|||
""" |
|||
return self.global_step.current_step |
|||
|
|||
def set_step(self, step: int) -> int: |
|||
""" |
|||
Sets current model step to step without creating additional ops. |
|||
:param step: Step to set the current model step to. |
|||
:return: The step the model was set to. |
|||
""" |
|||
self.global_step.current_step = step |
|||
return step |
|||
|
|||
def increment_step(self, n_steps): |
|||
""" |
|||
Increments model step. |
|||
""" |
|||
self.global_step.increment(n_steps) |
|||
return self.get_current_step() |
|||
|
|||
def load_weights(self, values: List[np.ndarray]) -> None: |
|||
self.actor_critic.load_state_dict(values) |
|||
|
|||
def init_load_weights(self) -> None: |
|||
pass |
|||
|
|||
def get_weights(self) -> List[np.ndarray]: |
|||
return copy.deepcopy(self.actor_critic.state_dict()) |
|||
|
|||
def get_modules(self): |
|||
return {"Policy": self.actor_critic, "global_step": self.global_step} |
|
|||
from typing import Dict, cast |
|||
import torch |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
|
|||
from mlagents_envs.timers import timed |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|||
from mlagents.trainers.settings import TrainerSettings, PPOSettings |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
|
|||
class TorchPPOOptimizer(TorchOptimizer): |
|||
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|||
""" |
|||
Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy. |
|||
The PPO optimizer has a value estimator and a loss function. |
|||
:param policy: A TFPolicy object that will be updated by this PPO Optimizer. |
|||
:param trainer_params: Trainer parameters dictionary that specifies the |
|||
properties of the trainer. |
|||
""" |
|||
# Create the graph here to give more granular control of the TF graph to the Optimizer. |
|||
|
|||
super().__init__(policy, trainer_settings) |
|||
params = list(self.policy.actor_critic.parameters()) |
|||
self.hyperparameters: PPOSettings = cast( |
|||
PPOSettings, trainer_settings.hyperparameters |
|||
) |
|||
self.decay_learning_rate = ModelUtils.DecayedValue( |
|||
self.hyperparameters.learning_rate_schedule, |
|||
self.hyperparameters.learning_rate, |
|||
1e-10, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
self.decay_epsilon = ModelUtils.DecayedValue( |
|||
self.hyperparameters.learning_rate_schedule, |
|||
self.hyperparameters.epsilon, |
|||
0.1, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
self.decay_beta = ModelUtils.DecayedValue( |
|||
self.hyperparameters.learning_rate_schedule, |
|||
self.hyperparameters.beta, |
|||
1e-5, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
|
|||
self.optimizer = torch.optim.Adam( |
|||
params, lr=self.trainer_settings.hyperparameters.learning_rate |
|||
) |
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
} |
|||
|
|||
self.stream_names = list(self.reward_signals.keys()) |
|||
|
|||
def ppo_value_loss( |
|||
self, |
|||
values: Dict[str, torch.Tensor], |
|||
old_values: Dict[str, torch.Tensor], |
|||
returns: Dict[str, torch.Tensor], |
|||
epsilon: float, |
|||
loss_masks: torch.Tensor, |
|||
) -> torch.Tensor: |
|||
""" |
|||
Evaluates value loss for PPO. |
|||
:param values: Value output of the current network. |
|||
:param old_values: Value stored with experiences in buffer. |
|||
:param returns: Computed returns. |
|||
:param epsilon: Clipping value for value estimate. |
|||
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences. |
|||
""" |
|||
value_losses = [] |
|||
for name, head in values.items(): |
|||
old_val_tensor = old_values[name] |
|||
returns_tensor = returns[name] |
|||
clipped_value_estimate = old_val_tensor + torch.clamp( |
|||
head - old_val_tensor, -1 * epsilon, epsilon |
|||
) |
|||
v_opt_a = (returns_tensor - head) ** 2 |
|||
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 |
|||
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks) |
|||
value_losses.append(value_loss) |
|||
value_loss = torch.mean(torch.stack(value_losses)) |
|||
return value_loss |
|||
|
|||
def ppo_policy_loss( |
|||
self, |
|||
advantages: torch.Tensor, |
|||
log_probs: torch.Tensor, |
|||
old_log_probs: torch.Tensor, |
|||
loss_masks: torch.Tensor, |
|||
) -> torch.Tensor: |
|||
""" |
|||
Evaluate PPO policy loss. |
|||
:param advantages: Computed advantages. |
|||
:param log_probs: Current policy probabilities |
|||
:param old_log_probs: Past policy probabilities |
|||
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences. |
|||
""" |
|||
advantage = advantages.unsqueeze(-1) |
|||
|
|||
decay_epsilon = self.hyperparameters.epsilon |
|||
|
|||
r_theta = torch.exp(log_probs - old_log_probs) |
|||
p_opt_a = r_theta * advantage |
|||
p_opt_b = ( |
|||
torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage |
|||
) |
|||
policy_loss = -1 * ModelUtils.masked_mean( |
|||
torch.min(p_opt_a, p_opt_b), loss_masks |
|||
) |
|||
return policy_loss |
|||
|
|||
@timed |
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
""" |
|||
Performs update on model. |
|||
:param batch: Batch of experiences. |
|||
:param num_sequences: Number of sequences to process. |
|||
:return: Results of update. |
|||
""" |
|||
# Get decayed parameters |
|||
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|||
decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step()) |
|||
decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) |
|||
returns = {} |
|||
old_values = {} |
|||
for name in self.reward_signals: |
|||
old_values[name] = ModelUtils.list_to_tensor( |
|||
batch[f"{name}_value_estimates"] |
|||
) |
|||
returns[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns"]) |
|||
|
|||
vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
|||
act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) |
|||
if self.policy.use_continuous_act: |
|||
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) |
|||
else: |
|||
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) |
|||
|
|||
memories = [ |
|||
ModelUtils.list_to_tensor(batch["memory"][i]) |
|||
for i in range(0, len(batch["memory"]), self.policy.sequence_length) |
|||
] |
|||
if len(memories) > 0: |
|||
memories = torch.stack(memories).unsqueeze(0) |
|||
|
|||
if self.policy.use_vis_obs: |
|||
vis_obs = [] |
|||
for idx, _ in enumerate( |
|||
self.policy.actor_critic.network_body.visual_encoders |
|||
): |
|||
vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
|||
vis_obs.append(vis_ob) |
|||
else: |
|||
vis_obs = [] |
|||
log_probs, entropy, values = self.policy.evaluate_actions( |
|||
vec_obs, |
|||
vis_obs, |
|||
masks=act_masks, |
|||
actions=actions, |
|||
memories=memories, |
|||
seq_len=self.policy.sequence_length, |
|||
) |
|||
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|||
value_loss = self.ppo_value_loss( |
|||
values, old_values, returns, decay_eps, loss_masks |
|||
) |
|||
policy_loss = self.ppo_policy_loss( |
|||
ModelUtils.list_to_tensor(batch["advantages"]), |
|||
log_probs, |
|||
ModelUtils.list_to_tensor(batch["action_probs"]), |
|||
loss_masks, |
|||
) |
|||
loss = ( |
|||
policy_loss |
|||
+ 0.5 * value_loss |
|||
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks) |
|||
) |
|||
|
|||
# Set optimizer learning rate |
|||
ModelUtils.update_learning_rate(self.optimizer, decay_lr) |
|||
self.optimizer.zero_grad() |
|||
loss.backward() |
|||
|
|||
self.optimizer.step() |
|||
update_stats = { |
|||
"Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()), |
|||
"Losses/Value Loss": value_loss.detach().cpu().numpy(), |
|||
"Policy/Learning Rate": decay_lr, |
|||
"Policy/Epsilon": decay_eps, |
|||
"Policy/Beta": decay_bet, |
|||
} |
|||
|
|||
for reward_provider in self.reward_signals.values(): |
|||
update_stats.update(reward_provider.update(batch)) |
|||
|
|||
return update_stats |
|||
|
|||
def get_modules(self): |
|||
return {"Optimizer": self.optimizer} |
|
|||
import numpy as np |
|||
from typing import Dict, List, Mapping, cast, Tuple, Optional |
|||
import torch |
|||
from torch import nn |
|||
import attr |
|||
|
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents_envs.base_env import ActionType |
|||
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.settings import NetworkSettings |
|||
from mlagents.trainers.torch.networks import ValueNetwork |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents_envs.timers import timed |
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
from mlagents.trainers.settings import TrainerSettings, SACSettings |
|||
|
|||
EPSILON = 1e-6 # Small value to avoid divide by zero |
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class TorchSACOptimizer(TorchOptimizer): |
|||
class PolicyValueNetwork(nn.Module): |
|||
def __init__( |
|||
self, |
|||
stream_names: List[str], |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
act_type: ActionType, |
|||
act_size: List[int], |
|||
): |
|||
super().__init__() |
|||
if act_type == ActionType.CONTINUOUS: |
|||
num_value_outs = 1 |
|||
num_action_ins = sum(act_size) |
|||
else: |
|||
num_value_outs = sum(act_size) |
|||
num_action_ins = 0 |
|||
self.q1_network = ValueNetwork( |
|||
stream_names, |
|||
observation_shapes, |
|||
network_settings, |
|||
num_action_ins, |
|||
num_value_outs, |
|||
) |
|||
self.q2_network = ValueNetwork( |
|||
stream_names, |
|||
observation_shapes, |
|||
network_settings, |
|||
num_action_ins, |
|||
num_value_outs, |
|||
) |
|||
|
|||
def forward( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
actions: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: |
|||
q1_out, _ = self.q1_network( |
|||
vec_inputs, |
|||
vis_inputs, |
|||
actions=actions, |
|||
memories=memories, |
|||
sequence_length=sequence_length, |
|||
) |
|||
q2_out, _ = self.q2_network( |
|||
vec_inputs, |
|||
vis_inputs, |
|||
actions=actions, |
|||
memories=memories, |
|||
sequence_length=sequence_length, |
|||
) |
|||
return q1_out, q2_out |
|||
|
|||
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): |
|||
super().__init__(policy, trainer_params) |
|||
hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters) |
|||
self.tau = hyperparameters.tau |
|||
self.init_entcoef = hyperparameters.init_entcoef |
|||
|
|||
self.policy = policy |
|||
self.act_size = policy.act_size |
|||
policy_network_settings = policy.network_settings |
|||
|
|||
self.tau = hyperparameters.tau |
|||
self.burn_in_ratio = 0.0 |
|||
|
|||
# Non-exposed SAC parameters |
|||
self.discrete_target_entropy_scale = 0.2 # Roughly equal to e-greedy 0.05 |
|||
self.continuous_target_entropy_scale = 1.0 |
|||
|
|||
self.stream_names = list(self.reward_signals.keys()) |
|||
# Use to reduce "survivor bonus" when using Curiosity or GAIL. |
|||
self.gammas = [_val.gamma for _val in trainer_params.reward_signals.values()] |
|||
self.use_dones_in_backup = { |
|||
name: int(not self.reward_signals[name].ignore_done) |
|||
for name in self.stream_names |
|||
} |
|||
|
|||
# Critics should have 1/2 of the memory of the policy |
|||
critic_memory = policy_network_settings.memory |
|||
if critic_memory is not None: |
|||
critic_memory = attr.evolve( |
|||
critic_memory, memory_size=critic_memory.memory_size // 2 |
|||
) |
|||
value_network_settings = attr.evolve( |
|||
policy_network_settings, memory=critic_memory |
|||
) |
|||
|
|||
self.value_network = TorchSACOptimizer.PolicyValueNetwork( |
|||
self.stream_names, |
|||
self.policy.behavior_spec.observation_shapes, |
|||
value_network_settings, |
|||
self.policy.behavior_spec.action_type, |
|||
self.act_size, |
|||
) |
|||
|
|||
self.target_network = ValueNetwork( |
|||
self.stream_names, |
|||
self.policy.behavior_spec.observation_shapes, |
|||
value_network_settings, |
|||
) |
|||
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0) |
|||
|
|||
self._log_ent_coef = torch.nn.Parameter( |
|||
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))), |
|||
requires_grad=True, |
|||
) |
|||
if self.policy.use_continuous_act: |
|||
self.target_entropy = torch.as_tensor( |
|||
-1 |
|||
* self.continuous_target_entropy_scale |
|||
* np.prod(self.act_size[0]).astype(np.float32) |
|||
) |
|||
else: |
|||
self.target_entropy = [ |
|||
self.discrete_target_entropy_scale * np.log(i).astype(np.float32) |
|||
for i in self.act_size |
|||
] |
|||
|
|||
policy_params = list(self.policy.actor_critic.network_body.parameters()) + list( |
|||
self.policy.actor_critic.distribution.parameters() |
|||
) |
|||
value_params = list(self.value_network.parameters()) + list( |
|||
self.policy.actor_critic.critic.parameters() |
|||
) |
|||
|
|||
logger.debug("value_vars") |
|||
for param in value_params: |
|||
logger.debug(param.shape) |
|||
logger.debug("policy_vars") |
|||
for param in policy_params: |
|||
logger.debug(param.shape) |
|||
|
|||
self.decay_learning_rate = ModelUtils.DecayedValue( |
|||
hyperparameters.learning_rate_schedule, |
|||
hyperparameters.learning_rate, |
|||
1e-10, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
self.policy_optimizer = torch.optim.Adam( |
|||
policy_params, lr=hyperparameters.learning_rate |
|||
) |
|||
self.value_optimizer = torch.optim.Adam( |
|||
value_params, lr=hyperparameters.learning_rate |
|||
) |
|||
self.entropy_optimizer = torch.optim.Adam( |
|||
[self._log_ent_coef], lr=hyperparameters.learning_rate |
|||
) |
|||
|
|||
def sac_q_loss( |
|||
self, |
|||
q1_out: Dict[str, torch.Tensor], |
|||
q2_out: Dict[str, torch.Tensor], |
|||
target_values: Dict[str, torch.Tensor], |
|||
dones: torch.Tensor, |
|||
rewards: Dict[str, torch.Tensor], |
|||
loss_masks: torch.Tensor, |
|||
) -> Tuple[torch.Tensor, torch.Tensor]: |
|||
q1_losses = [] |
|||
q2_losses = [] |
|||
# Multiple q losses per stream |
|||
for i, name in enumerate(q1_out.keys()): |
|||
q1_stream = q1_out[name].squeeze() |
|||
q2_stream = q2_out[name].squeeze() |
|||
with torch.no_grad(): |
|||
q_backup = rewards[name] + ( |
|||
(1.0 - self.use_dones_in_backup[name] * dones) |
|||
* self.gammas[i] |
|||
* target_values[name] |
|||
) |
|||
_q1_loss = 0.5 * ModelUtils.masked_mean( |
|||
torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks |
|||
) |
|||
_q2_loss = 0.5 * ModelUtils.masked_mean( |
|||
torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks |
|||
) |
|||
|
|||
q1_losses.append(_q1_loss) |
|||
q2_losses.append(_q2_loss) |
|||
q1_loss = torch.mean(torch.stack(q1_losses)) |
|||
q2_loss = torch.mean(torch.stack(q2_losses)) |
|||
return q1_loss, q2_loss |
|||
|
|||
def soft_update(self, source: nn.Module, target: nn.Module, tau: float) -> None: |
|||
for source_param, target_param in zip(source.parameters(), target.parameters()): |
|||
target_param.data.copy_( |
|||
target_param.data * (1.0 - tau) + source_param.data * tau |
|||
) |
|||
|
|||
def sac_value_loss( |
|||
self, |
|||
log_probs: torch.Tensor, |
|||
values: Dict[str, torch.Tensor], |
|||
q1p_out: Dict[str, torch.Tensor], |
|||
q2p_out: Dict[str, torch.Tensor], |
|||
loss_masks: torch.Tensor, |
|||
discrete: bool, |
|||
) -> torch.Tensor: |
|||
min_policy_qs = {} |
|||
with torch.no_grad(): |
|||
_ent_coef = torch.exp(self._log_ent_coef) |
|||
for name in values.keys(): |
|||
if not discrete: |
|||
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) |
|||
else: |
|||
action_probs = log_probs.exp() |
|||
_branched_q1p = ModelUtils.break_into_branches( |
|||
q1p_out[name] * action_probs, self.act_size |
|||
) |
|||
_branched_q2p = ModelUtils.break_into_branches( |
|||
q2p_out[name] * action_probs, self.act_size |
|||
) |
|||
_q1p_mean = torch.mean( |
|||
torch.stack( |
|||
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p] |
|||
), |
|||
dim=0, |
|||
) |
|||
_q2p_mean = torch.mean( |
|||
torch.stack( |
|||
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p] |
|||
), |
|||
dim=0, |
|||
) |
|||
|
|||
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) |
|||
|
|||
value_losses = [] |
|||
if not discrete: |
|||
for name in values.keys(): |
|||
with torch.no_grad(): |
|||
v_backup = min_policy_qs[name] - torch.sum( |
|||
_ent_coef * log_probs, dim=1 |
|||
) |
|||
value_loss = 0.5 * ModelUtils.masked_mean( |
|||
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks |
|||
) |
|||
value_losses.append(value_loss) |
|||
else: |
|||
branched_per_action_ent = ModelUtils.break_into_branches( |
|||
log_probs * log_probs.exp(), self.act_size |
|||
) |
|||
# We have to do entropy bonus per action branch |
|||
branched_ent_bonus = torch.stack( |
|||
[ |
|||
torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True) |
|||
for i, _lp in enumerate(branched_per_action_ent) |
|||
] |
|||
) |
|||
for name in values.keys(): |
|||
with torch.no_grad(): |
|||
v_backup = min_policy_qs[name] - torch.mean( |
|||
branched_ent_bonus, axis=0 |
|||
) |
|||
value_loss = 0.5 * ModelUtils.masked_mean( |
|||
torch.nn.functional.mse_loss(values[name], v_backup.squeeze()), |
|||
loss_masks, |
|||
) |
|||
value_losses.append(value_loss) |
|||
value_loss = torch.mean(torch.stack(value_losses)) |
|||
if torch.isinf(value_loss).any() or torch.isnan(value_loss).any(): |
|||
raise UnityTrainerException("Inf found") |
|||
return value_loss |
|||
|
|||
def sac_policy_loss( |
|||
self, |
|||
log_probs: torch.Tensor, |
|||
q1p_outs: Dict[str, torch.Tensor], |
|||
loss_masks: torch.Tensor, |
|||
discrete: bool, |
|||
) -> torch.Tensor: |
|||
_ent_coef = torch.exp(self._log_ent_coef) |
|||
mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0) |
|||
if not discrete: |
|||
mean_q1 = mean_q1.unsqueeze(1) |
|||
batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1) |
|||
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) |
|||
else: |
|||
action_probs = log_probs.exp() |
|||
branched_per_action_ent = ModelUtils.break_into_branches( |
|||
log_probs * action_probs, self.act_size |
|||
) |
|||
branched_q_term = ModelUtils.break_into_branches( |
|||
mean_q1 * action_probs, self.act_size |
|||
) |
|||
branched_policy_loss = torch.stack( |
|||
[ |
|||
torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True) |
|||
for i, (_lp, _qt) in enumerate( |
|||
zip(branched_per_action_ent, branched_q_term) |
|||
) |
|||
] |
|||
) |
|||
batch_policy_loss = torch.squeeze(branched_policy_loss) |
|||
policy_loss = torch.mean(loss_masks * batch_policy_loss) |
|||
return policy_loss |
|||
|
|||
def sac_entropy_loss( |
|||
self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool |
|||
) -> torch.Tensor: |
|||
if not discrete: |
|||
with torch.no_grad(): |
|||
target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1) |
|||
entropy_loss = -torch.mean( |
|||
self._log_ent_coef * loss_masks * target_current_diff |
|||
) |
|||
else: |
|||
with torch.no_grad(): |
|||
branched_per_action_ent = ModelUtils.break_into_branches( |
|||
log_probs * log_probs.exp(), self.act_size |
|||
) |
|||
target_current_diff_branched = torch.stack( |
|||
[ |
|||
torch.sum(_lp, axis=1, keepdim=True) + _te |
|||
for _lp, _te in zip( |
|||
branched_per_action_ent, self.target_entropy |
|||
) |
|||
], |
|||
axis=1, |
|||
) |
|||
target_current_diff = torch.squeeze( |
|||
target_current_diff_branched, axis=2 |
|||
) |
|||
entropy_loss = -1 * ModelUtils.masked_mean( |
|||
torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks |
|||
) |
|||
|
|||
return entropy_loss |
|||
|
|||
def _condense_q_streams( |
|||
self, q_output: Dict[str, torch.Tensor], discrete_actions: torch.Tensor |
|||
) -> Dict[str, torch.Tensor]: |
|||
condensed_q_output = {} |
|||
onehot_actions = ModelUtils.actions_to_onehot(discrete_actions, self.act_size) |
|||
for key, item in q_output.items(): |
|||
branched_q = ModelUtils.break_into_branches(item, self.act_size) |
|||
only_action_qs = torch.stack( |
|||
[ |
|||
torch.sum(_act * _q, dim=1, keepdim=True) |
|||
for _act, _q in zip(onehot_actions, branched_q) |
|||
] |
|||
) |
|||
|
|||
condensed_q_output[key] = torch.mean(only_action_qs, dim=0) |
|||
return condensed_q_output |
|||
|
|||
@timed |
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
""" |
|||
Updates model using buffer. |
|||
:param num_sequences: Number of trajectories in batch. |
|||
:param batch: Experience mini-batch. |
|||
:param update_target: Whether or not to update target value network |
|||
:param reward_signal_batches: Minibatches to use for updating the reward signals, |
|||
indexed by name. If none, don't update the reward signals. |
|||
:return: Output from update process. |
|||
""" |
|||
rewards = {} |
|||
for name in self.reward_signals: |
|||
rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) |
|||
|
|||
vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
|||
next_vec_obs = [ModelUtils.list_to_tensor(batch["next_vector_in"])] |
|||
act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) |
|||
if self.policy.use_continuous_act: |
|||
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) |
|||
else: |
|||
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) |
|||
|
|||
memories_list = [ |
|||
ModelUtils.list_to_tensor(batch["memory"][i]) |
|||
for i in range(0, len(batch["memory"]), self.policy.sequence_length) |
|||
] |
|||
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. |
|||
offset = 1 if self.policy.sequence_length > 1 else 0 |
|||
next_memories_list = [ |
|||
ModelUtils.list_to_tensor( |
|||
batch["memory"][i][self.policy.m_size // 2 :] |
|||
) # only pass value part of memory to target network |
|||
for i in range(offset, len(batch["memory"]), self.policy.sequence_length) |
|||
] |
|||
|
|||
if len(memories_list) > 0: |
|||
memories = torch.stack(memories_list).unsqueeze(0) |
|||
next_memories = torch.stack(next_memories_list).unsqueeze(0) |
|||
else: |
|||
memories = None |
|||
next_memories = None |
|||
# Q network memories are 0'ed out, since we don't have them during inference. |
|||
q_memories = ( |
|||
torch.zeros_like(next_memories) if next_memories is not None else None |
|||
) |
|||
|
|||
vis_obs: List[torch.Tensor] = [] |
|||
next_vis_obs: List[torch.Tensor] = [] |
|||
if self.policy.use_vis_obs: |
|||
vis_obs = [] |
|||
for idx, _ in enumerate( |
|||
self.policy.actor_critic.network_body.visual_encoders |
|||
): |
|||
vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
|||
vis_obs.append(vis_ob) |
|||
next_vis_ob = ModelUtils.list_to_tensor( |
|||
batch["next_visual_obs%d" % idx] |
|||
) |
|||
next_vis_obs.append(next_vis_ob) |
|||
|
|||
# Copy normalizers from policy |
|||
self.value_network.q1_network.network_body.copy_normalization( |
|||
self.policy.actor_critic.network_body |
|||
) |
|||
self.value_network.q2_network.network_body.copy_normalization( |
|||
self.policy.actor_critic.network_body |
|||
) |
|||
self.target_network.network_body.copy_normalization( |
|||
self.policy.actor_critic.network_body |
|||
) |
|||
( |
|||
sampled_actions, |
|||
log_probs, |
|||
entropies, |
|||
sampled_values, |
|||
_, |
|||
) = self.policy.sample_actions( |
|||
vec_obs, |
|||
vis_obs, |
|||
masks=act_masks, |
|||
memories=memories, |
|||
seq_len=self.policy.sequence_length, |
|||
all_log_probs=not self.policy.use_continuous_act, |
|||
) |
|||
if self.policy.use_continuous_act: |
|||
squeezed_actions = actions.squeeze(-1) |
|||
q1p_out, q2p_out = self.value_network( |
|||
vec_obs, |
|||
vis_obs, |
|||
sampled_actions, |
|||
memories=q_memories, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
q1_out, q2_out = self.value_network( |
|||
vec_obs, |
|||
vis_obs, |
|||
squeezed_actions, |
|||
memories=q_memories, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
q1_stream, q2_stream = q1_out, q2_out |
|||
else: |
|||
with torch.no_grad(): |
|||
q1p_out, q2p_out = self.value_network( |
|||
vec_obs, |
|||
vis_obs, |
|||
memories=q_memories, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
q1_out, q2_out = self.value_network( |
|||
vec_obs, |
|||
vis_obs, |
|||
memories=q_memories, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
q1_stream = self._condense_q_streams(q1_out, actions) |
|||
q2_stream = self._condense_q_streams(q2_out, actions) |
|||
|
|||
with torch.no_grad(): |
|||
target_values, _ = self.target_network( |
|||
next_vec_obs, |
|||
next_vis_obs, |
|||
memories=next_memories, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|||
use_discrete = not self.policy.use_continuous_act |
|||
dones = ModelUtils.list_to_tensor(batch["done"]) |
|||
|
|||
q1_loss, q2_loss = self.sac_q_loss( |
|||
q1_stream, q2_stream, target_values, dones, rewards, masks |
|||
) |
|||
value_loss = self.sac_value_loss( |
|||
log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete |
|||
) |
|||
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete) |
|||
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete) |
|||
|
|||
total_value_loss = q1_loss + q2_loss + value_loss |
|||
|
|||
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|||
ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) |
|||
self.policy_optimizer.zero_grad() |
|||
policy_loss.backward() |
|||
self.policy_optimizer.step() |
|||
|
|||
ModelUtils.update_learning_rate(self.value_optimizer, decay_lr) |
|||
self.value_optimizer.zero_grad() |
|||
total_value_loss.backward() |
|||
self.value_optimizer.step() |
|||
|
|||
ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) |
|||
self.entropy_optimizer.zero_grad() |
|||
entropy_loss.backward() |
|||
self.entropy_optimizer.step() |
|||
|
|||
# Update target network |
|||
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau) |
|||
update_stats = { |
|||
"Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()), |
|||
"Losses/Value Loss": value_loss.detach().cpu().numpy(), |
|||
"Losses/Q1 Loss": q1_loss.detach().cpu().numpy(), |
|||
"Losses/Q2 Loss": q2_loss.detach().cpu().numpy(), |
|||
"Policy/Entropy Coeff": torch.exp(self._log_ent_coef) |
|||
.detach() |
|||
.cpu() |
|||
.numpy(), |
|||
"Policy/Learning Rate": decay_lr, |
|||
} |
|||
|
|||
for signal in self.reward_signals.values(): |
|||
signal.update(batch) |
|||
|
|||
return update_stats |
|||
|
|||
def update_reward_signals( |
|||
self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int |
|||
) -> Dict[str, float]: |
|||
return {} |
|||
|
|||
def get_modules(self): |
|||
return { |
|||
"Optimizer:value_network": self.value_network, |
|||
"Optimizer:target_network": self.target_network, |
|||
"Optimizer:policy_optimizer": self.policy_optimizer, |
|||
"Optimizer:value_optimizer": self.value_optimizer, |
|||
"Optimizer:entropy_optimizer": self.entropy_optimizer, |
|||
} |
|
|||
import os |
|||
import shutil |
|||
import torch |
|||
from typing import Dict, Union, Optional, cast |
|||
from mlagents_envs.exception import UnityPolicyException |
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents.trainers.saver.saver import BaseSaver |
|||
from mlagents.trainers.settings import TrainerSettings, SerializationSettings |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|||
from mlagents.trainers.torch.model_serialization import ModelSerializer |
|||
|
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class TorchSaver(BaseSaver): |
|||
""" |
|||
Saver class for PyTorch |
|||
""" |
|||
|
|||
def __init__( |
|||
self, trainer_settings: TrainerSettings, model_path: str, load: bool = False |
|||
): |
|||
super().__init__() |
|||
self.model_path = model_path |
|||
self.initialize_path = trainer_settings.init_path |
|||
self._keep_checkpoints = trainer_settings.keep_checkpoints |
|||
self.load = load |
|||
|
|||
self.policy: Optional[TorchPolicy] = None |
|||
self.exporter: Optional[ModelSerializer] = None |
|||
self.modules: Dict[str, torch.nn.Modules] = {} |
|||
|
|||
def register(self, module: Union[TorchPolicy, TorchOptimizer]) -> None: |
|||
if isinstance(module, TorchPolicy) or isinstance(module, TorchOptimizer): |
|||
self.modules.update(module.get_modules()) # type: ignore |
|||
else: |
|||
raise UnityPolicyException( |
|||
"Registering Object of unsupported type {} to Saver ".format( |
|||
type(module) |
|||
) |
|||
) |
|||
if self.policy is None and isinstance(module, TorchPolicy): |
|||
self.policy = module |
|||
self.exporter = ModelSerializer(self.policy) |
|||
|
|||
def save_checkpoint(self, brain_name: str, step: int) -> str: |
|||
if not os.path.exists(self.model_path): |
|||
os.makedirs(self.model_path) |
|||
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}") |
|||
state_dict = { |
|||
name: module.state_dict() for name, module in self.modules.items() |
|||
} |
|||
torch.save(state_dict, f"{checkpoint_path}.pt") |
|||
torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt")) |
|||
self.export(checkpoint_path, brain_name) |
|||
return checkpoint_path |
|||
|
|||
def export(self, output_filepath: str, brain_name: str) -> None: |
|||
if self.exporter is not None: |
|||
self.exporter.export_policy_model(output_filepath) |
|||
|
|||
def initialize_or_load(self, policy: Optional[TorchPolicy] = None) -> None: |
|||
# Initialize/Load registered self.policy by default. |
|||
# If given input argument policy, use the input policy instead. |
|||
# This argument is mainly for initialization of the ghost trainer's fixed policy. |
|||
reset_steps = not self.load |
|||
if self.initialize_path is not None: |
|||
self._load_model( |
|||
self.initialize_path, policy, reset_global_steps=reset_steps |
|||
) |
|||
elif self.load: |
|||
self._load_model(self.model_path, policy, reset_global_steps=reset_steps) |
|||
|
|||
def _load_model( |
|||
self, |
|||
load_path: str, |
|||
policy: Optional[TorchPolicy] = None, |
|||
reset_global_steps: bool = False, |
|||
) -> None: |
|||
model_path = os.path.join(load_path, "checkpoint.pt") |
|||
saved_state_dict = torch.load(model_path) |
|||
if policy is None: |
|||
modules = self.modules |
|||
policy = self.policy |
|||
else: |
|||
modules = policy.get_modules() |
|||
policy = cast(TorchPolicy, policy) |
|||
|
|||
for name, mod in modules.items(): |
|||
mod.load_state_dict(saved_state_dict[name]) |
|||
|
|||
if reset_global_steps: |
|||
policy.set_step(0) |
|||
logger.info( |
|||
"Starting training from step 0 and saving to {}.".format( |
|||
self.model_path |
|||
) |
|||
) |
|||
else: |
|||
logger.info(f"Resuming training from step {policy.get_current_step()}.") |
|||
|
|||
def copy_final_model(self, source_nn_path: str) -> None: |
|||
""" |
|||
Copy the .nn file at the given source to the destination. |
|||
Also copies the corresponding .onnx file if it exists. |
|||
""" |
|||
final_model_name = os.path.splitext(source_nn_path)[0] |
|||
|
|||
if SerializationSettings.convert_to_onnx: |
|||
try: |
|||
source_path = f"{final_model_name}.onnx" |
|||
destination_path = f"{self.model_path}.onnx" |
|||
shutil.copyfile(source_path, destination_path) |
|||
logger.info(f"Copied {source_path} to {destination_path}.") |
|||
except OSError: |
|||
pass |
1001
ml-agents/mlagents/trainers/tests/torch/test.demo
文件差异内容过多而无法显示
查看文件
文件差异内容过多而无法显示
查看文件
|
|||
from unittest.mock import MagicMock |
|||
import pytest |
|||
import mlagents.trainers.tests.mock_brain as mb |
|||
|
|||
import numpy as np |
|||
import os |
|||
|
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.torch.components.bc.module import BCModule |
|||
from mlagents.trainers.settings import ( |
|||
TrainerSettings, |
|||
BehavioralCloningSettings, |
|||
NetworkSettings, |
|||
) |
|||
|
|||
|
|||
def create_bc_module(mock_behavior_specs, bc_settings, use_rnn, tanhresample): |
|||
# model_path = env.external_brain_names[0] |
|||
trainer_config = TrainerSettings() |
|||
trainer_config.network_settings.memory = ( |
|||
NetworkSettings.MemorySettings() if use_rnn else None |
|||
) |
|||
policy = TorchPolicy( |
|||
0, mock_behavior_specs, trainer_config, tanhresample, tanhresample |
|||
) |
|||
bc_module = BCModule( |
|||
policy, |
|||
settings=bc_settings, |
|||
policy_learning_rate=trainer_config.hyperparameters.learning_rate, |
|||
default_batch_size=trainer_config.hyperparameters.batch_size, |
|||
default_num_epoch=3, |
|||
) |
|||
return bc_module |
|||
|
|||
|
|||
# Test default values |
|||
def test_bcmodule_defaults(): |
|||
# See if default values match |
|||
mock_specs = mb.create_mock_3dball_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo" |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, False, False) |
|||
assert bc_module.num_epoch == 3 |
|||
assert bc_module.batch_size == TrainerSettings().hyperparameters.batch_size |
|||
# Assign strange values and see if it overrides properly |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo", |
|||
num_epoch=100, |
|||
batch_size=10000, |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, False, False) |
|||
assert bc_module.num_epoch == 100 |
|||
assert bc_module.batch_size == 10000 |
|||
|
|||
|
|||
# Test with continuous control env and vector actions |
|||
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) |
|||
def test_bcmodule_update(is_sac): |
|||
mock_specs = mb.create_mock_3dball_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo" |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac) |
|||
stats = bc_module.update() |
|||
for _, item in stats.items(): |
|||
assert isinstance(item, np.float32) |
|||
|
|||
|
|||
# Test with constant pretraining learning rate |
|||
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) |
|||
def test_bcmodule_constant_lr_update(is_sac): |
|||
mock_specs = mb.create_mock_3dball_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo", |
|||
steps=0, |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac) |
|||
stats = bc_module.update() |
|||
for _, item in stats.items(): |
|||
assert isinstance(item, np.float32) |
|||
old_learning_rate = bc_module.current_lr |
|||
|
|||
_ = bc_module.update() |
|||
assert old_learning_rate == bc_module.current_lr |
|||
|
|||
|
|||
# Test with constant pretraining learning rate |
|||
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) |
|||
def test_bcmodule_linear_lr_update(is_sac): |
|||
mock_specs = mb.create_mock_3dball_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo", |
|||
steps=100, |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac) |
|||
# Should decay by 10/100 * 0.0003 = 0.00003 |
|||
bc_module.policy.get_current_step = MagicMock(return_value=10) |
|||
old_learning_rate = bc_module.current_lr |
|||
_ = bc_module.update() |
|||
assert old_learning_rate - 0.00003 == pytest.approx(bc_module.current_lr, abs=0.01) |
|||
|
|||
|
|||
# Test with RNN |
|||
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) |
|||
def test_bcmodule_rnn_update(is_sac): |
|||
mock_specs = mb.create_mock_3dball_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo" |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac) |
|||
stats = bc_module.update() |
|||
for _, item in stats.items(): |
|||
assert isinstance(item, np.float32) |
|||
|
|||
|
|||
# Test with discrete control and visual observations |
|||
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) |
|||
def test_bcmodule_dc_visual_update(is_sac): |
|||
mock_specs = mb.create_mock_banana_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo" |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac) |
|||
stats = bc_module.update() |
|||
for _, item in stats.items(): |
|||
assert isinstance(item, np.float32) |
|||
|
|||
|
|||
# Test with discrete control, visual observations and RNN |
|||
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) |
|||
def test_bcmodule_rnn_dc_update(is_sac): |
|||
mock_specs = mb.create_mock_banana_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo" |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac) |
|||
stats = bc_module.update() |
|||
for _, item in stats.items(): |
|||
assert isinstance(item, np.float32) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
pytest.main() |
|
|||
import pytest |
|||
|
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.ghost.trainer import GhostTrainer |
|||
from mlagents.trainers.ghost.controller import GhostController |
|||
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers |
|||
from mlagents.trainers.ppo.trainer import PPOTrainer |
|||
from mlagents.trainers.agent_processor import AgentManagerQueue |
|||
from mlagents.trainers.tests import mock_brain as mb |
|||
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory |
|||
from mlagents.trainers.settings import TrainerSettings, SelfPlaySettings, FrameworkType |
|||
|
|||
|
|||
@pytest.fixture |
|||
def dummy_config(): |
|||
return TrainerSettings( |
|||
self_play=SelfPlaySettings(), framework=FrameworkType.PYTORCH |
|||
) |
|||
|
|||
|
|||
VECTOR_ACTION_SPACE = 1 |
|||
VECTOR_OBS_SPACE = 8 |
|||
DISCRETE_ACTION_SPACE = [3, 3, 3, 2] |
|||
BUFFER_INIT_SAMPLES = 513 |
|||
NUM_AGENTS = 12 |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_load_and_set(dummy_config, use_discrete): |
|||
mock_specs = mb.setup_test_behavior_specs( |
|||
use_discrete, |
|||
False, |
|||
vector_action_space=DISCRETE_ACTION_SPACE |
|||
if use_discrete |
|||
else VECTOR_ACTION_SPACE, |
|||
vector_obs_space=VECTOR_OBS_SPACE, |
|||
) |
|||
|
|||
trainer_params = dummy_config |
|||
trainer = PPOTrainer("test", 0, trainer_params, True, False, 0, "0") |
|||
trainer.seed = 1 |
|||
policy = trainer.create_policy("test", mock_specs) |
|||
trainer.seed = 20 # otherwise graphs are the same |
|||
to_load_policy = trainer.create_policy("test", mock_specs) |
|||
|
|||
weights = policy.get_weights() |
|||
load_weights = to_load_policy.get_weights() |
|||
try: |
|||
for w, lw in zip(weights, load_weights): |
|||
np.testing.assert_array_equal(w, lw) |
|||
except AssertionError: |
|||
pass |
|||
|
|||
to_load_policy.load_weights(weights) |
|||
load_weights = to_load_policy.get_weights() |
|||
|
|||
for w, lw in zip(weights, load_weights): |
|||
np.testing.assert_array_equal(w, lw) |
|||
|
|||
|
|||
def test_process_trajectory(dummy_config): |
|||
mock_specs = mb.setup_test_behavior_specs( |
|||
True, False, vector_action_space=[2], vector_obs_space=1 |
|||
) |
|||
behavior_id_team0 = "test_brain?team=0" |
|||
behavior_id_team1 = "test_brain?team=1" |
|||
brain_name = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team0).brain_name |
|||
|
|||
ppo_trainer = PPOTrainer(brain_name, 0, dummy_config, True, False, 0, "0") |
|||
controller = GhostController(100) |
|||
trainer = GhostTrainer( |
|||
ppo_trainer, brain_name, controller, 0, dummy_config, True, "0" |
|||
) |
|||
|
|||
# first policy encountered becomes policy trained by wrapped PPO |
|||
parsed_behavior_id0 = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team0) |
|||
policy = trainer.create_policy(parsed_behavior_id0, mock_specs) |
|||
trainer.add_policy(parsed_behavior_id0, policy) |
|||
trajectory_queue0 = AgentManagerQueue(behavior_id_team0) |
|||
trainer.subscribe_trajectory_queue(trajectory_queue0) |
|||
|
|||
# Ghost trainer should ignore this queue because off policy |
|||
parsed_behavior_id1 = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team1) |
|||
policy = trainer.create_policy(parsed_behavior_id1, mock_specs) |
|||
trainer.add_policy(parsed_behavior_id1, policy) |
|||
trajectory_queue1 = AgentManagerQueue(behavior_id_team1) |
|||
trainer.subscribe_trajectory_queue(trajectory_queue1) |
|||
|
|||
time_horizon = 15 |
|||
trajectory = make_fake_trajectory( |
|||
length=time_horizon, |
|||
max_step_complete=True, |
|||
observation_shapes=[(1,)], |
|||
action_space=[2], |
|||
) |
|||
trajectory_queue0.put(trajectory) |
|||
trainer.advance() |
|||
|
|||
# Check that trainer put trajectory in update buffer |
|||
assert trainer.trainer.update_buffer.num_experiences == 15 |
|||
|
|||
trajectory_queue1.put(trajectory) |
|||
trainer.advance() |
|||
|
|||
# Check that ghost trainer ignored off policy queue |
|||
assert trainer.trainer.update_buffer.num_experiences == 15 |
|||
# Check that it emptied the queue |
|||
assert trajectory_queue1.empty() |
|||
|
|||
|
|||
def test_publish_queue(dummy_config): |
|||
mock_specs = mb.setup_test_behavior_specs( |
|||
True, False, vector_action_space=[1], vector_obs_space=8 |
|||
) |
|||
|
|||
behavior_id_team0 = "test_brain?team=0" |
|||
behavior_id_team1 = "test_brain?team=1" |
|||
|
|||
parsed_behavior_id0 = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team0) |
|||
|
|||
brain_name = parsed_behavior_id0.brain_name |
|||
|
|||
ppo_trainer = PPOTrainer(brain_name, 0, dummy_config, True, False, 0, "0") |
|||
controller = GhostController(100) |
|||
trainer = GhostTrainer( |
|||
ppo_trainer, brain_name, controller, 0, dummy_config, True, "0" |
|||
) |
|||
|
|||
# First policy encountered becomes policy trained by wrapped PPO |
|||
# This queue should remain empty after swap snapshot |
|||
policy = trainer.create_policy(parsed_behavior_id0, mock_specs) |
|||
trainer.add_policy(parsed_behavior_id0, policy) |
|||
policy_queue0 = AgentManagerQueue(behavior_id_team0) |
|||
trainer.publish_policy_queue(policy_queue0) |
|||
|
|||
# Ghost trainer should use this queue for ghost policy swap |
|||
parsed_behavior_id1 = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team1) |
|||
policy = trainer.create_policy(parsed_behavior_id1, mock_specs) |
|||
trainer.add_policy(parsed_behavior_id1, policy) |
|||
policy_queue1 = AgentManagerQueue(behavior_id_team1) |
|||
trainer.publish_policy_queue(policy_queue1) |
|||
|
|||
# check ghost trainer swap pushes to ghost queue and not trainer |
|||
assert policy_queue0.empty() and policy_queue1.empty() |
|||
trainer._swap_snapshots() |
|||
assert policy_queue0.empty() and not policy_queue1.empty() |
|||
# clear |
|||
policy_queue1.get_nowait() |
|||
|
|||
mock_specs = mb.setup_test_behavior_specs( |
|||
False, |
|||
False, |
|||
vector_action_space=VECTOR_ACTION_SPACE, |
|||
vector_obs_space=VECTOR_OBS_SPACE, |
|||
) |
|||
|
|||
buffer = mb.simulate_rollout(BUFFER_INIT_SAMPLES, mock_specs) |
|||
# Mock out reward signal eval |
|||
buffer["extrinsic_rewards"] = buffer["environment_rewards"] |
|||
buffer["extrinsic_returns"] = buffer["environment_rewards"] |
|||
buffer["extrinsic_value_estimates"] = buffer["environment_rewards"] |
|||
buffer["curiosity_rewards"] = buffer["environment_rewards"] |
|||
buffer["curiosity_returns"] = buffer["environment_rewards"] |
|||
buffer["curiosity_value_estimates"] = buffer["environment_rewards"] |
|||
buffer["advantages"] = buffer["environment_rewards"] |
|||
trainer.trainer.update_buffer = buffer |
|||
|
|||
# when ghost trainer advance and wrapped trainer buffers full |
|||
# the wrapped trainer pushes updated policy to correct queue |
|||
assert policy_queue0.empty() and policy_queue1.empty() |
|||
trainer.advance() |
|||
assert not policy_queue0.empty() and policy_queue1.empty() |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
pytest.main() |
|
|||
import pytest |
|||
|
|||
import torch |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.tests import mock_brain as mb |
|||
from mlagents.trainers.settings import TrainerSettings, NetworkSettings |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
VECTOR_ACTION_SPACE = 2 |
|||
VECTOR_OBS_SPACE = 8 |
|||
DISCRETE_ACTION_SPACE = [3, 3, 3, 2] |
|||
BUFFER_INIT_SAMPLES = 32 |
|||
NUM_AGENTS = 12 |
|||
EPSILON = 1e-7 |
|||
|
|||
|
|||
def create_policy_mock( |
|||
dummy_config: TrainerSettings, |
|||
use_rnn: bool = False, |
|||
use_discrete: bool = True, |
|||
use_visual: bool = False, |
|||
seed: int = 0, |
|||
) -> TorchPolicy: |
|||
mock_spec = mb.setup_test_behavior_specs( |
|||
use_discrete, |
|||
use_visual, |
|||
vector_action_space=DISCRETE_ACTION_SPACE |
|||
if use_discrete |
|||
else VECTOR_ACTION_SPACE, |
|||
vector_obs_space=VECTOR_OBS_SPACE, |
|||
) |
|||
|
|||
trainer_settings = dummy_config |
|||
trainer_settings.keep_checkpoints = 3 |
|||
trainer_settings.network_settings.memory = ( |
|||
NetworkSettings.MemorySettings() if use_rnn else None |
|||
) |
|||
policy = TorchPolicy(seed, mock_spec, trainer_settings) |
|||
return policy |
|||
|
|||
|
|||
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|||
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|||
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|||
def test_policy_evaluate(rnn, visual, discrete): |
|||
# Test evaluate |
|||
policy = create_policy_mock( |
|||
TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|||
) |
|||
decision_step, terminal_step = mb.create_steps_from_behavior_spec( |
|||
policy.behavior_spec, num_agents=NUM_AGENTS |
|||
) |
|||
|
|||
run_out = policy.evaluate(decision_step, list(decision_step.agent_id)) |
|||
if discrete: |
|||
run_out["action"].shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE)) |
|||
else: |
|||
assert run_out["action"].shape == (NUM_AGENTS, VECTOR_ACTION_SPACE) |
|||
|
|||
|
|||
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|||
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|||
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|||
def test_evaluate_actions(rnn, visual, discrete): |
|||
policy = create_policy_mock( |
|||
TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|||
) |
|||
buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size) |
|||
vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])] |
|||
act_masks = ModelUtils.list_to_tensor(buffer["action_mask"]) |
|||
if policy.use_continuous_act: |
|||
actions = ModelUtils.list_to_tensor(buffer["actions"]).unsqueeze(-1) |
|||
else: |
|||
actions = ModelUtils.list_to_tensor(buffer["actions"], dtype=torch.long) |
|||
vis_obs = [] |
|||
for idx, _ in enumerate(policy.actor_critic.network_body.visual_encoders): |
|||
vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx]) |
|||
vis_obs.append(vis_ob) |
|||
|
|||
memories = [ |
|||
ModelUtils.list_to_tensor(buffer["memory"][i]) |
|||
for i in range(0, len(buffer["memory"]), policy.sequence_length) |
|||
] |
|||
if len(memories) > 0: |
|||
memories = torch.stack(memories).unsqueeze(0) |
|||
|
|||
log_probs, entropy, values = policy.evaluate_actions( |
|||
vec_obs, |
|||
vis_obs, |
|||
masks=act_masks, |
|||
actions=actions, |
|||
memories=memories, |
|||
seq_len=policy.sequence_length, |
|||
) |
|||
assert log_probs.shape == (64, policy.behavior_spec.action_size) |
|||
assert entropy.shape == (64, policy.behavior_spec.action_size) |
|||
for val in values.values(): |
|||
assert val.shape == (64,) |
|||
|
|||
|
|||
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|||
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|||
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|||
def test_sample_actions(rnn, visual, discrete): |
|||
policy = create_policy_mock( |
|||
TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|||
) |
|||
buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size) |
|||
vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])] |
|||
act_masks = ModelUtils.list_to_tensor(buffer["action_mask"]) |
|||
|
|||
vis_obs = [] |
|||
for idx, _ in enumerate(policy.actor_critic.network_body.visual_encoders): |
|||
vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx]) |
|||
vis_obs.append(vis_ob) |
|||
|
|||
memories = [ |
|||
ModelUtils.list_to_tensor(buffer["memory"][i]) |
|||
for i in range(0, len(buffer["memory"]), policy.sequence_length) |
|||
] |
|||
if len(memories) > 0: |
|||
memories = torch.stack(memories).unsqueeze(0) |
|||
|
|||
( |
|||
sampled_actions, |
|||
log_probs, |
|||
entropies, |
|||
sampled_values, |
|||
memories, |
|||
) = policy.sample_actions( |
|||
vec_obs, |
|||
vis_obs, |
|||
masks=act_masks, |
|||
memories=memories, |
|||
seq_len=policy.sequence_length, |
|||
all_log_probs=not policy.use_continuous_act, |
|||
) |
|||
if discrete: |
|||
assert log_probs.shape == ( |
|||
64, |
|||
sum(policy.behavior_spec.discrete_action_branches), |
|||
) |
|||
else: |
|||
assert log_probs.shape == (64, policy.behavior_spec.action_shape) |
|||
assert entropies.shape == (64, policy.behavior_spec.action_size) |
|||
for val in sampled_values.values(): |
|||
assert val.shape == (64,) |
|||
|
|||
if rnn: |
|||
assert memories.shape == (1, 1, policy.m_size) |
|
|||
import math |
|||
import tempfile |
|||
import pytest |
|||
import numpy as np |
|||
import attr |
|||
from typing import Dict |
|||
|
|||
from mlagents.trainers.tests.simple_test_envs import ( |
|||
SimpleEnvironment, |
|||
MemoryEnvironment, |
|||
RecordEnvironment, |
|||
) |
|||
from mlagents.trainers.trainer_controller import TrainerController |
|||
from mlagents.trainers.trainer_util import TrainerFactory |
|||
from mlagents.trainers.simple_env_manager import SimpleEnvManager |
|||
from mlagents.trainers.demo_loader import write_demo |
|||
from mlagents.trainers.stats import StatsReporter, StatsWriter, StatsSummary |
|||
from mlagents.trainers.settings import ( |
|||
TrainerSettings, |
|||
PPOSettings, |
|||
SACSettings, |
|||
NetworkSettings, |
|||
SelfPlaySettings, |
|||
BehavioralCloningSettings, |
|||
GAILSettings, |
|||
TrainerType, |
|||
RewardSignalType, |
|||
EncoderType, |
|||
ScheduleType, |
|||
FrameworkType, |
|||
) |
|||
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager |
|||
from mlagents_envs.side_channel.environment_parameters_channel import ( |
|||
EnvironmentParametersChannel, |
|||
) |
|||
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import ( |
|||
DemonstrationMetaProto, |
|||
) |
|||
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto |
|||
from mlagents_envs.communicator_objects.space_type_pb2 import discrete, continuous |
|||
|
|||
BRAIN_NAME = "1D" |
|||
|
|||
|
|||
PPO_CONFIG = TrainerSettings( |
|||
trainer_type=TrainerType.PPO, |
|||
hyperparameters=PPOSettings( |
|||
learning_rate=5.0e-3, |
|||
learning_rate_schedule=ScheduleType.CONSTANT, |
|||
batch_size=16, |
|||
buffer_size=64, |
|||
), |
|||
network_settings=NetworkSettings(num_layers=1, hidden_units=32), |
|||
summary_freq=500, |
|||
max_steps=3000, |
|||
threaded=False, |
|||
framework=FrameworkType.PYTORCH, |
|||
) |
|||
|
|||
SAC_CONFIG = TrainerSettings( |
|||
trainer_type=TrainerType.SAC, |
|||
hyperparameters=SACSettings( |
|||
learning_rate=5.0e-3, |
|||
learning_rate_schedule=ScheduleType.CONSTANT, |
|||
batch_size=8, |
|||
buffer_init_steps=100, |
|||
buffer_size=5000, |
|||
tau=0.01, |
|||
init_entcoef=0.01, |
|||
), |
|||
network_settings=NetworkSettings(num_layers=1, hidden_units=16), |
|||
summary_freq=100, |
|||
max_steps=1000, |
|||
threaded=False, |
|||
) |
|||
|
|||
|
|||
# The reward processor is passed as an argument to _check_environment_trains. |
|||
# It is applied to the list of all final rewards for each brain individually. |
|||
# This is so that we can process all final rewards in different ways for different algorithms. |
|||
# Custom reward processors should be built within the test function and passed to _check_environment_trains |
|||
# Default is average over the last 5 final rewards |
|||
def default_reward_processor(rewards, last_n_rewards=5): |
|||
rewards_to_use = rewards[-last_n_rewards:] |
|||
# For debugging tests |
|||
print(f"Last {last_n_rewards} rewards:", rewards_to_use) |
|||
return np.array(rewards[-last_n_rewards:], dtype=np.float32).mean() |
|||
|
|||
|
|||
class DebugWriter(StatsWriter): |
|||
""" |
|||
Print to stdout so stats can be viewed in pytest |
|||
""" |
|||
|
|||
def __init__(self): |
|||
self._last_reward_summary: Dict[str, float] = {} |
|||
|
|||
def get_last_rewards(self): |
|||
return self._last_reward_summary |
|||
|
|||
def write_stats( |
|||
self, category: str, values: Dict[str, StatsSummary], step: int |
|||
) -> None: |
|||
for val, stats_summary in values.items(): |
|||
if val == "Environment/Cumulative Reward": |
|||
print(step, val, stats_summary.mean) |
|||
self._last_reward_summary[category] = stats_summary.mean |
|||
|
|||
|
|||
def _check_environment_trains( |
|||
env, |
|||
trainer_config, |
|||
reward_processor=default_reward_processor, |
|||
env_parameter_manager=None, |
|||
success_threshold=0.9, |
|||
env_manager=None, |
|||
): |
|||
if env_parameter_manager is None: |
|||
env_parameter_manager = EnvironmentParameterManager() |
|||
# Create controller and begin training. |
|||
with tempfile.TemporaryDirectory() as dir: |
|||
run_id = "id" |
|||
seed = 1337 |
|||
StatsReporter.writers.clear() # Clear StatsReporters so we don't write to file |
|||
debug_writer = DebugWriter() |
|||
StatsReporter.add_writer(debug_writer) |
|||
if env_manager is None: |
|||
env_manager = SimpleEnvManager(env, EnvironmentParametersChannel()) |
|||
trainer_factory = TrainerFactory( |
|||
trainer_config=trainer_config, |
|||
output_path=dir, |
|||
train_model=True, |
|||
load_model=False, |
|||
seed=seed, |
|||
param_manager=env_parameter_manager, |
|||
multi_gpu=False, |
|||
) |
|||
|
|||
tc = TrainerController( |
|||
trainer_factory=trainer_factory, |
|||
output_path=dir, |
|||
run_id=run_id, |
|||
param_manager=env_parameter_manager, |
|||
train=True, |
|||
training_seed=seed, |
|||
) |
|||
|
|||
# Begin training |
|||
tc.start_learning(env_manager) |
|||
if ( |
|||
success_threshold is not None |
|||
): # For tests where we are just checking setup and not reward |
|||
processed_rewards = [ |
|||
reward_processor(rewards) for rewards in env.final_rewards.values() |
|||
] |
|||
assert all(not math.isnan(reward) for reward in processed_rewards) |
|||
assert all(reward > success_threshold for reward in processed_rewards) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_simple_ppo(use_discrete): |
|||
env = SimpleEnvironment([BRAIN_NAME], use_discrete=use_discrete) |
|||
config = attr.evolve(PPO_CONFIG) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_2d_ppo(use_discrete): |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME], use_discrete=use_discrete, action_size=2, step_size=0.8 |
|||
) |
|||
new_hyperparams = attr.evolve( |
|||
PPO_CONFIG.hyperparameters, batch_size=64, buffer_size=640 |
|||
) |
|||
config = attr.evolve(PPO_CONFIG, hyperparameters=new_hyperparams, max_steps=10000) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
@pytest.mark.parametrize("num_visual", [1, 2]) |
|||
def test_visual_ppo(num_visual, use_discrete): |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME], |
|||
use_discrete=use_discrete, |
|||
num_visual=num_visual, |
|||
num_vector=0, |
|||
step_size=0.2, |
|||
) |
|||
new_hyperparams = attr.evolve(PPO_CONFIG.hyperparameters, learning_rate=3.0e-4) |
|||
config = attr.evolve(PPO_CONFIG, hyperparameters=new_hyperparams) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}) |
|||
|
|||
|
|||
@pytest.mark.parametrize("num_visual", [1, 2]) |
|||
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn"]) |
|||
def test_visual_advanced_ppo(vis_encode_type, num_visual): |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME], |
|||
use_discrete=True, |
|||
num_visual=num_visual, |
|||
num_vector=0, |
|||
step_size=0.5, |
|||
vis_obs_size=(36, 36, 3), |
|||
) |
|||
new_networksettings = attr.evolve( |
|||
SAC_CONFIG.network_settings, vis_encode_type=EncoderType(vis_encode_type) |
|||
) |
|||
new_hyperparams = attr.evolve(PPO_CONFIG.hyperparameters, learning_rate=3.0e-4) |
|||
config = attr.evolve( |
|||
PPO_CONFIG, |
|||
hyperparameters=new_hyperparams, |
|||
network_settings=new_networksettings, |
|||
max_steps=500, |
|||
summary_freq=100, |
|||
) |
|||
# The number of steps is pretty small for these encoders |
|||
_check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.5) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_recurrent_ppo(use_discrete): |
|||
env = MemoryEnvironment([BRAIN_NAME], use_discrete=use_discrete) |
|||
new_network_settings = attr.evolve( |
|||
PPO_CONFIG.network_settings, |
|||
memory=NetworkSettings.MemorySettings(memory_size=16), |
|||
) |
|||
new_hyperparams = attr.evolve( |
|||
PPO_CONFIG.hyperparameters, learning_rate=1.0e-3, batch_size=64, buffer_size=128 |
|||
) |
|||
config = attr.evolve( |
|||
PPO_CONFIG, |
|||
hyperparameters=new_hyperparams, |
|||
network_settings=new_network_settings, |
|||
max_steps=5000, |
|||
) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_simple_sac(use_discrete): |
|||
env = SimpleEnvironment([BRAIN_NAME], use_discrete=use_discrete) |
|||
config = attr.evolve(SAC_CONFIG) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_2d_sac(use_discrete): |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME], use_discrete=use_discrete, action_size=2, step_size=0.8 |
|||
) |
|||
new_hyperparams = attr.evolve(SAC_CONFIG.hyperparameters, buffer_init_steps=2000) |
|||
config = attr.evolve(SAC_CONFIG, hyperparameters=new_hyperparams, max_steps=10000) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.8) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
@pytest.mark.parametrize("num_visual", [1, 2]) |
|||
def test_visual_sac(num_visual, use_discrete): |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME], |
|||
use_discrete=use_discrete, |
|||
num_visual=num_visual, |
|||
num_vector=0, |
|||
step_size=0.2, |
|||
) |
|||
new_hyperparams = attr.evolve( |
|||
SAC_CONFIG.hyperparameters, batch_size=16, learning_rate=3e-4 |
|||
) |
|||
config = attr.evolve(SAC_CONFIG, hyperparameters=new_hyperparams) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}) |
|||
|
|||
|
|||
@pytest.mark.parametrize("num_visual", [1, 2]) |
|||
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn"]) |
|||
def test_visual_advanced_sac(vis_encode_type, num_visual): |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME], |
|||
use_discrete=True, |
|||
num_visual=num_visual, |
|||
num_vector=0, |
|||
step_size=0.5, |
|||
vis_obs_size=(36, 36, 3), |
|||
) |
|||
new_networksettings = attr.evolve( |
|||
SAC_CONFIG.network_settings, vis_encode_type=EncoderType(vis_encode_type) |
|||
) |
|||
new_hyperparams = attr.evolve( |
|||
SAC_CONFIG.hyperparameters, |
|||
batch_size=16, |
|||
learning_rate=3e-4, |
|||
buffer_init_steps=0, |
|||
) |
|||
config = attr.evolve( |
|||
SAC_CONFIG, |
|||
hyperparameters=new_hyperparams, |
|||
network_settings=new_networksettings, |
|||
max_steps=100, |
|||
) |
|||
# The number of steps is pretty small for these encoders |
|||
_check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.5) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_recurrent_sac(use_discrete): |
|||
step_size = 0.5 if use_discrete else 0.2 |
|||
env = MemoryEnvironment( |
|||
[BRAIN_NAME], use_discrete=use_discrete, step_size=step_size |
|||
) |
|||
new_networksettings = attr.evolve( |
|||
SAC_CONFIG.network_settings, |
|||
memory=NetworkSettings.MemorySettings(memory_size=16, sequence_length=16), |
|||
) |
|||
new_hyperparams = attr.evolve( |
|||
SAC_CONFIG.hyperparameters, |
|||
batch_size=128, |
|||
learning_rate=1e-3, |
|||
buffer_init_steps=1000, |
|||
steps_per_update=2, |
|||
) |
|||
config = attr.evolve( |
|||
SAC_CONFIG, |
|||
hyperparameters=new_hyperparams, |
|||
network_settings=new_networksettings, |
|||
max_steps=5000, |
|||
) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_simple_ghost(use_discrete): |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME + "?team=0", BRAIN_NAME + "?team=1"], use_discrete=use_discrete |
|||
) |
|||
self_play_settings = SelfPlaySettings( |
|||
play_against_latest_model_ratio=1.0, save_steps=2000, swap_steps=2000 |
|||
) |
|||
config = attr.evolve(PPO_CONFIG, self_play=self_play_settings, max_steps=2500) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_simple_ghost_fails(use_discrete): |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME + "?team=0", BRAIN_NAME + "?team=1"], use_discrete=use_discrete |
|||
) |
|||
# This config should fail because the ghosted policy is never swapped with a competent policy. |
|||
# Swap occurs after max step is reached. |
|||
self_play_settings = SelfPlaySettings( |
|||
play_against_latest_model_ratio=1.0, save_steps=2000, swap_steps=4000 |
|||
) |
|||
config = attr.evolve(PPO_CONFIG, self_play=self_play_settings, max_steps=2500) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=None) |
|||
processed_rewards = [ |
|||
default_reward_processor(rewards) for rewards in env.final_rewards.values() |
|||
] |
|||
success_threshold = 0.9 |
|||
assert any(reward > success_threshold for reward in processed_rewards) and any( |
|||
reward < success_threshold for reward in processed_rewards |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_simple_asymm_ghost(use_discrete): |
|||
# Make opponent for asymmetric case |
|||
brain_name_opp = BRAIN_NAME + "Opp" |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME + "?team=0", brain_name_opp + "?team=1"], use_discrete=use_discrete |
|||
) |
|||
self_play_settings = SelfPlaySettings( |
|||
play_against_latest_model_ratio=1.0, |
|||
save_steps=10000, |
|||
swap_steps=10000, |
|||
team_change=400, |
|||
) |
|||
config = attr.evolve(PPO_CONFIG, self_play=self_play_settings, max_steps=4000) |
|||
_check_environment_trains(env, {BRAIN_NAME: config, brain_name_opp: config}) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_simple_asymm_ghost_fails(use_discrete): |
|||
# Make opponent for asymmetric case |
|||
brain_name_opp = BRAIN_NAME + "Opp" |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME + "?team=0", brain_name_opp + "?team=1"], use_discrete=use_discrete |
|||
) |
|||
# This config should fail because the team that us not learning when both have reached |
|||
# max step should be executing the initial, untrained poliy. |
|||
self_play_settings = SelfPlaySettings( |
|||
play_against_latest_model_ratio=0.0, |
|||
save_steps=5000, |
|||
swap_steps=5000, |
|||
team_change=2000, |
|||
) |
|||
config = attr.evolve(PPO_CONFIG, self_play=self_play_settings, max_steps=3000) |
|||
_check_environment_trains( |
|||
env, {BRAIN_NAME: config, brain_name_opp: config}, success_threshold=None |
|||
) |
|||
processed_rewards = [ |
|||
default_reward_processor(rewards) for rewards in env.final_rewards.values() |
|||
] |
|||
success_threshold = 0.9 |
|||
assert any(reward > success_threshold for reward in processed_rewards) and any( |
|||
reward < success_threshold for reward in processed_rewards |
|||
) |
|||
|
|||
|
|||
@pytest.fixture(scope="session") |
|||
def simple_record(tmpdir_factory): |
|||
def record_demo(use_discrete, num_visual=0, num_vector=1): |
|||
env = RecordEnvironment( |
|||
[BRAIN_NAME], |
|||
use_discrete=use_discrete, |
|||
num_visual=num_visual, |
|||
num_vector=num_vector, |
|||
n_demos=100, |
|||
) |
|||
# If we want to use true demos, we can solve the env in the usual way |
|||
# Otherwise, we can just call solve to execute the optimal policy |
|||
env.solve() |
|||
agent_info_protos = env.demonstration_protos[BRAIN_NAME] |
|||
meta_data_proto = DemonstrationMetaProto() |
|||
brain_param_proto = BrainParametersProto( |
|||
vector_action_size=[2] if use_discrete else [1], |
|||
vector_action_descriptions=[""], |
|||
vector_action_space_type=discrete if use_discrete else continuous, |
|||
brain_name=BRAIN_NAME, |
|||
is_training=True, |
|||
) |
|||
action_type = "Discrete" if use_discrete else "Continuous" |
|||
demo_path_name = "1DTest" + action_type + ".demo" |
|||
demo_path = str(tmpdir_factory.mktemp("tmp_demo").join(demo_path_name)) |
|||
write_demo(demo_path, meta_data_proto, brain_param_proto, agent_info_protos) |
|||
return demo_path |
|||
|
|||
return record_demo |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
@pytest.mark.parametrize("trainer_config", [PPO_CONFIG, SAC_CONFIG]) |
|||
def test_gail(simple_record, use_discrete, trainer_config): |
|||
demo_path = simple_record(use_discrete) |
|||
env = SimpleEnvironment([BRAIN_NAME], use_discrete=use_discrete, step_size=0.2) |
|||
bc_settings = BehavioralCloningSettings(demo_path=demo_path, steps=1000) |
|||
reward_signals = { |
|||
RewardSignalType.GAIL: GAILSettings(encoding_size=32, demo_path=demo_path) |
|||
} |
|||
config = attr.evolve( |
|||
trainer_config, |
|||
reward_signals=reward_signals, |
|||
behavioral_cloning=bc_settings, |
|||
max_steps=500, |
|||
) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_gail_visual_ppo(simple_record, use_discrete): |
|||
demo_path = simple_record(use_discrete, num_visual=1, num_vector=0) |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME], |
|||
num_visual=1, |
|||
num_vector=0, |
|||
use_discrete=use_discrete, |
|||
step_size=0.2, |
|||
) |
|||
bc_settings = BehavioralCloningSettings(demo_path=demo_path, steps=1500) |
|||
reward_signals = { |
|||
RewardSignalType.GAIL: GAILSettings(encoding_size=32, demo_path=demo_path) |
|||
} |
|||
hyperparams = attr.evolve(PPO_CONFIG.hyperparameters, learning_rate=3e-4) |
|||
config = attr.evolve( |
|||
PPO_CONFIG, |
|||
reward_signals=reward_signals, |
|||
hyperparameters=hyperparams, |
|||
behavioral_cloning=bc_settings, |
|||
max_steps=1000, |
|||
) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_gail_visual_sac(simple_record, use_discrete): |
|||
demo_path = simple_record(use_discrete, num_visual=1, num_vector=0) |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME], |
|||
num_visual=1, |
|||
num_vector=0, |
|||
use_discrete=use_discrete, |
|||
step_size=0.2, |
|||
) |
|||
bc_settings = BehavioralCloningSettings(demo_path=demo_path, steps=1000) |
|||
reward_signals = { |
|||
RewardSignalType.GAIL: GAILSettings(encoding_size=32, demo_path=demo_path) |
|||
} |
|||
hyperparams = attr.evolve( |
|||
SAC_CONFIG.hyperparameters, learning_rate=3e-4, batch_size=16 |
|||
) |
|||
config = attr.evolve( |
|||
SAC_CONFIG, |
|||
reward_signals=reward_signals, |
|||
hyperparameters=hyperparams, |
|||
behavioral_cloning=bc_settings, |
|||
max_steps=500, |
|||
) |
|||
_check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9) |
|
|||
bcvis& -* * * * :VisualFoodCollectorLearning� |
|||
�P�������j� |
|||
TT��PNG |
|||
|
|||
|