浏览代码

[refactor] Refactor Actor and Critic classes (#4287)

/develop/add-fire
GitHub 4 年前
当前提交
69579611
共有 11 个文件被更改,包括 616 次插入142 次删除
  1. 6
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 31
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 1
      ml-agents/mlagents/trainers/ppo/trainer.py
  4. 4
      ml-agents/mlagents/trainers/tests/test_reward_signals.py
  5. 11
      ml-agents/mlagents/trainers/torch/decoders.py
  6. 39
      ml-agents/mlagents/trainers/torch/distributions.py
  7. 26
      ml-agents/mlagents/trainers/torch/encoders.py
  8. 389
      ml-agents/mlagents/trainers/torch/networks.py
  9. 43
      ml-agents/mlagents/trainers/torch/utils.py
  10. 208
      ml-agents/mlagents/trainers/tests/torch/test_networks.py

6
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


"""
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)
value_estimates, mean_value = self.policy.actor_critic.critic_pass(
value_estimates = self.policy.actor_critic.critic_pass(
np.expand_dims(vec_vis_obs.vector_observations[idx], 0),
np.expand_dims(vec_vis_obs.visual_observations[idx], 0),
)

next_obs = [ModelUtils.list_to_tensor(next_obs).unsqueeze(0)]
next_memory = torch.zeros([1, 1, self.policy.m_size])
value_estimates, mean_value = self.policy.actor_critic.critic_pass(
value_estimates = self.policy.actor_critic.critic_pass(
next_value_estimate, next_value = self.policy.actor_critic.critic_pass(
next_value_estimate = self.policy.actor_critic.critic_pass(
next_obs, next_obs, next_memory
)

31
ml-agents/mlagents/trainers/policy/torch_policy.py


from typing import Any, Dict, List, Optional
from typing import Any, Dict, List
import numpy as np
import torch

from mlagents.trainers.settings import TrainerSettings, TestingConfiguration
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.torch.networks import ActorCritic
from mlagents.trainers.torch.networks import SharedActorCritic, SeparateActorCritic
from mlagents.trainers.torch.utils import ModelUtils
EPSILON = 1e-7 # Small value to avoid divide by zero

load: bool = False,
tanh_squash: bool = False,
reparameterize: bool = False,
separate_critic: bool = True,
separate_critic: Optional[bool] = None,
):
"""
Policy that uses a multilayer perceptron to map the observations to actions. Could

"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
}
self.actor_critic = ActorCritic(
if separate_critic:
ac_class = SeparateActorCritic
else:
ac_class = SharedActorCritic
self.actor_critic = ac_class(
separate_critic=separate_critic
if separate_critic is not None
else self.use_continuous_act,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)

"""
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
"""
(
dists,
(value_heads, mean_value),
memories,
) = self.actor_critic.get_dist_and_value(
dists, value_heads, memories = self.actor_critic.get_dist_and_value(
log_probs, entropies, all_logs = self.actor_critic.get_probs_and_entropy(
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
action_list, dists
)
actions = torch.stack(action_list, dim=-1)

def evaluate_actions(
self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1
):
dists, (value_heads, mean_value), _ = self.actor_critic.get_dist_and_value(
dists, value_heads, _ = self.actor_critic.get_dist_and_value(
log_probs, entropies, _ = self.actor_critic.get_probs_and_entropy(
action_list, dists
)
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists)
return log_probs, entropies, value_heads

1
ml-agents/mlagents/trainers/ppo/trainer.py


self.artifact_path,
self.load,
condition_sigma_on_obs=False, # Faster training for PPO
separate_critic=behavior_spec.is_action_continuous(),
)
return policy

4
ml-agents/mlagents/trainers/tests/test_reward_signals.py


import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.sac.optimizer import SACOptimizer
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.ppo.optimizer_tf import TFPPOOptimizer
from mlagents.trainers.tests.test_simple_rl import PPO_CONFIG, SAC_CONFIG
from mlagents.trainers.settings import (
GAILSettings,

if trainer_settings.trainer_type == TrainerType.SAC:
optimizer = SACOptimizer(policy, trainer_settings)
else:
optimizer = PPOOptimizer(policy, trainer_settings)
optimizer = TFPPOOptimizer(policy, trainer_settings)
return optimizer

11
ml-agents/mlagents/trainers/torch/decoders.py


from typing import List, Dict
def __init__(self, stream_names, input_size, output_size=1):
def __init__(self, stream_names: List[str], input_size: int, output_size: int = 1):
super().__init__()
self.stream_names = stream_names
_value_heads = {}

_value_heads[name] = value
self.value_heads = nn.ModuleDict(_value_heads)
def forward(self, hidden):
def forward(self, hidden: torch.Tensor) -> Dict[str, torch.Tensor]:
return (
value_outputs,
torch.mean(torch.stack(list(value_outputs.values())), dim=0),
)
return value_outputs

39
ml-agents/mlagents/trainers/torch/distributions.py


import abc
import torch
from torch import nn
import numpy as np

class GaussianDistInstance(nn.Module):
class DistInstance(nn.Module, abc.ABC):
@abc.abstractmethod
def sample(self) -> torch.Tensor:
"""
Return a sample from this distribution.
"""
pass
@abc.abstractmethod
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Returns the log probabilities of a particular value.
:param value: A value sampled from the distribution.
:returns: Log probabilities of the given value.
"""
pass
@abc.abstractmethod
def entropy(self) -> torch.Tensor:
"""
Returns the entropy of this distribution.
"""
pass
class DiscreteDistInstance(DistInstance):
@abc.abstractmethod
def all_log_prob(self) -> torch.Tensor:
"""
Returns the log probabilities of all actions represented by this distribution.
"""
pass
class GaussianDistInstance(DistInstance):
def __init__(self, mean, std):
super().__init__()
self.mean = mean

)
class CategoricalDistInstance(nn.Module):
class CategoricalDistInstance(DiscreteDistInstance):
def __init__(self, logits):
super().__init__()
self.logits = logits

26
ml-agents/mlagents/trainers/torch/encoders.py


for _ in range(num_layers - 1):
self.layers.append(nn.Linear(hidden_size, hidden_size))
self.layers.append(nn.ReLU())
self.layers.append(nn.LeakyReLU())
self.seq_layers = nn.Sequential(*self.layers)
def forward(self, inputs: torch.Tensor) -> None:

self.dense = nn.Linear(self.final_flat, self.h_size)
def forward(self, visual_obs: torch.Tensor) -> None:
conv_1 = torch.relu(self.conv1(visual_obs))
conv_2 = torch.relu(self.conv2(conv_1))
conv_1 = nn.functional.leaky_relu(self.conv1(visual_obs))
conv_2 = nn.functional.leaky_relu(self.conv2(conv_1))
hidden = torch.relu(self.dense(torch.reshape(conv_2, (-1, self.final_flat))))
hidden = nn.functional.leaky_relu(
self.dense(torch.reshape(conv_2, (-1, self.final_flat)))
)
return hidden

self.dense = nn.Linear(self.final_flat, self.h_size)
def forward(self, visual_obs):
conv_1 = torch.relu(self.conv1(visual_obs))
conv_2 = torch.relu(self.conv2(conv_1))
conv_3 = torch.relu(self.conv3(conv_2))
hidden = torch.relu(self.dense(conv_3.view([-1, self.final_flat])))
conv_1 = nn.functional.leaky_relu(self.conv1(visual_obs))
conv_2 = nn.functional.leaky_relu(self.conv2(conv_1))
conv_3 = nn.functional.leaky_relu(self.conv3(conv_2))
hidden = nn.functional.leaky_relu(
self.dense(conv_3.view([-1, self.final_flat]))
)
return hidden

for _ in range(n_blocks):
self.layers.append(self.make_block(channel))
last_channel = channel
self.layers.append(nn.ReLU())
self.layers.append(nn.LeakyReLU())
nn.ReLU(),
nn.LeakyReLU(),
nn.ReLU(),
nn.LeakyReLU(),
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1),
]
return block_layers

389
ml-agents/mlagents/trainers/torch/networks.py


from typing import Callable, List, Dict, Tuple, Optional
import attr
import abc
import torch
from torch import nn

GaussianDistribution,
MultiCategoricalDistribution,
DistInstance,
)
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils

else:
self.lstm = None
def update_normalization(self, vec_inputs):
def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None:
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders):
vec_enc.update_normalization(vec_input)

def forward(
self,
vec_inputs: torch.Tensor,
vis_inputs: torch.Tensor,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
vec_embeds = []
vec_encodes = []
for idx, encoder in enumerate(self.vector_encoders):
vec_input = vec_inputs[idx]
if actions is not None:

vec_embeds.append(hidden)
vec_encodes.append(hidden)
vis_embeds = []
vis_encodes = []
vis_embeds.append(hidden)
vis_encodes.append(hidden)
# embedding = vec_embeds[0]
if len(vec_embeds) > 0 and len(vis_embeds) > 0:
vec_embeds_tensor = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
vis_embeds_tensor = torch.stack(vis_embeds, dim=-1).sum(dim=-1)
embedding = torch.stack([vec_embeds_tensor, vis_embeds_tensor], dim=-1).sum(
dim=-1
)
elif len(vec_embeds) > 0:
embedding = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
elif len(vis_embeds) > 0:
embedding = torch.stack(vis_embeds, dim=-1).sum(dim=-1)
if len(vec_encodes) > 0 and len(vis_encodes) > 0:
vec_encodes_tensor = torch.stack(vec_encodes, dim=-1).sum(dim=-1)
vis_encodes_tensor = torch.stack(vis_encodes, dim=-1).sum(dim=-1)
encoding = torch.stack(
[vec_encodes_tensor, vis_encodes_tensor], dim=-1
).sum(dim=-1)
elif len(vec_encodes) > 0:
encoding = torch.stack(vec_encodes, dim=-1).sum(dim=-1)
elif len(vis_encodes) > 0:
encoding = torch.stack(vis_encodes, dim=-1).sum(dim=-1)
embedding = embedding.view([sequence_length, -1, self.h_size])
encoding = encoding.view([sequence_length, -1, self.h_size])
embedding, memories = self.lstm(
embedding.contiguous(),
encoding, memories = self.lstm(
encoding.contiguous(),
embedding = embedding.view([-1, self.m_size // 2])
encoding = encoding.view([-1, self.m_size // 2])
return embedding, memories
return encoding, memories
class ValueNetwork(nn.Module):

self.network_body = NetworkBody(
observation_shapes, network_settings, encoded_act_size=encoded_act_size
)
self.value_heads = ValueHeads(
stream_names, network_settings.hidden_units, outputs_per_stream
)
if network_settings.memory is not None:
encoding_size = network_settings.memory.memory_size // 2
else:
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)
def forward(
self,

memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
embedding, memories = self.network_body(
encoding, memories = self.network_body(
output, _ = self.value_heads(embedding)
output = self.value_heads(encoding)
class ActorCritic(nn.Module):
class Actor(abc.ABC):
@abc.abstractmethod
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
"""
Updates normalization of Actor based on the provided List of vector obs.
:param vector_obs: A List of vector obs as tensors.
"""
pass
@abc.abstractmethod
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]:
"""
Takes a List of Distribution iinstances and samples an action from each.
"""
pass
@abc.abstractmethod
def get_dists(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]:
"""
Returns distributions from this Actor, from which actions can be sampled.
If memory is enabled, return the memories as well.
:param vec_inputs: A List of vector inputs as tensors.
:param vis_inputs: A List of visual inputs as tensors.
:param masks: If using discrete actions, a Tensor of action masks.
:param memories: If using memory, a Tensor of initial memories.
:param sequence_length: If using memory, the sequence length.
:return: A Tuple of a List of action distribution instances, and memories.
Memories will be None if not using memory.
"""
pass
@abc.abstractmethod
def forward(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
"""
Forward pass of the Actor for inference. This is required for export to ONNX, and
the inputs and outputs of this method should not be changed without a respective change
in the ONNX export code.
"""
pass
class ActorCritic(Actor):
@abc.abstractmethod
def critic_pass(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Get value outputs for the given obs.
:param vec_inputs: List of vector inputs as tensors.
:param vis_inputs: List of visual inputs as tensors.
:param memories: Tensor of memories, if using memory. Otherwise, None.
:returns: Dict of reward stream to output tensor for values.
"""
pass
@abc.abstractmethod
def get_dist_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
"""
Returns distributions, from which actions can be sampled, and value estimates.
If memory is enabled, return the memories as well.
:param vec_inputs: A List of vector inputs as tensors.
:param vis_inputs: A List of visual inputs as tensors.
:param masks: If using discrete actions, a Tensor of action masks.
:param memories: If using memory, a Tensor of initial memories.
:param sequence_length: If using memory, the sequence length.
:return: A Tuple of a List of action distribution instances, a Dict of reward signal
name to value estimate, and memories. Memories will be None if not using memory.
"""
pass
class SimpleActor(nn.Module, Actor):
def __init__(
self,
observation_shapes: List[Tuple[int, ...]],

stream_names: List[str],
separate_critic: bool,
conditional_sigma: bool = False,
tanh_squash: bool = False,
):

self.version_number = torch.nn.Parameter(torch.Tensor([2.0]))
self.memory_size = torch.nn.Parameter(torch.Tensor([0]))
self.is_continuous_int = torch.nn.Parameter(torch.Tensor([1]))
self.is_continuous_int = torch.nn.Parameter(
torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
)
self.separate_critic = separate_critic
embedding_size = network_settings.memory.memory_size // 2
self.encoding_size = network_settings.memory.memory_size // 2
embedding_size = network_settings.hidden_units
self.encoding_size = network_settings.hidden_units
embedding_size,
self.encoding_size,
self.distribution = MultiCategoricalDistribution(embedding_size, act_size)
if separate_critic:
self.critic = ValueNetwork(
stream_names, observation_shapes, network_settings
self.distribution = MultiCategoricalDistribution(
self.encoding_size, act_size
else:
self.stream_names = stream_names
self.value_heads = ValueHeads(stream_names, embedding_size)
def update_normalization(self, vector_obs):
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
if self.separate_critic:
self.critic.network_body.update_normalization(vector_obs)
def critic_pass(self, vec_inputs, vis_inputs, memories=None):
if self.separate_critic:
return self.critic(vec_inputs, vis_inputs)
else:
embedding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories)
return self.value_heads(embedding)
def sample_action(self, dists):
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]:
actions = []
for action_dist in dists:
action = action_dist.sample()

def get_probs_and_entropy(self, action_list, dists):
log_probs = []
all_probs = []
entropies = []
for action, action_dist in zip(action_list, dists):
log_prob = action_dist.log_prob(action)
log_probs.append(log_prob)
entropies.append(action_dist.entropy())
if self.act_type == ActionType.DISCRETE:
all_probs.append(action_dist.all_log_prob())
log_probs = torch.stack(log_probs, dim=-1)
entropies = torch.stack(entropies, dim=-1)
if self.act_type == ActionType.CONTINUOUS:
log_probs = log_probs.squeeze(-1)
entropies = entropies.squeeze(-1)
all_probs = None
else:
all_probs = torch.cat(all_probs, dim=-1)
return log_probs, entropies, all_probs
def get_dist_and_value(
self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1
):
embedding, memories = self.network_body(
def get_dists(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]:
encoding, memories = self.network_body(
dists = self.distribution(embedding)
else:
dists = self.distribution(embedding, masks=masks)
if self.separate_critic:
value_outputs = self.critic(vec_inputs, vis_inputs)
dists = self.distribution(encoding)
value_outputs = self.value_heads(embedding)
return dists, value_outputs, memories
dists = self.distribution(encoding, masks)
return dists, memories
self, vec_inputs, vis_inputs=None, masks=None, memories=None, sequence_length=1
):
embedding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
dists, value_outputs, memories = self.get_dist_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
"""
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
"""
dists, _ = self.get_dists(
vec_inputs, vis_inputs, masks, memories, sequence_length
)
action_list = self.sample_action(dists)

self.is_continuous_int,
self.act_size_vector,
)
class SharedActorCritic(SimpleActor, ActorCritic):
def __init__(
self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,
):
super().__init__(
observation_shapes,
network_settings,
act_type,
act_size,
conditional_sigma,
tanh_squash,
)
self.stream_names = stream_names
self.value_heads = ValueHeads(stream_names, self.encoding_size)
def critic_pass(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
encoding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories)
return self.value_heads(encoding)
def get_dist_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
if self.act_type == ActionType.CONTINUOUS:
dists = self.distribution(encoding)
else:
dists = self.distribution(encoding, masks=masks)
value_outputs = self.value_heads(encoding)
return dists, value_outputs, memories
class SeparateActorCritic(SimpleActor, ActorCritic):
def __init__(
self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,
):
# Give the Actor only half the memories. Note we previously validate
# that memory_size must be a multiple of 4.
self.use_lstm = network_settings.memory is not None
if network_settings.memory is not None:
self.half_mem_size = network_settings.memory.memory_size // 2
new_memory_settings = attr.evolve(
network_settings.memory, memory_size=self.half_mem_size
)
use_network_settings = attr.evolve(
network_settings, memory=new_memory_settings
)
else:
use_network_settings = network_settings
self.half_mem_size = 0
super().__init__(
observation_shapes,
use_network_settings,
act_type,
act_size,
conditional_sigma,
tanh_squash,
)
self.stream_names = stream_names
self.critic = ValueNetwork(
stream_names, observation_shapes, use_network_settings
)
def critic_pass(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
if memories is not None:
# Use only the back half of memories for critic
_, critic_mem = torch.split(memories, self.half_mem_size, -1)
else:
critic_mem = None
value_outputs, _memories = self.critic(
vec_inputs, vis_inputs, memories=critic_mem
)
return value_outputs
def get_dist_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
if memories is not None:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1)
else:
critic_mem = None
actor_mem = None
dists, actor_mem_outs = self.get_dists(
vec_inputs,
vis_inputs,
memories=actor_mem,
sequence_length=sequence_length,
masks=masks,
)
value_outputs, critic_mem_outs = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)
if self.use_lstm:
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=1)
else:
mem_out = None
return dists, value_outputs, mem_out
class GlobalSteps(nn.Module):

43
ml-agents/mlagents/trainers/torch/utils.py


)
from mlagents.trainers.settings import EncoderType
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance
class ModelUtils:

raise UnityTrainerException(
f"Unsupported shape of {dimension} for observation {i}"
)
if unnormalized_inputs > 0:
vector_encoders.append(
VectorAndUnnormalizedInputEncoder(
vector_size, h_size, unnormalized_inputs, num_layers, normalize
if vector_size + unnormalized_inputs > 0:
if unnormalized_inputs > 0:
vector_encoders.append(
VectorAndUnnormalizedInputEncoder(
vector_size, h_size, unnormalized_inputs, num_layers, normalize
)
)
else:
vector_encoders.append(
VectorEncoder(vector_size, h_size, num_layers, normalize)
)
else:
vector_encoders.append(
VectorEncoder(vector_size, h_size, num_layers, normalize)
)
return nn.ModuleList(visual_encoders), nn.ModuleList(vector_encoders)
@staticmethod

for i, _act in enumerate(discrete_actions.T)
]
return onehot_branches
@staticmethod
def get_probs_and_entropy(
action_list: List[torch.Tensor], dists: List[DistInstance]
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
log_probs_list = []
all_probs_list = []
entropies_list = []
for action, action_dist in zip(action_list, dists):
log_prob = action_dist.log_prob(action)
log_probs_list.append(log_prob)
entropies_list.append(action_dist.entropy())
if isinstance(action_dist, DiscreteDistInstance):
all_probs_list.append(action_dist.all_log_prob())
log_probs = torch.stack(log_probs_list, dim=-1)
entropies = torch.stack(entropies_list, dim=-1)
if not all_probs_list:
log_probs = log_probs.squeeze(-1)
entropies = entropies.squeeze(-1)
all_probs = None
else:
all_probs = torch.cat(all_probs, dim=-1)
return log_probs, entropies, all_probs

208
ml-agents/mlagents/trainers/tests/torch/test_networks.py


import pytest
import torch
from mlagents.trainers.torch.networks import (
NetworkBody,
ValueNetwork,
SimpleActor,
SharedActorCritic,
SeparateActorCritic,
)
from mlagents.trainers.settings import NetworkSettings
from mlagents_envs.base_env import ActionType
from mlagents.trainers.torch.distributions import (
GaussianDistInstance,
CategoricalDistInstance,
)
def test_networkbody_vector():
obs_size = 4
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]
networkbody = NetworkBody(obs_shapes, network_settings, encoded_act_size=2)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = torch.ones((1, obs_size))
sample_act = torch.ones((1, 2))
for _ in range(100):
encoded, _ = networkbody([sample_obs], [], sample_act)
assert encoded.shape == (1, network_settings.hidden_units)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten():
assert _enc == pytest.approx(1.0, abs=0.1)
def test_networkbody_lstm():
obs_size = 4
seq_len = 16
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=4)
)
obs_shapes = [(obs_size,)]
networkbody = NetworkBody(obs_shapes, network_settings)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = torch.ones((1, seq_len, obs_size))
for _ in range(100):
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 4))
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten():
assert _enc == pytest.approx(1.0, abs=0.1)
def test_networkbody_visual():
vec_obs_size = 4
obs_size = (84, 84, 3)
network_settings = NetworkSettings()
obs_shapes = [(vec_obs_size,), obs_size]
torch.random.manual_seed(0)
networkbody = NetworkBody(obs_shapes, network_settings)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = torch.ones((1, 84, 84, 3))
sample_vec_obs = torch.ones((1, vec_obs_size))
for _ in range(100):
encoded, _ = networkbody([sample_vec_obs], [sample_obs])
assert encoded.shape == (1, network_settings.hidden_units)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten():
assert _enc == pytest.approx(1.0, abs=0.1)
def test_valuenetwork():
obs_size = 4
num_outputs = 2
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]
stream_names = [f"stream_name{n}" for n in range(4)]
value_net = ValueNetwork(
stream_names, obs_shapes, network_settings, outputs_per_stream=num_outputs
)
optimizer = torch.optim.Adam(value_net.parameters(), lr=3e-3)
for _ in range(50):
sample_obs = torch.ones((1, obs_size))
values, _ = value_net([sample_obs], [])
loss = 0
for s_name in stream_names:
assert values[s_name].shape == (1, num_outputs)
# Try to force output to 1
loss += torch.nn.functional.mse_loss(
values[s_name], torch.ones((1, num_outputs))
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for value in values.values():
for _out in value:
assert _out[0] == pytest.approx(1.0, abs=0.1)
@pytest.mark.parametrize("action_type", [ActionType.DISCRETE, ActionType.CONTINUOUS])
def test_simple_actor(action_type):
obs_size = 4
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]
act_size = [2]
masks = None if action_type == ActionType.CONTINUOUS else torch.ones((1, 1))
actor = SimpleActor(obs_shapes, network_settings, action_type, act_size)
# Test get_dist
sample_obs = torch.ones((1, obs_size))
dists, _ = actor.get_dists([sample_obs], [], masks=masks)
for dist in dists:
if action_type == ActionType.CONTINUOUS:
assert isinstance(dist, GaussianDistInstance)
else:
assert isinstance(dist, CategoricalDistInstance)
# Test sample_actions
actions = actor.sample_action(dists)
for act in actions:
if action_type == ActionType.CONTINUOUS:
assert act.shape == (1, act_size[0])
else:
assert act.shape == (1, 1)
# Test forward
actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward(
[sample_obs], [], masks=masks
)
for act in actions:
if action_type == ActionType.CONTINUOUS:
assert act.shape == (
act_size[0],
1,
) # This is different from above for ONNX export
else:
assert act.shape == (1, 1)
# TODO: Once export works properly. fix the shapes here.
assert mem_size == 0
assert is_cont == int(action_type == ActionType.CONTINUOUS)
assert act_size_vec == torch.tensor(act_size)
@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic])
@pytest.mark.parametrize("lstm", [True, False])
def test_actor_critic(ac_type, lstm):
obs_size = 4
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings() if lstm else None
)
obs_shapes = [(obs_size,)]
act_size = [2]
stream_names = [f"stream_name{n}" for n in range(4)]
actor = ac_type(
obs_shapes, network_settings, ActionType.CONTINUOUS, act_size, stream_names
)
if lstm:
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
memories = torch.ones(
(
1,
network_settings.memory.sequence_length,
network_settings.memory.memory_size,
)
)
else:
sample_obs = torch.ones((1, obs_size))
memories = None
# Test critic pass
value_out = actor.critic_pass([sample_obs], [], memories=memories)
for stream in stream_names:
if lstm:
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
else:
assert value_out[stream].shape == (1,)
# Test get_dist_and_value
dists, value_out, _ = actor.get_dist_and_value([sample_obs], [], memories=memories)
for dist in dists:
assert isinstance(dist, GaussianDistInstance)
for stream in stream_names:
if lstm:
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
else:
assert value_out[stream].shape == (1,)
正在加载...
取消
保存