浏览代码
[feature] Add experimental PyTorch support (#4335)
[feature] Add experimental PyTorch support (#4335)
* Begin porting work * Add ResNet and distributions * Dynamically construct actor and critic * Initial optimizer port * Refactoring policy and optimizer * Resolving a few bugs * Share more code between tf and torch policies * Slightly closer to running model * Training runs, but doesn’t actually work * Fix a couple additional bugs * Add conditional sigma for distribution * Fix normalization * Support discrete actions as well * Continuous and discrete now train * Mulkti-discrete now working * Visual observations now train as well * GRU in-progress and dynamic cnns * Fix for memories * Remove unused arg * Combine actor and critic classes. Initial export. * Support tf and pytorch alongside one another * Prepare model for onnx export * Use LSTM and fix a few merge errors * Fix bug in probs calculation * Optimize np -> tensor operations * Time action sample funct.../MLA-1734-demo-provider
GitHub
4 年前
当前提交
1955af9e
共有 52 个文件被更改,包括 5374 次插入 和 156 次删除
-
4com.unity.ml-agents/CHANGELOG.md
-
2ml-agents/mlagents/trainers/buffer.py
-
7ml-agents/mlagents/trainers/cli_utils.py
-
12ml-agents/mlagents/trainers/ghost/trainer.py
-
2ml-agents/mlagents/trainers/policy/tf_policy.py
-
76ml-agents/mlagents/trainers/ppo/trainer.py
-
120ml-agents/mlagents/trainers/sac/trainer.py
-
14ml-agents/mlagents/trainers/settings.py
-
7ml-agents/mlagents/trainers/tests/test_ghost.py
-
5ml-agents/mlagents/trainers/tests/test_rl_trainer.py
-
3ml-agents/mlagents/trainers/tests/test_sac.py
-
2ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
19ml-agents/mlagents/trainers/tests/torch/test_layers.py
-
17ml-agents/mlagents/trainers/tests/torch/test_networks.py
-
6ml-agents/mlagents/trainers/tests/torch/test_utils.py
-
20ml-agents/mlagents/trainers/tf/model_serialization.py
-
17ml-agents/mlagents/trainers/torch/encoders.py
-
67ml-agents/mlagents/trainers/torch/layers.py
-
115ml-agents/mlagents/trainers/torch/networks.py
-
4ml-agents/mlagents/trainers/torch/utils.py
-
99ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
5ml-agents/mlagents/trainers/trainer/trainer.py
-
7ml-agents/mlagents/trainers/trainer_controller.py
-
94ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
-
281ml-agents/mlagents/trainers/policy/torch_policy.py
-
203ml-agents/mlagents/trainers/ppo/optimizer_torch.py
-
561ml-agents/mlagents/trainers/sac/optimizer_torch.py
-
118ml-agents/mlagents/trainers/saver/torch_saver.py
-
1001ml-agents/mlagents/trainers/tests/torch/test.demo
-
144ml-agents/mlagents/trainers/tests/torch/test_bcmodule.py
-
177ml-agents/mlagents/trainers/tests/torch/test_ghost.py
-
150ml-agents/mlagents/trainers/tests/torch/test_policy.py
-
505ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
-
446ml-agents/mlagents/trainers/tests/torch/testdcvis.demo
-
74ml-agents/mlagents/trainers/torch/model_serialization.py
-
111ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
-
56ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
-
138ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
-
32ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
-
0ml-agents/mlagents/trainers/torch/components/__init__.py
-
0ml-agents/mlagents/trainers/torch/components/bc/__init__.py
-
183ml-agents/mlagents/trainers/torch/components/bc/module.py
-
15ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py
-
72ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py
-
15ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py
-
43ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py
-
225ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
-
256ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
|
|||
from typing import Dict, Optional, Tuple, List |
|||
import torch |
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
from mlagents.trainers.torch.components.bc.module import BCModule |
|||
from mlagents.trainers.torch.components.reward_providers import create_reward_provider |
|||
|
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.optimizer import Optimizer |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
|
|||
class TorchOptimizer(Optimizer): # pylint: disable=W0223 |
|||
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|||
super().__init__() |
|||
self.policy = policy |
|||
self.trainer_settings = trainer_settings |
|||
self.update_dict: Dict[str, torch.Tensor] = {} |
|||
self.value_heads: Dict[str, torch.Tensor] = {} |
|||
self.memory_in: torch.Tensor = None |
|||
self.memory_out: torch.Tensor = None |
|||
self.m_size: int = 0 |
|||
self.global_step = torch.tensor(0) |
|||
self.bc_module: Optional[BCModule] = None |
|||
self.create_reward_signals(trainer_settings.reward_signals) |
|||
if trainer_settings.behavioral_cloning is not None: |
|||
self.bc_module = BCModule( |
|||
self.policy, |
|||
trainer_settings.behavioral_cloning, |
|||
policy_learning_rate=trainer_settings.hyperparameters.learning_rate, |
|||
default_batch_size=trainer_settings.hyperparameters.batch_size, |
|||
default_num_epoch=3, |
|||
) |
|||
|
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
pass |
|||
|
|||
def create_reward_signals(self, reward_signal_configs): |
|||
""" |
|||
Create reward signals |
|||
:param reward_signal_configs: Reward signal config. |
|||
""" |
|||
for reward_signal, settings in reward_signal_configs.items(): |
|||
# Name reward signals by string in case we have duplicates later |
|||
self.reward_signals[reward_signal.value] = create_reward_provider( |
|||
reward_signal, self.policy.behavior_spec, settings |
|||
) |
|||
|
|||
def get_trajectory_value_estimates( |
|||
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool |
|||
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: |
|||
vector_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
|||
if self.policy.use_vis_obs: |
|||
visual_obs = [] |
|||
for idx, _ in enumerate( |
|||
self.policy.actor_critic.network_body.visual_encoders |
|||
): |
|||
visual_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
|||
visual_obs.append(visual_ob) |
|||
else: |
|||
visual_obs = [] |
|||
|
|||
memory = torch.zeros([1, 1, self.policy.m_size]) |
|||
|
|||
vec_vis_obs = SplitObservations.from_observations(next_obs) |
|||
next_vec_obs = [ |
|||
ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0) |
|||
] |
|||
next_vis_obs = [ |
|||
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0) |
|||
for _vis_ob in vec_vis_obs.visual_observations |
|||
] |
|||
|
|||
value_estimates, next_memory = self.policy.actor_critic.critic_pass( |
|||
vector_obs, visual_obs, memory, sequence_length=batch.num_experiences |
|||
) |
|||
|
|||
next_value_estimate, _ = self.policy.actor_critic.critic_pass( |
|||
next_vec_obs, next_vis_obs, next_memory, sequence_length=1 |
|||
) |
|||
|
|||
for name, estimate in value_estimates.items(): |
|||
value_estimates[name] = estimate.detach().cpu().numpy() |
|||
next_value_estimate[name] = next_value_estimate[name].detach().cpu().numpy() |
|||
|
|||
if done: |
|||
for k in next_value_estimate: |
|||
if not self.reward_signals[k].ignore_done: |
|||
next_value_estimate[k] = 0.0 |
|||
|
|||
return value_estimates, next_value_estimate |
|
|||
from typing import Any, Dict, List, Tuple, Optional |
|||
import numpy as np |
|||
import torch |
|||
import copy |
|||
|
|||
from mlagents.trainers.action_info import ActionInfo |
|||
from mlagents.trainers.behavior_id_utils import get_global_agent_id |
|||
from mlagents.trainers.policy import Policy |
|||
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec |
|||
from mlagents_envs.timers import timed |
|||
|
|||
from mlagents.trainers.settings import TrainerSettings |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
from mlagents.trainers.torch.networks import ( |
|||
SharedActorCritic, |
|||
SeparateActorCritic, |
|||
GlobalSteps, |
|||
) |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
EPSILON = 1e-7 # Small value to avoid divide by zero |
|||
|
|||
|
|||
class TorchPolicy(Policy): |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
behavior_spec: BehaviorSpec, |
|||
trainer_settings: TrainerSettings, |
|||
tanh_squash: bool = False, |
|||
reparameterize: bool = False, |
|||
separate_critic: bool = True, |
|||
condition_sigma_on_obs: bool = True, |
|||
): |
|||
""" |
|||
Policy that uses a multilayer perceptron to map the observations to actions. Could |
|||
also use a CNN to encode visual input prior to the MLP. Supports discrete and |
|||
continuous action spaces, as well as recurrent networks. |
|||
:param seed: Random seed. |
|||
:param brain: Assigned BrainParameters object. |
|||
:param trainer_settings: Defined training parameters. |
|||
:param load: Whether a pre-trained model will be loaded or a new one created. |
|||
:param tanh_squash: Whether to use a tanh function on the continuous output, |
|||
or a clipped output. |
|||
:param reparameterize: Whether we are using the resampling trick to update the policy |
|||
in continuous output. |
|||
""" |
|||
super().__init__( |
|||
seed, |
|||
behavior_spec, |
|||
trainer_settings, |
|||
tanh_squash, |
|||
reparameterize, |
|||
condition_sigma_on_obs, |
|||
) |
|||
self.global_step = ( |
|||
GlobalSteps() |
|||
) # could be much simpler if TorchPolicy is nn.Module |
|||
self.grads = None |
|||
|
|||
torch.set_default_tensor_type(torch.FloatTensor) |
|||
|
|||
reward_signal_configs = trainer_settings.reward_signals |
|||
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] |
|||
|
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
} |
|||
if separate_critic: |
|||
ac_class = SeparateActorCritic |
|||
else: |
|||
ac_class = SharedActorCritic |
|||
self.actor_critic = ac_class( |
|||
observation_shapes=self.behavior_spec.observation_shapes, |
|||
network_settings=trainer_settings.network_settings, |
|||
act_type=behavior_spec.action_type, |
|||
act_size=self.act_size, |
|||
stream_names=reward_signal_names, |
|||
conditional_sigma=self.condition_sigma_on_obs, |
|||
tanh_squash=tanh_squash, |
|||
) |
|||
# Save the m_size needed for export |
|||
self._export_m_size = self.m_size |
|||
# m_size needed for training is determined by network, not trainer settings |
|||
self.m_size = self.actor_critic.memory_size |
|||
|
|||
self.actor_critic.to("cpu") |
|||
|
|||
@property |
|||
def export_memory_size(self) -> int: |
|||
""" |
|||
Returns the memory size of the exported ONNX policy. This only includes the memory |
|||
of the Actor and not any auxillary networks. |
|||
""" |
|||
return self._export_m_size |
|||
|
|||
def _split_decision_step( |
|||
self, decision_requests: DecisionSteps |
|||
) -> Tuple[SplitObservations, np.ndarray]: |
|||
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs) |
|||
mask = None |
|||
if not self.use_continuous_act: |
|||
mask = torch.ones([len(decision_requests), np.sum(self.act_size)]) |
|||
if decision_requests.action_mask is not None: |
|||
mask = torch.as_tensor( |
|||
1 - np.concatenate(decision_requests.action_mask, axis=1) |
|||
) |
|||
return vec_vis_obs, mask |
|||
|
|||