浏览代码

Refactor TFPolicy and Policy (#4254)

* Refactor TFPolicy and Policy
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
1b098c9a
共有 12 个文件被更改,包括 409 次插入375 次删除
  1. 2
      ml-agents/mlagents/trainers/components/reward_signals/__init__.py
  2. 154
      ml-agents/mlagents/trainers/policy/policy.py
  3. 275
      ml-agents/mlagents/trainers/policy/tf_policy.py
  4. 12
      ml-agents/mlagents/trainers/ppo/trainer.py
  5. 4
      ml-agents/mlagents/trainers/sac/optimizer.py
  6. 8
      ml-agents/mlagents/trainers/sac/trainer.py
  7. 7
      ml-agents/mlagents/trainers/tests/test_bcmodule.py
  8. 13
      ml-agents/mlagents/trainers/tests/test_nn_policy.py
  9. 15
      ml-agents/mlagents/trainers/tests/test_ppo.py
  10. 6
      ml-agents/mlagents/trainers/tests/test_reward_signals.py
  11. 13
      ml-agents/mlagents/trainers/tests/test_sac.py
  12. 275
      ml-agents/mlagents/trainers/policy/nn_policy.py

2
ml-agents/mlagents/trainers/components/reward_signals/__init__.py


"""
Initializes a reward signal. At minimum, you must pass in the policy it is being applied to,
the reward strength, and the gamma (discount factor.)
:param policy: The Policy object (e.g. NNPolicy) that this Reward Signal will apply to.
:param policy: The Policy object (e.g. TFPolicy) that this Reward Signal will apply to.
:param settings: Settings parameters for this Reward Signal, including gamma and strength.
:return: A RewardSignal object.
"""

154
ml-agents/mlagents/trainers/policy/policy.py


from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import Dict, List, Optional
import numpy as np
from mlagents_envs.exception import UnityException
from mlagents.model_serialization import SerializationSettings
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.settings import TrainerSettings, NetworkSettings
class Policy(ABC):
@abstractmethod
class UnityPolicyException(UnityException):
"""
Related to errors with the Trainer.
"""
pass
class Policy:
def __init__(
self,
seed: int,
behavior_spec: BehaviorSpec,
trainer_settings: TrainerSettings,
model_path: str,
load: bool = False,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
):
self.behavior_spec = behavior_spec
self.trainer_settings = trainer_settings
self.network_settings: NetworkSettings = trainer_settings.network_settings
self.seed = seed
self.act_size = (
list(behavior_spec.discrete_action_branches)
if behavior_spec.is_action_discrete()
else [behavior_spec.action_size]
)
self.vec_obs_size = sum(
shape[0] for shape in behavior_spec.observation_shapes if len(shape) == 1
)
self.vis_obs_size = sum(
1 for shape in behavior_spec.observation_shapes if len(shape) == 3
)
self.model_path = model_path
self.initialize_path = self.trainer_settings.init_path
self._keep_checkpoints = self.trainer_settings.keep_checkpoints
self.use_continuous_act = behavior_spec.is_action_continuous()
self.num_branches = self.behavior_spec.action_size
self.previous_action_dict: Dict[str, np.array] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings.network_settings.normalize
self.use_recurrent = self.network_settings.memory is not None
self.load = load
self.h_size = self.network_settings.hidden_units
num_layers = self.network_settings.num_layers
if num_layers < 1:
num_layers = 1
self.num_layers = num_layers
self.vis_encode_type = self.network_settings.vis_encode_type
self.tanh_squash = tanh_squash
self.reparameterize = reparameterize
self.condition_sigma_on_obs = condition_sigma_on_obs
self.m_size = 0
self.sequence_length = 1
if self.network_settings.memory is not None:
self.m_size = self.network_settings.memory.memory_size
self.sequence_length = self.network_settings.memory.sequence_length
# Non-exposed parameters; these aren't exposed because they don't have a
# good explanation and usually shouldn't be touched.
self.log_std_min = -20
self.log_std_max = 2
def make_empty_memory(self, num_agents):
"""
Creates empty memory for use with RNNs
:param num_agents: Number of agents.
:return: Numpy array of zeros.
"""
return np.zeros((num_agents, self.m_size), dtype=np.float32)
def save_memories(
self, agent_ids: List[str], memory_matrix: Optional[np.ndarray]
) -> None:
if memory_matrix is None:
return
for index, agent_id in enumerate(agent_ids):
self.memory_dict[agent_id] = memory_matrix[index, :]
def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray:
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.memory_dict:
memory_matrix[index, :] = self.memory_dict[agent_id]
return memory_matrix
def remove_memories(self, agent_ids):
for agent_id in agent_ids:
if agent_id in self.memory_dict:
self.memory_dict.pop(agent_id)
def make_empty_previous_action(self, num_agents):
"""
Creates empty previous action for use with RNNs and discrete control
:param num_agents: Number of agents.
:return: Numpy array of zeros.
"""
return np.zeros((num_agents, self.num_branches), dtype=np.int)
def save_previous_action(
self, agent_ids: List[str], action_matrix: Optional[np.ndarray]
) -> None:
if action_matrix is None:
return
for index, agent_id in enumerate(agent_ids):
self.previous_action_dict[agent_id] = action_matrix[index, :]
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray:
action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.previous_action_dict:
action_matrix[index, :] = self.previous_action_dict[agent_id]
return action_matrix
def remove_previous_action(self, agent_ids):
for agent_id in agent_ids:
if agent_id in self.previous_action_dict:
self.previous_action_dict.pop(agent_id)
raise NotImplementedError
@abstractmethod
def update_normalization(self, vector_obs: np.ndarray) -> None:
pass
@abstractmethod
def increment_step(self, n_steps):
pass
@abstractmethod
def get_current_step(self):
pass
@abstractmethod
def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None:
pass
@abstractmethod
def save(self, output_filepath: str, settings: SerializationSettings) -> None:
pass

275
ml-agents/mlagents/trainers/policy/tf_policy.py


from typing import Any, Dict, List, Optional, Tuple
import abc
from mlagents_envs.timers import timed
from mlagents.model_serialization import SerializationSettings, export_policy_model
from mlagents.tf_utils import tf

from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.behavior_id_utils import get_global_agent_id
from mlagents_envs.base_env import DecisionSteps
from mlagents.trainers.models import ModelUtils
from mlagents.trainers.settings import TrainerSettings, NetworkSettings
from mlagents.trainers.models import ModelUtils, EncoderType
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,
)
logger = get_logger(__name__)

