浏览代码
Action Model (#4580)
Action Model (#4580)
Co-authored-by: Ervin T <ervin@unity3d.com> Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com>/fix-conflict-base-env
GitHub
4 年前
当前提交
3c96a3a2
共有 43 个文件被更改,包括 1337 次插入 和 1004 次删除
-
4.github/workflows/pytest.yml
-
93ml-agents-envs/mlagents_envs/base_env.py
-
2ml-agents-envs/mlagents_envs/rpc_utils.py
-
26ml-agents/mlagents/trainers/agent_processor.py
-
5ml-agents/mlagents/trainers/demo_loader.py
-
21ml-agents/mlagents/trainers/env_manager.py
-
17ml-agents/mlagents/trainers/policy/policy.py
-
24ml-agents/mlagents/trainers/policy/tf_policy.py
-
54ml-agents/mlagents/trainers/policy/torch_policy.py
-
6ml-agents/mlagents/trainers/ppo/optimizer_tf.py
-
5ml-agents/mlagents/trainers/ppo/optimizer_torch.py
-
2ml-agents/mlagents/trainers/ppo/trainer.py
-
319ml-agents/mlagents/trainers/sac/optimizer_torch.py
-
3ml-agents/mlagents/trainers/simple_env_manager.py
-
7ml-agents/mlagents/trainers/subprocess_env_manager.py
-
20ml-agents/mlagents/trainers/tests/mock_brain.py
-
49ml-agents/mlagents/trainers/tests/simple_test_envs.py
-
66ml-agents/mlagents/trainers/tests/tensorflow/test_ppo.py
-
114ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py
-
2ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py
-
27ml-agents/mlagents/trainers/tests/test_agent_processor.py
-
2ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py
-
4ml-agents/mlagents/trainers/tests/test_trajectory.py
-
2ml-agents/mlagents/trainers/tests/torch/test_distributions.py
-
78ml-agents/mlagents/trainers/tests/torch/test_networks.py
-
13ml-agents/mlagents/trainers/tests/torch/test_policy.py
-
28ml-agents/mlagents/trainers/tests/torch/test_ppo.py
-
3ml-agents/mlagents/trainers/tests/torch/test_sac.py
-
118ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
-
47ml-agents/mlagents/trainers/tests/torch/test_utils.py
-
35ml-agents/mlagents/trainers/torch/components/bc/module.py
-
75ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
-
6ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
-
19ml-agents/mlagents/trainers/torch/distributions.py
-
183ml-agents/mlagents/trainers/torch/networks.py
-
240ml-agents/mlagents/trainers/torch/utils.py
-
25ml-agents/mlagents/trainers/trajectory.py
-
81ml-agents/mlagents/trainers/tests/torch/test_action_model.py
-
122ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
-
44ml-agents/mlagents/trainers/torch/action_flattener.py
-
108ml-agents/mlagents/trainers/torch/action_log_probs.py
-
184ml-agents/mlagents/trainers/torch/action_model.py
-
58ml-agents/mlagents/trainers/torch/agent_action.py
|
|||
import pytest |
|||
|
|||
from mlagents.torch_utils import torch |
|||
from mlagents.trainers.torch.action_model import ActionModel, DistInstances |
|||
from mlagents.trainers.torch.agent_action import AgentAction |
|||
from mlagents.trainers.torch.distributions import ( |
|||
GaussianDistInstance, |
|||
CategoricalDistInstance, |
|||
) |
|||
|
|||
from mlagents_envs.base_env import ActionSpec |
|||
|
|||
|
|||
def create_action_model(inp_size, act_size): |
|||
mask = torch.ones([1, act_size * 2]) |
|||
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size))) |
|||
action_model = ActionModel(inp_size, action_spec) |
|||
return action_model, mask |
|||
|
|||
|
|||
def test_get_dists(): |
|||
inp_size = 4 |
|||
act_size = 2 |
|||
action_model, masks = create_action_model(inp_size, act_size) |
|||
sample_inp = torch.ones((1, inp_size)) |
|||
dists = action_model._get_dists(sample_inp, masks=masks) |
|||
assert isinstance(dists.continuous, GaussianDistInstance) |
|||
assert len(dists.discrete) == 2 |
|||
for _dist in dists.discrete: |
|||
assert isinstance(_dist, CategoricalDistInstance) |
|||
|
|||
|
|||
def test_sample_action(): |
|||
inp_size = 4 |
|||
act_size = 2 |
|||
action_model, masks = create_action_model(inp_size, act_size) |
|||
sample_inp = torch.ones((1, inp_size)) |
|||
dists = action_model._get_dists(sample_inp, masks=masks) |
|||
agent_action = action_model._sample_action(dists) |
|||
assert agent_action.continuous_tensor.shape == (1, 2) |
|||
assert len(agent_action.discrete_list) == 2 |
|||
for _disc in agent_action.discrete_list: |
|||
assert _disc.shape == (1, 1) |
|||
|
|||
|
|||
def test_get_probs_and_entropy(): |
|||
inp_size = 4 |
|||
act_size = 2 |
|||
action_model, masks = create_action_model(inp_size, act_size) |
|||
|
|||
_continuous_dist = GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))) |
|||
act_size = 2 |
|||
test_prob = torch.tensor([[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)]) |
|||
_discrete_dist_list = [ |
|||
CategoricalDistInstance(test_prob), |
|||
CategoricalDistInstance(test_prob), |
|||
] |
|||
dist_tuple = DistInstances(_continuous_dist, _discrete_dist_list) |
|||
|
|||
agent_action = AgentAction( |
|||
torch.zeros((1, 2)), [torch.tensor([0]), torch.tensor([1])] |
|||
) |
|||
|
|||
log_probs, entropies = action_model._get_probs_and_entropy(agent_action, dist_tuple) |
|||
|
|||
assert log_probs.continuous_tensor.shape == (1, 2) |
|||
assert len(log_probs.discrete_list) == 2 |
|||
for _disc in log_probs.discrete_list: |
|||
assert _disc.shape == (1,) |
|||
assert len(log_probs.all_discrete_list) == 2 |
|||
for _disc in log_probs.all_discrete_list: |
|||
assert _disc.shape == (1, 2) |
|||
|
|||
for clp in log_probs.continuous_tensor[0]: |
|||
# Log prob of standard normal at 0 |
|||
assert clp == pytest.approx(-0.919, abs=0.01) |
|||
|
|||
assert log_probs.discrete_list[0] > log_probs.discrete_list[1] |
|||
|
|||
for ent, val in zip(entropies[0], [1.4189, 1.4189, 0.6191, 0.6191]): |
|||
assert ent == pytest.approx(val, abs=0.01) |
|
|||
import attr |
|||
import pytest |
|||
|
|||
|
|||
from mlagents.trainers.tests.simple_test_envs import ( |
|||
SimpleEnvironment, |
|||
MemoryEnvironment, |
|||
) |
|||
|
|||
from mlagents.trainers.settings import NetworkSettings, FrameworkType |
|||
|
|||
from mlagents.trainers.tests.dummy_config import ppo_dummy_config, sac_dummy_config |
|||
from mlagents.trainers.tests.check_env_trains import check_environment_trains |
|||
|
|||
BRAIN_NAME = "1D" |
|||
|
|||
PPO_TORCH_CONFIG = attr.evolve(ppo_dummy_config(), framework=FrameworkType.PYTORCH) |
|||
SAC_TORCH_CONFIG = attr.evolve(sac_dummy_config(), framework=FrameworkType.PYTORCH) |
|||
|
|||
|
|||
@pytest.mark.parametrize("action_size", [(1, 1), (2, 2), (1, 2), (2, 1)]) |
|||
def test_hybrid_ppo(action_size): |
|||
env = SimpleEnvironment([BRAIN_NAME], action_sizes=action_size, step_size=0.8) |
|||
new_network_settings = attr.evolve(PPO_TORCH_CONFIG.network_settings) |
|||
new_hyperparams = attr.evolve( |
|||
PPO_TORCH_CONFIG.hyperparameters, batch_size=64, buffer_size=1024 |
|||
) |
|||
config = attr.evolve( |
|||
PPO_TORCH_CONFIG, |
|||
hyperparameters=new_hyperparams, |
|||
network_settings=new_network_settings, |
|||
max_steps=10000, |
|||
) |
|||
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9) |
|||
|
|||
|
|||
@pytest.mark.parametrize("num_visual", [1, 2]) |
|||
def test_hybrid_visual_ppo(num_visual): |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME], num_visual=num_visual, num_vector=0, action_sizes=(1, 1) |
|||
) |
|||
new_hyperparams = attr.evolve( |
|||
PPO_TORCH_CONFIG.hyperparameters, learning_rate=3.0e-4 |
|||
) |
|||
config = attr.evolve(PPO_TORCH_CONFIG, hyperparameters=new_hyperparams) |
|||
check_environment_trains(env, {BRAIN_NAME: config}) |
|||
|
|||
|
|||
def test_hybrid_recurrent_ppo(): |
|||
env = MemoryEnvironment([BRAIN_NAME], action_sizes=(1, 1), step_size=0.5) |
|||
new_network_settings = attr.evolve( |
|||
PPO_TORCH_CONFIG.network_settings, |
|||
memory=NetworkSettings.MemorySettings(memory_size=16), |
|||
) |
|||
new_hyperparams = attr.evolve( |
|||
PPO_TORCH_CONFIG.hyperparameters, |
|||
learning_rate=1.0e-3, |
|||
batch_size=64, |
|||
buffer_size=512, |
|||
) |
|||
config = attr.evolve( |
|||
PPO_TORCH_CONFIG, |
|||
hyperparameters=new_hyperparams, |
|||
network_settings=new_network_settings, |
|||
max_steps=3000, |
|||
) |
|||
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9) |
|||
|
|||
|
|||
@pytest.mark.parametrize("action_size", [(1, 1), (2, 2), (1, 2), (2, 1)]) |
|||
def test_hybrid_sac(action_size): |
|||
env = SimpleEnvironment([BRAIN_NAME], action_sizes=action_size, step_size=0.8) |
|||
|
|||
new_hyperparams = attr.evolve( |
|||
SAC_TORCH_CONFIG.hyperparameters, |
|||
buffer_size=50000, |
|||
batch_size=256, |
|||
buffer_init_steps=2000, |
|||
) |
|||
config = attr.evolve( |
|||
SAC_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=5000 |
|||
) |
|||
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9) |
|||
|
|||
|
|||
@pytest.mark.parametrize("num_visual", [1, 2]) |
|||
def test_hybrid_visual_sac(num_visual): |
|||
env = SimpleEnvironment( |
|||
[BRAIN_NAME], num_visual=num_visual, num_vector=0, action_sizes=(1, 1) |
|||
) |
|||
new_hyperparams = attr.evolve( |
|||
SAC_TORCH_CONFIG.hyperparameters, |
|||
buffer_size=50000, |
|||
batch_size=128, |
|||
learning_rate=3.0e-4, |
|||
) |
|||
config = attr.evolve( |
|||
SAC_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=3000 |
|||
) |
|||
check_environment_trains(env, {BRAIN_NAME: config}) |
|||
|
|||
|
|||
def test_hybrid_recurrent_sac(): |
|||
env = MemoryEnvironment([BRAIN_NAME], action_sizes=(1, 1), step_size=0.5) |
|||
new_networksettings = attr.evolve( |
|||
SAC_TORCH_CONFIG.network_settings, |
|||
memory=NetworkSettings.MemorySettings(memory_size=16, sequence_length=16), |
|||
) |
|||
new_hyperparams = attr.evolve( |
|||
SAC_TORCH_CONFIG.hyperparameters, |
|||
batch_size=256, |
|||
learning_rate=1e-3, |
|||
buffer_init_steps=1000, |
|||
steps_per_update=2, |
|||
) |
|||
config = attr.evolve( |
|||
SAC_TORCH_CONFIG, |
|||
hyperparameters=new_hyperparams, |
|||
network_settings=new_networksettings, |
|||
max_steps=2000, |
|||
) |
|||
check_environment_trains(env, {BRAIN_NAME: config}) |
|
|||
from typing import List |
|||
from mlagents.torch_utils import torch |
|||
|
|||
from mlagents_envs.base_env import ActionSpec |
|||
from mlagents.trainers.torch.agent_action import AgentAction |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
|
|||
class ActionFlattener: |
|||
def __init__(self, action_spec: ActionSpec): |
|||
""" |
|||
A torch module that creates the flattened form of an AgentAction object. |
|||
The flattened form is the continuous action concatenated with the |
|||
concatenated one hot encodings of the discrete actions. |
|||
:param action_spec: An ActionSpec that describes the action space dimensions |
|||
""" |
|||
self._specs = action_spec |
|||
|
|||
@property |
|||
def flattened_size(self) -> int: |
|||
""" |
|||
The flattened size is the continuous size plus the sum of the branch sizes |
|||
since discrete actions are encoded as one hots. |
|||
""" |
|||
return self._specs.continuous_size + sum(self._specs.discrete_branches) |
|||
|
|||
def forward(self, action: AgentAction) -> torch.Tensor: |
|||
""" |
|||
Returns a tensor corresponding the flattened action |
|||
:param action: An AgentAction object |
|||
""" |
|||
action_list: List[torch.Tensor] = [] |
|||
if self._specs.continuous_size > 0: |
|||
action_list.append(action.continuous_tensor) |
|||
if self._specs.discrete_size > 0: |
|||
flat_discrete = torch.cat( |
|||
ModelUtils.actions_to_onehot( |
|||
torch.as_tensor(action.discrete_tensor, dtype=torch.long), |
|||
self._specs.discrete_branches, |
|||
), |
|||
dim=1, |
|||
) |
|||
action_list.append(flat_discrete) |
|||
return torch.cat(action_list, dim=1) |
|
|||
from typing import List, Optional, NamedTuple, Dict |
|||
from mlagents.torch_utils import torch |
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents_envs.base_env import _ActionTupleBase |
|||
|
|||
|
|||
class LogProbsTuple(_ActionTupleBase): |
|||
""" |
|||
An object whose fields correspond to the log probs of actions of different types. |
|||
Continuous and discrete are numpy arrays |
|||
Dimensions are of (n_agents, continuous_size) and (n_agents, discrete_size), |
|||
respectively. Note, this also holds when continuous or discrete size is |
|||
zero. |
|||
""" |
|||
|
|||
@property |
|||
def discrete_dtype(self) -> np.dtype: |
|||
""" |
|||
The dtype of a discrete log probability. |
|||
""" |
|||
return np.float32 |
|||
|
|||
|
|||
class ActionLogProbs(NamedTuple): |
|||
""" |
|||
A NamedTuple containing the tensor for continuous log probs and list of tensors for |
|||
discrete log probs of individual actions as well as all the log probs for an entire branch. |
|||
Utility functions provide numpy <=> tensor conversions to be used by the optimizers. |
|||
:param continuous_tensor: Torch tensor corresponding to log probs of continuous actions |
|||
:param discrete_list: List of Torch tensors each corresponding to log probs of the discrete actions that were |
|||
sampled. |
|||
:param all_discrete_list: List of Torch tensors each corresponding to all log probs of |
|||
a discrete action branch, even the discrete actions that were not sampled. all_discrete_list is a list of Tensors, |
|||
each Tensor corresponds to one discrete branch log probabilities. |
|||
""" |
|||
|
|||
continuous_tensor: torch.Tensor |
|||
discrete_list: Optional[List[torch.Tensor]] |
|||
all_discrete_list: Optional[List[torch.Tensor]] |
|||
|
|||
@property |
|||
def discrete_tensor(self): |
|||
""" |
|||
Returns the discrete log probs list as a stacked tensor |
|||
""" |
|||
return torch.stack(self.discrete_list, dim=-1) |
|||
|
|||
@property |
|||
def all_discrete_tensor(self): |
|||
""" |
|||
Returns the discrete log probs of each branch as a tensor |
|||
""" |
|||
return torch.cat(self.all_discrete_list, dim=1) |
|||
|
|||
def to_log_probs_tuple(self) -> LogProbsTuple: |
|||
""" |
|||
Returns a LogProbsTuple. Only adds if tensor is not None. Otherwise, |
|||
LogProbsTuple uses a default. |
|||
""" |
|||
log_probs_tuple = LogProbsTuple() |
|||
if self.continuous_tensor is not None: |
|||
continuous = ModelUtils.to_numpy(self.continuous_tensor) |
|||
log_probs_tuple.add_continuous(continuous) |
|||
if self.discrete_list is not None: |
|||
discrete = ModelUtils.to_numpy(self.discrete_tensor) |
|||
log_probs_tuple.add_discrete(discrete) |
|||
return log_probs_tuple |
|||
|
|||
def _to_tensor_list(self) -> List[torch.Tensor]: |
|||
""" |
|||
Returns the tensors in the ActionLogProbs as a flat List of torch Tensors. This |
|||
is private and serves as a utility for self.flatten() |
|||
""" |
|||
tensor_list: List[torch.Tensor] = [] |
|||
if self.continuous_tensor is not None: |
|||
tensor_list.append(self.continuous_tensor) |
|||
if self.discrete_list is not None: |
|||
tensor_list.append(self.discrete_tensor) |
|||
return tensor_list |
|||
|
|||
def flatten(self) -> torch.Tensor: |
|||
""" |
|||
A utility method that returns all log probs in ActionLogProbs as a flattened tensor. |
|||
This is useful for algorithms like PPO which can treat all log probs in the same way. |
|||
""" |
|||
return torch.cat(self._to_tensor_list(), dim=1) |
|||
|
|||
@staticmethod |
|||
def from_dict(buff: Dict[str, np.ndarray]) -> "ActionLogProbs": |
|||
""" |
|||
A static method that accesses continuous and discrete log probs fields in an AgentBuffer |
|||
and constructs the corresponding ActionLogProbs from the retrieved np arrays. |
|||
""" |
|||
continuous: torch.Tensor = None |
|||
discrete: List[torch.Tensor] = None # type: ignore |
|||
|
|||
if "continuous_log_probs" in buff: |
|||
continuous = ModelUtils.list_to_tensor(buff["continuous_log_probs"]) |
|||
if "discrete_log_probs" in buff: |
|||
discrete_tensor = ModelUtils.list_to_tensor(buff["discrete_log_probs"]) |
|||
# This will keep discrete_list = None which enables flatten() |
|||
if discrete_tensor.shape[1] > 0: |
|||
discrete = [ |
|||
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1]) |
|||
] |
|||
return ActionLogProbs(continuous, discrete, None) |
|
|||
from typing import List, Tuple, NamedTuple, Optional |
|||
from mlagents.torch_utils import torch, nn |
|||
from mlagents.trainers.torch.distributions import ( |
|||
DistInstance, |
|||
DiscreteDistInstance, |
|||
GaussianDistribution, |
|||
MultiCategoricalDistribution, |
|||
) |
|||
from mlagents.trainers.torch.agent_action import AgentAction |
|||
from mlagents.trainers.torch.action_log_probs import ActionLogProbs |
|||
from mlagents_envs.base_env import ActionSpec |
|||
|
|||
EPSILON = 1e-7 # Small value to avoid divide by zero |
|||
|
|||
|
|||
class DistInstances(NamedTuple): |
|||
""" |
|||
A NamedTuple with fields corresponding the the DistInstance objects |
|||
output by continuous and discrete distributions, respectively. Discrete distributions |
|||
output a list of DistInstance objects whereas continuous distributions output a single |
|||
DistInstance object. |
|||
""" |
|||
|
|||
continuous: Optional[DistInstance] |
|||
discrete: Optional[List[DiscreteDistInstance]] |
|||
|
|||
|
|||
class ActionModel(nn.Module): |
|||
def __init__( |
|||
self, |
|||
hidden_size: int, |
|||
action_spec: ActionSpec, |
|||
conditional_sigma: bool = False, |
|||
tanh_squash: bool = False, |
|||
): |
|||
""" |
|||
A torch module that represents the action space of a policy. The ActionModel may contain |
|||
a continuous distribution, a discrete distribution or both where construction depends on |
|||
the action_spec. The ActionModel uses the encoded input of the network body to parameterize |
|||
these distributions. The forward method of this module outputs the action, log probs, |
|||
and entropies given the encoding from the network body. |
|||
:params hidden_size: Size of the input to the ActionModel. |
|||
:params action_spec: The ActionSpec defining the action space dimensions and distributions. |
|||
:params conditional_sigma: Whether or not the std of a Gaussian is conditioned on state. |
|||
:params tanh_squash: Whether to squash the output of a Gaussian with the tanh function. |
|||
""" |
|||
super().__init__() |
|||
self.encoding_size = hidden_size |
|||
self.action_spec = action_spec |
|||
self._continuous_distribution = None |
|||
self._discrete_distribution = None |
|||
|
|||
if self.action_spec.continuous_size > 0: |
|||
self._continuous_distribution = GaussianDistribution( |
|||
self.encoding_size, |
|||
self.action_spec.continuous_size, |
|||
conditional_sigma=conditional_sigma, |
|||
tanh_squash=tanh_squash, |
|||
) |
|||
|
|||
if self.action_spec.discrete_size > 0: |
|||
self._discrete_distribution = MultiCategoricalDistribution( |
|||
self.encoding_size, self.action_spec.discrete_branches |
|||
) |
|||
|
|||
def _sample_action(self, dists: DistInstances) -> AgentAction: |
|||
""" |
|||
Samples actions from a DistInstances tuple |
|||
:params dists: The DistInstances tuple |
|||
:return: An AgentAction corresponding to the actions sampled from the DistInstances |
|||
""" |
|||
continuous_action: Optional[torch.Tensor] = None |
|||
discrete_action: Optional[List[torch.Tensor]] = None |
|||
# This checks None because mypy complains otherwise |
|||
if dists.continuous is not None: |
|||
continuous_action = dists.continuous.sample() |
|||
if dists.discrete is not None: |
|||
discrete_action = [] |
|||
for discrete_dist in dists.discrete: |
|||
discrete_action.append(discrete_dist.sample()) |
|||
return AgentAction(continuous_action, discrete_action) |
|||
|
|||
def _get_dists(self, inputs: torch.Tensor, masks: torch.Tensor) -> DistInstances: |
|||
""" |
|||
Creates a DistInstances tuple using the continuous and discrete distributions |
|||
:params inputs: The encoding from the network body |
|||
:params masks: Action masks for discrete actions |
|||
:return: A DistInstances tuple |
|||
""" |
|||
continuous_dist: Optional[DistInstance] = None |
|||
discrete_dist: Optional[List[DiscreteDistInstance]] = None |
|||
# This checks None because mypy complains otherwise |
|||
if self._continuous_distribution is not None: |
|||
continuous_dist = self._continuous_distribution(inputs) |
|||
if self._discrete_distribution is not None: |
|||
discrete_dist = self._discrete_distribution(inputs, masks) |
|||
return DistInstances(continuous_dist, discrete_dist) |
|||
|
|||
def _get_probs_and_entropy( |
|||
self, actions: AgentAction, dists: DistInstances |
|||
) -> Tuple[ActionLogProbs, torch.Tensor]: |
|||
""" |
|||
Computes the log probabilites of the actions given distributions and entropies of |
|||
the given distributions. |
|||
:params actions: The AgentAction |
|||
:params dists: The DistInstances tuple |
|||
:return: An ActionLogProbs tuple and a torch tensor of the distribution entropies. |
|||
""" |
|||
entropies_list: List[torch.Tensor] = [] |
|||
continuous_log_prob: Optional[torch.Tensor] = None |
|||
discrete_log_probs: Optional[List[torch.Tensor]] = None |
|||
all_discrete_log_probs: Optional[List[torch.Tensor]] = None |
|||
# This checks None because mypy complains otherwise |
|||
if dists.continuous is not None: |
|||
continuous_log_prob = dists.continuous.log_prob(actions.continuous_tensor) |
|||
entropies_list.append(dists.continuous.entropy()) |
|||
if dists.discrete is not None: |
|||
discrete_log_probs = [] |
|||
all_discrete_log_probs = [] |
|||
for discrete_action, discrete_dist in zip( |
|||
actions.discrete_list, dists.discrete # type: ignore |
|||
): |
|||
discrete_log_prob = discrete_dist.log_prob(discrete_action) |
|||
entropies_list.append(discrete_dist.entropy()) |
|||
discrete_log_probs.append(discrete_log_prob) |
|||
all_discrete_log_probs.append(discrete_dist.all_log_prob()) |
|||
action_log_probs = ActionLogProbs( |
|||
continuous_log_prob, discrete_log_probs, all_discrete_log_probs |
|||
) |
|||
entropies = torch.cat(entropies_list, dim=1) |
|||
return action_log_probs, entropies |
|||
|
|||
def evaluate( |
|||
self, inputs: torch.Tensor, masks: torch.Tensor, actions: AgentAction |
|||
) -> Tuple[ActionLogProbs, torch.Tensor]: |
|||
""" |
|||
Given actions and encoding from the network body, gets the distributions and |
|||
computes the log probabilites and entropies. |
|||
:params inputs: The encoding from the network body |
|||
:params masks: Action masks for discrete actions |
|||
:params actions: The AgentAction |
|||
:return: An ActionLogProbs tuple and a torch tensor of the distribution entropies. |
|||
""" |
|||
dists = self._get_dists(inputs, masks) |
|||
log_probs, entropies = self._get_probs_and_entropy(actions, dists) |
|||
# Use the sum of entropy across actions, not the mean |
|||
entropy_sum = torch.sum(entropies, dim=1) |
|||
return log_probs, entropy_sum |
|||
|
|||
def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: |
|||
""" |
|||
Gets the tensors corresponding to the output of the policy network to be used for |
|||
inference. Called by the Actor's forward call. |
|||
:params inputs: The encoding from the network body |
|||
:params masks: Action masks for discrete actions |
|||
:return: A tuple of torch tensors corresponding to the inference output |
|||
""" |
|||
dists = self._get_dists(inputs, masks) |
|||
out_list: List[torch.Tensor] = [] |
|||
# This checks None because mypy complains otherwise |
|||
if dists.continuous is not None: |
|||
out_list.append(dists.continuous.exported_model_output()) |
|||
if dists.discrete is not None: |
|||
for discrete_dist in dists.discrete: |
|||
out_list.append(discrete_dist.exported_model_output()) |
|||
return torch.cat(out_list, dim=1) |
|||
|
|||
def forward( |
|||
self, inputs: torch.Tensor, masks: torch.Tensor |
|||
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor]: |
|||
""" |
|||
The forward method of this module. Outputs the action, log probs, |
|||
and entropies given the encoding from the network body. |
|||
:params inputs: The encoding from the network body |
|||
:params masks: Action masks for discrete actions |
|||
:return: Given the input, an AgentAction of the actions generated by the policy and the corresponding |
|||
ActionLogProbs and entropies. |
|||
""" |
|||
dists = self._get_dists(inputs, masks) |
|||
actions = self._sample_action(dists) |
|||
log_probs, entropies = self._get_probs_and_entropy(actions, dists) |
|||
# Use the sum of entropy across actions, not the mean |
|||
entropy_sum = torch.sum(entropies, dim=1) |
|||
return (actions, log_probs, entropy_sum) |
|
|||
from typing import List, Optional, NamedTuple, Dict |
|||
from mlagents.torch_utils import torch |
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents_envs.base_env import ActionTuple |
|||
|
|||
|
|||
class AgentAction(NamedTuple): |
|||
""" |
|||
A NamedTuple containing the tensor for continuous actions and list of tensors for |
|||
discrete actions. Utility functions provide numpy <=> tensor conversions to be |
|||
sent as actions to the environment manager as well as used by the optimizers. |
|||
:param continuous_tensor: Torch tensor corresponding to continuous actions |
|||
:param discrete_list: List of Torch tensors each corresponding to discrete actions |
|||
""" |
|||
|
|||
continuous_tensor: torch.Tensor |
|||
discrete_list: Optional[List[torch.Tensor]] |
|||
|
|||
@property |
|||
def discrete_tensor(self): |
|||
""" |
|||
Returns the discrete action list as a stacked tensor |
|||
""" |
|||
return torch.stack(self.discrete_list, dim=-1) |
|||
|
|||
def to_action_tuple(self) -> ActionTuple: |
|||
""" |
|||
Returns an ActionTuple |
|||
""" |
|||
action_tuple = ActionTuple() |
|||
if self.continuous_tensor is not None: |
|||
continuous = ModelUtils.to_numpy(self.continuous_tensor) |
|||
action_tuple.add_continuous(continuous) |
|||
if self.discrete_list is not None: |
|||
discrete = ModelUtils.to_numpy(self.discrete_tensor[:, 0, :]) |
|||
action_tuple.add_discrete(discrete) |
|||
return action_tuple |
|||
|
|||
@staticmethod |
|||
def from_dict(buff: Dict[str, np.ndarray]) -> "AgentAction": |
|||
""" |
|||
A static method that accesses continuous and discrete action fields in an AgentBuffer |
|||
and constructs the corresponding AgentAction from the retrieved np arrays. |
|||
""" |
|||
continuous: torch.Tensor = None |
|||
discrete: List[torch.Tensor] = None # type: ignore |
|||
if "continuous_action" in buff: |
|||
continuous = ModelUtils.list_to_tensor(buff["continuous_action"]) |
|||
if "discrete_action" in buff: |
|||
discrete_tensor = ModelUtils.list_to_tensor( |
|||
buff["discrete_action"], dtype=torch.long |
|||
) |
|||
discrete = [ |
|||
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1]) |
|||
] |
|||
return AgentAction(continuous, discrete) |
撰写
预览
正在加载...
取消
保存
Reference in new issue