Andrew Cohen
4 年前
当前提交
06e4356c
共有 64 个文件被更改,包括 977 次插入 和 561 次删除
-
2com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
-
12com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
-
53com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
-
30com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
-
18com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
-
2com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
-
13com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
-
34com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
-
22com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
-
5com.unity.ml-agents/CHANGELOG.md
-
2docs/Getting-Started.md
-
2docs/Training-Configuration-File.md
-
2docs/Training-on-Microsoft-Azure.md
-
4docs/Using-Docker.md
-
2docs/Using-Tensorboard.md
-
4docs/localized/zh-CN/docs/Getting-Started-with-Balance-Ball.md
-
2ml-agents-envs/setup.py
-
2ml-agents/mlagents/model_serialization.py
-
5ml-agents/mlagents/trainers/agent_processor.py
-
2ml-agents/mlagents/trainers/components/reward_signals/__init__.py
-
2ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
-
2ml-agents/mlagents/trainers/components/reward_signals/gail/model.py
-
29ml-agents/mlagents/trainers/env_manager.py
-
11ml-agents/mlagents/trainers/ghost/trainer.py
-
3ml-agents/mlagents/trainers/optimizer/optimizer.py
-
2ml-agents/mlagents/trainers/optimizer/tf_optimizer.py
-
166ml-agents/mlagents/trainers/policy/policy.py
-
334ml-agents/mlagents/trainers/policy/tf_policy.py
-
2ml-agents/mlagents/trainers/ppo/optimizer.py
-
21ml-agents/mlagents/trainers/ppo/trainer.py
-
4ml-agents/mlagents/trainers/sac/network.py
-
6ml-agents/mlagents/trainers/sac/optimizer.py
-
17ml-agents/mlagents/trainers/sac/trainer.py
-
12ml-agents/mlagents/trainers/settings.py
-
17ml-agents/mlagents/trainers/stats.py
-
2ml-agents/mlagents/trainers/tests/test_barracuda_converter.py
-
7ml-agents/mlagents/trainers/tests/test_bcmodule.py
-
2ml-agents/mlagents/trainers/tests/test_distributions.py
-
2ml-agents/mlagents/trainers/tests/test_models.py
-
17ml-agents/mlagents/trainers/tests/test_nn_policy.py
-
15ml-agents/mlagents/trainers/tests/test_ppo.py
-
6ml-agents/mlagents/trainers/tests/test_reward_signals.py
-
13ml-agents/mlagents/trainers/tests/test_sac.py
-
3ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
2ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py
-
3ml-agents/mlagents/trainers/tests/test_trainer_controller.py
-
4ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
9ml-agents/mlagents/trainers/trainer/trainer.py
-
36ml-agents/mlagents/trainers/trainer_controller.py
-
2ml-agents/mlagents/trainers/tf/tensorflow_to_barracuda.py
-
13ml-agents/mlagents/trainers/tf/models.py
-
2ml-agents/mlagents/trainers/tf/distributions.py
-
147com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs
-
11com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta
-
27com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs
-
11com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta
-
62com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs
-
11com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta
-
0ml-agents/mlagents/trainers/tf/__init__.py
-
285ml-agents/mlagents/trainers/policy/nn_policy.py
-
0/ml-agents/mlagents/trainers/tf/tensorflow_to_barracuda.py
-
0/ml-agents/mlagents/trainers/tf/models.py
-
0/ml-agents/mlagents/trainers/tf/distributions.py
|
|||
from abc import ABC, abstractmethod |
|||
from abc import abstractmethod |
|||
from typing import Dict, List, Optional |
|||
import numpy as np |
|||
from mlagents_envs.exception import UnityException |
|||
|
|||
from mlagents.model_serialization import SerializationSettings |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.settings import TrainerSettings, NetworkSettings |
|||
class Policy(ABC): |
|||
@abstractmethod |
|||
class UnityPolicyException(UnityException): |
|||
""" |
|||
Related to errors with the Trainer. |
|||
""" |
|||
|
|||
pass |
|||
|
|||
|
|||
class Policy: |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
behavior_spec: BehaviorSpec, |
|||
trainer_settings: TrainerSettings, |
|||
model_path: str, |
|||
load: bool = False, |
|||
tanh_squash: bool = False, |
|||
reparameterize: bool = False, |
|||
condition_sigma_on_obs: bool = True, |
|||
): |
|||
self.behavior_spec = behavior_spec |
|||
self.trainer_settings = trainer_settings |
|||
self.network_settings: NetworkSettings = trainer_settings.network_settings |
|||
self.seed = seed |
|||
self.act_size = ( |
|||
list(behavior_spec.discrete_action_branches) |
|||
if behavior_spec.is_action_discrete() |
|||
else [behavior_spec.action_size] |
|||
) |
|||
self.vec_obs_size = sum( |
|||
shape[0] for shape in behavior_spec.observation_shapes if len(shape) == 1 |
|||
) |
|||
self.vis_obs_size = sum( |
|||
1 for shape in behavior_spec.observation_shapes if len(shape) == 3 |
|||
) |
|||
self.model_path = model_path |
|||
self.initialize_path = self.trainer_settings.init_path |
|||
self._keep_checkpoints = self.trainer_settings.keep_checkpoints |
|||
self.use_continuous_act = behavior_spec.is_action_continuous() |
|||
self.num_branches = self.behavior_spec.action_size |
|||
self.previous_action_dict: Dict[str, np.array] = {} |
|||
self.memory_dict: Dict[str, np.ndarray] = {} |
|||
self.normalize = trainer_settings.network_settings.normalize |
|||
self.use_recurrent = self.network_settings.memory is not None |
|||
self.load = load |
|||
self.h_size = self.network_settings.hidden_units |
|||
num_layers = self.network_settings.num_layers |
|||
if num_layers < 1: |
|||
num_layers = 1 |
|||
self.num_layers = num_layers |
|||
|
|||
self.vis_encode_type = self.network_settings.vis_encode_type |
|||
self.tanh_squash = tanh_squash |
|||
self.reparameterize = reparameterize |
|||
self.condition_sigma_on_obs = condition_sigma_on_obs |
|||
|
|||
self.m_size = 0 |
|||
self.sequence_length = 1 |
|||
if self.network_settings.memory is not None: |
|||
self.m_size = self.network_settings.memory.memory_size |
|||
self.sequence_length = self.network_settings.memory.sequence_length |
|||
|
|||
# Non-exposed parameters; these aren't exposed because they don't have a |
|||
# good explanation and usually shouldn't be touched. |
|||
self.log_std_min = -20 |
|||
self.log_std_max = 2 |
|||
|
|||
def make_empty_memory(self, num_agents): |
|||
""" |
|||
Creates empty memory for use with RNNs |
|||
:param num_agents: Number of agents. |
|||
:return: Numpy array of zeros. |
|||
""" |
|||
return np.zeros((num_agents, self.m_size), dtype=np.float32) |
|||
|
|||
def save_memories( |
|||
self, agent_ids: List[str], memory_matrix: Optional[np.ndarray] |
|||
) -> None: |
|||
if memory_matrix is None: |
|||
return |
|||
for index, agent_id in enumerate(agent_ids): |
|||
self.memory_dict[agent_id] = memory_matrix[index, :] |
|||
|
|||
def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray: |
|||
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32) |
|||
for index, agent_id in enumerate(agent_ids): |
|||
if agent_id in self.memory_dict: |
|||
memory_matrix[index, :] = self.memory_dict[agent_id] |
|||
return memory_matrix |
|||
|
|||
def remove_memories(self, agent_ids): |
|||
for agent_id in agent_ids: |
|||
if agent_id in self.memory_dict: |
|||
self.memory_dict.pop(agent_id) |
|||
|
|||
def make_empty_previous_action(self, num_agents): |
|||
""" |
|||
Creates empty previous action for use with RNNs and discrete control |
|||
:param num_agents: Number of agents. |
|||
:return: Numpy array of zeros. |
|||
""" |
|||
return np.zeros((num_agents, self.num_branches), dtype=np.int) |
|||
|
|||
def save_previous_action( |
|||
self, agent_ids: List[str], action_matrix: Optional[np.ndarray] |
|||
) -> None: |
|||
if action_matrix is None: |
|||
return |
|||
for index, agent_id in enumerate(agent_ids): |
|||
self.previous_action_dict[agent_id] = action_matrix[index, :] |
|||
|
|||
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray: |
|||
action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int) |
|||
for index, agent_id in enumerate(agent_ids): |
|||
if agent_id in self.previous_action_dict: |
|||
action_matrix[index, :] = self.previous_action_dict[agent_id] |
|||
return action_matrix |
|||
|
|||
def remove_previous_action(self, agent_ids): |
|||
for agent_id in agent_ids: |
|||
if agent_id in self.previous_action_dict: |
|||
self.previous_action_dict.pop(agent_id) |
|||
|
|||
raise NotImplementedError |
|||
|
|||
@abstractmethod |
|||
def update_normalization(self, vector_obs: np.ndarray) -> None: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def increment_step(self, n_steps): |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def get_current_step(self): |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def save(self, output_filepath: str, settings: SerializationSettings) -> None: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def load_weights(self, values: List[np.ndarray]) -> None: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def get_weights(self) -> List[np.ndarray]: |
|||
return [] |
|||
|
|||
@abstractmethod |
|||
def init_load_weights(self) -> None: |
|||
pass |