# determines compatibility with inference in Barracuda.
MODEL_FORMAT_VERSION = 2
EPSILON = 1e-6 # Small value to avoid divide by zero
class UnityPolicyException(UnityException):

trainer_settings: TrainerSettings,
model_path: str,
load: bool = False,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
create_tf_graph: bool = True,
):
"""
Initialized the policy.

:param model_path: Where to load/save the model.
:param load: If True, load model from model_path. Otherwise, create new model.
"""
self.m_size = 0
self.trainer_settings = trainer_settings
self.network_settings: NetworkSettings = trainer_settings.network_settings
super().__init__(
seed,
behavior_spec,
trainer_settings,
model_path,
load,
tanh_squash,
reparameterize,
condition_sigma_on_obs,
)
self.inference_dict: Dict[str, tf.Tensor] = {}
self.sequence_length = 1
self.seed = seed
self.behavior_spec = behavior_spec
self.inference_dict: Dict[str, tf.Tensor] = {}
self.act_size = (
list(behavior_spec.discrete_action_branches)
if behavior_spec.is_action_discrete()
else [behavior_spec.action_size]
)
self.vec_obs_size = sum(
shape[0] for shape in behavior_spec.observation_shapes if len(shape) == 1
)
self.vis_obs_size = sum(
1 for shape in behavior_spec.observation_shapes if len(shape) == 3
)
self.use_recurrent = self.network_settings.memory is not None
self.memory_dict: Dict[str, np.ndarray] = {}
self.num_branches = self.behavior_spec.action_size
self.previous_action_dict: Dict[str, np.array] = {}
self.normalize = self.network_settings.normalize
self.use_continuous_act = behavior_spec.is_action_continuous()
self.model_path = model_path
self.initialize_path = self.trainer_settings.init_path
self.keep_checkpoints = self.trainer_settings.keep_checkpoints
self.seed = seed
if self.network_settings.memory is not None:
self.m_size = self.network_settings.memory.memory_size
self.sequence_length = self.network_settings.memory.sequence_length
self.load = load
self.grads = None
self.update_batch: Optional[tf.Operation] = None
self.trainable_variables: List[tf.Variable] = []
if create_tf_graph:
self.create_tf_graph()
@abc.abstractmethod
pass
return self.trainable_variables
@abc.abstractmethod
def create_tf_graph(self):
def create_tf_graph(self) -> None:
pass
with self.graph.as_default():
tf.set_random_seed(self.seed)
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if len(_vars) > 0:
# We assume the first thing created in the graph is the Policy. If
# already populated, don't create more tensors.
return
self.create_input_placeholders()
encoded = self._create_encoder(
self.visual_in,
self.processed_vector_in,
self.h_size,
self.num_layers,
self.vis_encode_type,
)
if self.use_continuous_act:
self._create_cc_actor(
encoded,
self.tanh_squash,
self.reparameterize,
self.condition_sigma_on_obs,
)
else:
self._create_dc_actor(encoded)
self.trainable_variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy"
)
self.trainable_variables += tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm"
) # LSTMs need to be root scope for Barracuda export
self.inference_dict = {
"action": self.output,
"log_probs": self.all_log_probs,
"entropy": self.entropy,
}
if self.use_continuous_act:
self.inference_dict["pre_action"] = self.output_pre
if self.use_recurrent:
self.inference_dict["memory_out"] = self.memory_out
# We do an initialize to make the Policy usable out of the box. If an optimizer is needed,
# it will re-load the full graph
self._initialize_graph()
def _create_encoder(
self,
visual_in: List[tf.Tensor],
vector_in: tf.Tensor,
h_size: int,
num_layers: int,
vis_encode_type: EncoderType,
) -> tf.Tensor:
"""
Creates an encoder for visual and vector observations.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
:return: The hidden layer (tf.Tensor) after the encoder.
"""
with tf.variable_scope("policy"):
encoded = ModelUtils.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,
h_size,
num_layers,
vis_encode_type,
)[0]
return encoded
@staticmethod
def _convert_version_string(version_string: str) -> Tuple[int, ...]:

def _initialize_graph(self):
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
logger.info(f"Loading model from {model_path}.")
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt is None:

feed_dict[assign_ph] = value
self.sess.run(self.assign_ops, feed_dict=feed_dict)
@timed
:param decision_requests: DecisionSteps input to network.
:return: Output from policy based on self.inference_dict.
:param decision_requests: DecisionSteps object containing inputs.
:param global_agent_ids: The global (with worker ID) agent ids of the data in the batched_step_result.
:return: Outputs from network as defined by self.inference_dict.
raise UnityPolicyException("The evaluate function was not implemented.")
feed_dict = {
self.batch_size_ph: len(decision_requests),
self.sequence_length_ph: 1,
}
if self.use_recurrent:
if not self.use_continuous_act:
feed_dict[self.prev_action] = self.retrieve_previous_action(
global_agent_ids
)
feed_dict[self.memory_in] = self.retrieve_memories(global_agent_ids)
feed_dict = self.fill_eval_dict(feed_dict, decision_requests)
run_out = self._execute_model(feed_dict, self.inference_dict)
return run_out
def get_action(
self, decision_requests: DecisionSteps, worker_id: int = 0

trainable=False,
dtype=tf.int32,
)
def _create_cc_actor(
self,
encoded: tf.Tensor,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
) -> None:
"""
Creates Continuous control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
:param tanh_squash: Whether to use a tanh function, or a clipped output.
:param reparameterize: Whether we are using the resampling trick to update the policy.
"""
if self.use_recurrent:
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder(
encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy"
)
self.memory_out = tf.identity(memory_policy_out, name="recurrent_out")
else:
hidden_policy = encoded
with tf.variable_scope("policy"):
distribution = GaussianDistribution(
hidden_policy,
self.act_size,
reparameterize=reparameterize,
tanh_squash=tanh_squash,
condition_sigma=condition_sigma_on_obs,
)
if tanh_squash:
self.output_pre = distribution.sample
self.output = tf.identity(self.output_pre, name="action")
else:
self.output_pre = distribution.sample
# Clip and scale output to ensure actions are always within [-1, 1] range.
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3
self.output = tf.identity(output_post, name="action")
self.selected_actions = tf.stop_gradient(self.output)
self.all_log_probs = tf.identity(distribution.log_probs, name="action_probs")
self.entropy = distribution.entropy
# We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control.
self.total_log_probs = distribution.total_log_probs
def _create_dc_actor(self, encoded: tf.Tensor) -> None:
"""
Creates Discrete control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
"""
if self.use_recurrent:
self.prev_action = tf.placeholder(
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"
)
prev_action_oh = tf.concat(
[
tf.one_hot(self.prev_action[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
hidden_policy = tf.concat([encoded, prev_action_oh], axis=1)
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder(
hidden_policy,
self.memory_in,
self.sequence_length_ph,
name="lstm_policy",
)
self.memory_out = tf.identity(memory_policy_out, "recurrent_out")
else:
hidden_policy = encoded
self.action_masks = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks"
)
with tf.variable_scope("policy"):
distribution = MultiCategoricalDistribution(
hidden_policy, self.act_size, self.action_masks
)
# It's important that we are able to feed_dict a value into this tensor to get the
# right one-hot encoding, so we can't do identity on it.
self.output = distribution.sample
self.all_log_probs = tf.identity(distribution.log_probs, name="action")
self.selected_actions = tf.stop_gradient(
distribution.sample_onehot
) # In discrete, these are onehot
self.entropy = distribution.entropy
self.total_log_probs = distribution.total_log_probs

12
ml-agents/mlagents/trainers/ppo/trainer.py


from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.policy.nn_policy import NNPolicy
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.ppo.optimizer import PPOOptimizer

)
self.load = load
self.seed = seed
self.policy: NNPolicy = None # type: ignore
self.policy: TFPolicy = None # type: ignore
def _process_trajectory(self, trajectory: Trajectory) -> None:
"""

:param behavior_spec: specifications for policy construction
:return policy
"""
policy = NNPolicy(
policy = TFPolicy(
self.is_training,
self.artifact_path,
self.load,
model_path=self.artifact_path,
load=self.load,
condition_sigma_on_obs=False, # Faster training for PPO
create_tf_graph=False, # We will create the TF graph in the Optimizer
)

self.__class__.__name__
)
)
if not isinstance(policy, NNPolicy):
raise RuntimeError("Non-NNPolicy passed to PPOTrainer.add_policy()")
self.policy = policy
self.policies[parsed_behavior_id.behavior_id] = policy
self.optimizer = PPOOptimizer(self.policy, self.trainer_settings)

4
ml-agents/mlagents/trainers/sac/optimizer.py


# Non-exposed SAC parameters
self.discrete_target_entropy_scale = (
0.2
) # Roughly equal to e-greedy 0.05
0.2 # Roughly equal to e-greedy 0.05
)
self.continuous_target_entropy_scale = 1.0
stream_names = list(self.reward_signals.keys())

8
ml-agents/mlagents/trainers/sac/trainer.py


from mlagents_envs.timers import timed
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.policy.nn_policy import NNPolicy
from mlagents.trainers.sac.optimizer import SACOptimizer
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.trajectory import Trajectory, SplitObservations

self.load = load
self.seed = seed
self.policy: NNPolicy = None # type: ignore
self.policy: TFPolicy = None # type: ignore
self.optimizer: SACOptimizer = None # type: ignore
self.hyperparameters: SACSettings = cast(
SACSettings, trainer_settings.hyperparameters

def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
) -> TFPolicy:
policy = NNPolicy(
policy = TFPolicy(
self.is_training,
self.artifact_path,
self.load,
tanh_squash=True,

self.__class__.__name__
)
)
if not isinstance(policy, NNPolicy):
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()")
self.policy = policy
self.policies[parsed_behavior_id.behavior_id] = policy
self.optimizer = SACOptimizer(self.policy, self.trainer_settings)

7
ml-agents/mlagents/trainers/tests/test_bcmodule.py


import numpy as np
import os
from mlagents.trainers.policy.nn_policy import NNPolicy
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.components.bc.module import BCModule
from mlagents.trainers.settings import (
TrainerSettings,

trainer_config.network_settings.memory = (
NetworkSettings.MemorySettings() if use_rnn else None
)
policy = NNPolicy(
policy = TFPolicy(
False,
"test",
False,
tanhresample,

assert isinstance(item, np.float32)
old_learning_rate = bc_module.current_lr
stats = bc_module.update()
_ = bc_module.update()
assert old_learning_rate == bc_module.current_lr

13
ml-agents/mlagents/trainers/tests/test_nn_policy.py


from mlagents.tf_utils import tf
from mlagents.trainers.policy.nn_policy import NNPolicy
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.models import EncoderType, ModelUtils, Tensor3DShape
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.tests import mock_brain as mb

model_path: str = "",
load: bool = False,
seed: int = 0,
) -> NNPolicy:
) -> TFPolicy:
mock_spec = mb.setup_test_behavior_specs(
use_discrete,
use_visual,

trainer_settings.network_settings.memory = (
NetworkSettings.MemorySettings() if use_rnn else None
)
policy = NNPolicy(seed, mock_spec, trainer_settings, False, model_path, load)
policy = TFPolicy(
seed, mock_spec, trainer_settings, model_path=model_path, load=load
)
return policy

assert len(cm.output) == 1
def _compare_two_policies(policy1: NNPolicy, policy2: NNPolicy) -> None:
def _compare_two_policies(policy1: TFPolicy, policy2: TFPolicy) -> None:
"""
Make sure two policies have the same output for the same input.
"""

# Change half of the obs to 0
for i in range(3):
trajectory.steps[i].obs[0] = np.zeros(1, dtype=np.float32)
policy = NNPolicy(
policy = TFPolicy(
False,
"testdir",
False,
)

15
ml-agents/mlagents/trainers/tests/test_ppo.py


from mlagents.trainers.ppo.trainer import PPOTrainer, discount_rewards
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.policy.nn_policy import NNPolicy
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory

if use_rnn
else None
)
policy = NNPolicy(
0, mock_specs, trainer_settings, False, "test", False, create_tf_graph=False
policy = TFPolicy(
0, mock_specs, trainer_settings, "test", False, create_tf_graph=False
)
optimizer = PPOOptimizer(policy, trainer_settings)
return optimizer

ppo_optimizer.return_value = mock_optimizer
trainer = PPOTrainer("test_brain", 0, trainer_params, True, False, 0, "0")
policy_mock = mock.Mock(spec=NNPolicy)
policy_mock = mock.Mock(spec=TFPolicy)
policy_mock.get_current_step.return_value = 0
step_count = (
5 # 10 hacked because this function is no longer called through trainer

ppo_optimizer.return_value = mock_optimizer
trainer = PPOTrainer("test_policy", 0, dummy_config, True, False, 0, "0")
policy = mock.Mock(spec=NNPolicy)
policy = mock.Mock(spec=TFPolicy)
policy.get_current_step.return_value = 2000
behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name)

# Make sure the summary steps were loaded properly
assert trainer.get_step == 2000
# Test incorrect class of policy
policy = mock.Mock()
with pytest.raises(RuntimeError):
trainer.add_policy(behavior_id, policy)
if __name__ == "__main__":

6
ml-agents/mlagents/trainers/tests/test_reward_signals.py


import copy
import os
import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.policy.nn_policy import NNPolicy
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.sac.optimizer import SACOptimizer
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.tests.test_simple_rl import PPO_CONFIG, SAC_CONFIG

if use_rnn
else None
)
policy = NNPolicy(
0, mock_specs, trainer_settings, False, "test", False, create_tf_graph=False
policy = TFPolicy(
0, mock_specs, trainer_settings, "test", False, create_tf_graph=False
)
if trainer_settings.trainer_type == TrainerType.SAC:
optimizer = SACOptimizer(policy, trainer_settings)

13
ml-agents/mlagents/trainers/tests/test_sac.py


from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.sac.optimizer import SACOptimizer
from mlagents.trainers.policy.nn_policy import NNPolicy
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.mock_brain import setup_test_behavior_specs

if use_rnn
else None
)
policy = NNPolicy(
0, mock_brain, trainer_settings, False, "test", False, create_tf_graph=False
policy = TFPolicy(
0, mock_brain, trainer_settings, "test", False, create_tf_graph=False
)
optimizer = SACOptimizer(policy, trainer_settings)
return optimizer

sac_optimizer.return_value = mock_optimizer
trainer = SACTrainer("test", 0, dummy_config, True, False, 0, "0")
policy = mock.Mock(spec=NNPolicy)
policy = mock.Mock(spec=TFPolicy)
policy.get_current_step.return_value = 2000
behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name)
trainer.add_policy(behavior_id, policy)

assert trainer.get_step == 2000
# Test incorrect class of policy
policy = mock.Mock()
with pytest.raises(RuntimeError):
trainer.add_policy(behavior_id, policy)
def test_advance(dummy_config):

275
ml-agents/mlagents/trainers/policy/nn_policy.py


from typing import Any, Dict, Optional, List
from mlagents.tf_utils import tf
from mlagents_envs.timers import timed
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec
from mlagents.trainers.models import EncoderType
from mlagents.trainers.models import ModelUtils
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,
)
EPSILON = 1e-6 # Small value to avoid divide by zero
class NNPolicy(TFPolicy):
def __init__(
self,
seed: int,
behavior_spec: BehaviorSpec,
trainer_params: TrainerSettings,
is_training: bool,
model_path: str,
load: bool,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
create_tf_graph: bool = True,
):
"""
Policy that uses a multilayer perceptron to map the observations to actions. Could
also use a CNN to encode visual input prior to the MLP. Supports discrete and
continuous action spaces, as well as recurrent networks.
:param seed: Random seed.
:param brain: Assigned BrainParameters object.
:param trainer_params: Defined training parameters.
:param is_training: Whether the model should be trained.
:param load: Whether a pre-trained model will be loaded or a new one created.
:param model_path: Path where the model should be saved and loaded.
:param tanh_squash: Whether to use a tanh function on the continuous output, or a clipped output.
:param reparameterize: Whether we are using the resampling trick to update the policy in continuous output.
"""
super().__init__(seed, behavior_spec, trainer_params, model_path, load)
self.grads = None
self.update_batch: Optional[tf.Operation] = None
num_layers = self.network_settings.num_layers
self.h_size = self.network_settings.hidden_units
if num_layers < 1:
num_layers = 1
self.num_layers = num_layers
self.vis_encode_type = self.network_settings.vis_encode_type
self.tanh_squash = tanh_squash
self.reparameterize = reparameterize
self.condition_sigma_on_obs = condition_sigma_on_obs
self.trainable_variables: List[tf.Variable] = []
# Non-exposed parameters; these aren't exposed because they don't have a
# good explanation and usually shouldn't be touched.
self.log_std_min = -20
self.log_std_max = 2
if create_tf_graph:
self.create_tf_graph()
def get_trainable_variables(self) -> List[tf.Variable]:
"""
Returns a List of the trainable variables in this policy. if create_tf_graph hasn't been called,
returns empty list.
"""
return self.trainable_variables
def create_tf_graph(self) -> None:
"""
Builds the tensorflow graph needed for this policy.
"""
with self.graph.as_default():
tf.set_random_seed(self.seed)
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if len(_vars) > 0:
# We assume the first thing created in the graph is the Policy. If
# already populated, don't create more tensors.
return
self.create_input_placeholders()
encoded = self._create_encoder(
self.visual_in,
self.processed_vector_in,
self.h_size,
self.num_layers,
self.vis_encode_type,
)
if self.use_continuous_act:
self._create_cc_actor(
encoded,
self.tanh_squash,
self.reparameterize,
self.condition_sigma_on_obs,
)
else:
self._create_dc_actor(encoded)
self.trainable_variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy"
)
self.trainable_variables += tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm"
) # LSTMs need to be root scope for Barracuda export
self.inference_dict: Dict[str, tf.Tensor] = {
"action": self.output,
"log_probs": self.all_log_probs,
"entropy": self.entropy,
}
if self.use_continuous_act:
self.inference_dict["pre_action"] = self.output_pre
if self.use_recurrent:
self.inference_dict["memory_out"] = self.memory_out
# We do an initialize to make the Policy usable out of the box. If an optimizer is needed,
# it will re-load the full graph
self._initialize_graph()
@timed
def evaluate(
self, decision_requests: DecisionSteps, global_agent_ids: List[str]
) -> Dict[str, Any]:
"""
Evaluates policy for the agent experiences provided.
:param decision_requests: DecisionSteps object containing inputs.
:param global_agent_ids: The global (with worker ID) agent ids of the data in the batched_step_result.
:return: Outputs from network as defined by self.inference_dict.
"""
feed_dict = {
self.batch_size_ph: len(decision_requests),
self.sequence_length_ph: 1,
}
if self.use_recurrent:
if not self.use_continuous_act:
feed_dict[self.prev_action] = self.retrieve_previous_action(
global_agent_ids
)
feed_dict[self.memory_in] = self.retrieve_memories(global_agent_ids)
feed_dict = self.fill_eval_dict(feed_dict, decision_requests)
run_out = self._execute_model(feed_dict, self.inference_dict)
return run_out
def _create_encoder(
self,
visual_in: List[tf.Tensor],
vector_in: tf.Tensor,
h_size: int,
num_layers: int,
vis_encode_type: EncoderType,
) -> tf.Tensor:
"""
Creates an encoder for visual and vector observations.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
:return: The hidden layer (tf.Tensor) after the encoder.
"""
with tf.variable_scope("policy"):
encoded = ModelUtils.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,
h_size,
num_layers,
vis_encode_type,
)[0]
return encoded
def _create_cc_actor(
self,
encoded: tf.Tensor,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
) -> None:
"""
Creates Continuous control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
:param tanh_squash: Whether to use a tanh function, or a clipped output.
:param reparameterize: Whether we are using the resampling trick to update the policy.
"""
if self.use_recurrent:
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder(
encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy"
)
self.memory_out = tf.identity(memory_policy_out, name="recurrent_out")
else:
hidden_policy = encoded
with tf.variable_scope("policy"):
distribution = GaussianDistribution(
hidden_policy,
self.act_size,
reparameterize=reparameterize,
tanh_squash=tanh_squash,
condition_sigma=condition_sigma_on_obs,
)
if tanh_squash:
self.output_pre = distribution.sample
self.output = tf.identity(self.output_pre, name="action")
else:
self.output_pre = distribution.sample
# Clip and scale output to ensure actions are always within [-1, 1] range.
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3
self.output = tf.identity(output_post, name="action")
self.selected_actions = tf.stop_gradient(self.output)
self.all_log_probs = tf.identity(distribution.log_probs, name="action_probs")
self.entropy = distribution.entropy
# We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control.
self.total_log_probs = distribution.total_log_probs
def _create_dc_actor(self, encoded: tf.Tensor) -> None:
"""
Creates Discrete control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
"""
if self.use_recurrent:
self.prev_action = tf.placeholder(
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"
)
prev_action_oh = tf.concat(
[
tf.one_hot(self.prev_action[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
hidden_policy = tf.concat([encoded, prev_action_oh], axis=1)
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder(
hidden_policy,
self.memory_in,
self.sequence_length_ph,
name="lstm_policy",
)
self.memory_out = tf.identity(memory_policy_out, "recurrent_out")
else:
hidden_policy = encoded
self.action_masks = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks"
)
with tf.variable_scope("policy"):
distribution = MultiCategoricalDistribution(
hidden_policy, self.act_size, self.action_masks
)
# It's important that we are able to feed_dict a value into this tensor to get the
# right one-hot encoding, so we can't do identity on it.
self.output = distribution.sample
self.all_log_probs = tf.identity(distribution.log_probs, name="action")
self.selected_actions = tf.stop_gradient(
distribution.sample_onehot
) # In discrete, these are onehot
self.entropy = distribution.entropy
self.total_log_probs = distribution.total_log_probs
正在加载...
取消
保存