浏览代码

Merge branch 'develop-add-fire' into develop-add-fire-checkpoint

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
79d89158
共有 15 个文件被更改,包括 1100 次插入168 次删除
  1. 6
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 31
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 1
      ml-agents/mlagents/trainers/ppo/trainer.py
  4. 4
      ml-agents/mlagents/trainers/tests/test_reward_signals.py
  5. 11
      ml-agents/mlagents/trainers/torch/decoders.py
  6. 69
      ml-agents/mlagents/trainers/torch/distributions.py
  7. 41
      ml-agents/mlagents/trainers/torch/encoders.py
  8. 389
      ml-agents/mlagents/trainers/torch/networks.py
  9. 58
      ml-agents/mlagents/trainers/torch/utils.py
  10. 210
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  11. 31
      ml-agents/mlagents/trainers/tests/torch/test_decoders.py
  12. 141
      ml-agents/mlagents/trainers/tests/torch/test_distributions.py
  13. 110
      ml-agents/mlagents/trainers/tests/torch/test_encoders.py
  14. 166
      ml-agents/mlagents/trainers/tests/torch/test_utils.py

6
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


"""
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)
value_estimates, mean_value = self.policy.actor_critic.critic_pass(
value_estimates = self.policy.actor_critic.critic_pass(
np.expand_dims(vec_vis_obs.vector_observations[idx], 0),
np.expand_dims(vec_vis_obs.visual_observations[idx], 0),
)

next_obs = [ModelUtils.list_to_tensor(next_obs).unsqueeze(0)]
next_memory = torch.zeros([1, 1, self.policy.m_size])
value_estimates, mean_value = self.policy.actor_critic.critic_pass(
value_estimates = self.policy.actor_critic.critic_pass(
next_value_estimate, next_value = self.policy.actor_critic.critic_pass(
next_value_estimate = self.policy.actor_critic.critic_pass(
next_obs, next_obs, next_memory
)

31
ml-agents/mlagents/trainers/policy/torch_policy.py


from typing import Any, Dict, List, Optional
from typing import Any, Dict, List
import numpy as np
import torch

from mlagents.trainers.settings import TrainerSettings, TestingConfiguration
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.torch.networks import ActorCritic, GlobalSteps
from mlagents.trainers.torch.networks import SharedActorCritic, SeparateActorCritic
from mlagents.trainers.torch.utils import ModelUtils
EPSILON = 1e-7 # Small value to avoid divide by zero

trainer_settings: TrainerSettings,
tanh_squash: bool = False,
reparameterize: bool = False,
separate_critic: bool = True,
separate_critic: Optional[bool] = None,
):
"""
Policy that uses a multilayer perceptron to map the observations to actions. Could

"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
}
self.actor_critic = ActorCritic(
if separate_critic:
ac_class = SeparateActorCritic
else:
ac_class = SharedActorCritic
self.actor_critic = ac_class(
separate_critic=separate_critic
if separate_critic is not None
else self.use_continuous_act,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)

"""
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
"""
(
dists,
(value_heads, mean_value),
memories,
) = self.actor_critic.get_dist_and_value(
dists, value_heads, memories = self.actor_critic.get_dist_and_value(
log_probs, entropies, all_logs = self.actor_critic.get_probs_and_entropy(
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
action_list, dists
)
actions = torch.stack(action_list, dim=-1)

def evaluate_actions(
self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1
):
dists, (value_heads, mean_value), _ = self.actor_critic.get_dist_and_value(
dists, value_heads, _ = self.actor_critic.get_dist_and_value(
log_probs, entropies, _ = self.actor_critic.get_probs_and_entropy(
action_list, dists
)
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists)
return log_probs, entropies, value_heads

1
ml-agents/mlagents/trainers/ppo/trainer.py


behavior_spec,
self.trainer_settings,
condition_sigma_on_obs=False, # Faster training for PPO
separate_critic=behavior_spec.is_action_continuous(),
)
return policy

4
ml-agents/mlagents/trainers/tests/test_reward_signals.py


import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.sac.optimizer import SACOptimizer
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.ppo.optimizer_tf import TFPPOOptimizer
from mlagents.trainers.tests.test_simple_rl import PPO_CONFIG, SAC_CONFIG
from mlagents.trainers.settings import (
GAILSettings,

if trainer_settings.trainer_type == TrainerType.SAC:
optimizer = SACOptimizer(policy, trainer_settings)
else:
optimizer = PPOOptimizer(policy, trainer_settings)
optimizer = TFPPOOptimizer(policy, trainer_settings)
return optimizer

11
ml-agents/mlagents/trainers/torch/decoders.py


from typing import List, Dict
def __init__(self, stream_names, input_size, output_size=1):
def __init__(self, stream_names: List[str], input_size: int, output_size: int = 1):
super().__init__()
self.stream_names = stream_names
_value_heads = {}

_value_heads[name] = value
self.value_heads = nn.ModuleDict(_value_heads)
def forward(self, hidden):
def forward(self, hidden: torch.Tensor) -> Dict[str, torch.Tensor]:
return (
value_outputs,
torch.mean(torch.stack(list(value_outputs.values())), dim=0),
)
return value_outputs

69
ml-agents/mlagents/trainers/torch/distributions.py


import abc
from typing import List
import torch
from torch import nn
import numpy as np

class GaussianDistInstance(nn.Module):
class DistInstance(nn.Module, abc.ABC):
@abc.abstractmethod
def sample(self) -> torch.Tensor:
"""
Return a sample from this distribution.
"""
pass
@abc.abstractmethod
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Returns the log probabilities of a particular value.
:param value: A value sampled from the distribution.
:returns: Log probabilities of the given value.
"""
pass
@abc.abstractmethod
def entropy(self) -> torch.Tensor:
"""
Returns the entropy of this distribution.
"""
pass
class DiscreteDistInstance(DistInstance):
@abc.abstractmethod
def all_log_prob(self) -> torch.Tensor:
"""
Returns the log probabilities of all actions represented by this distribution.
"""
pass
class GaussianDistInstance(DistInstance):
def __init__(self, mean, std):
super().__init__()
self.mean = mean

)
class CategoricalDistInstance(nn.Module):
class CategoricalDistInstance(DiscreteDistInstance):
def __init__(self, logits):
super().__init__()
self.logits = logits

class GaussianDistribution(nn.Module):
def __init__(
self,
hidden_size,
num_outputs,
conditional_sigma=False,
tanh_squash=False,
**kwargs
hidden_size: int,
num_outputs: int,
conditional_sigma: bool = False,
tanh_squash: bool = False,
super().__init__(**kwargs)
super().__init__()
self.conditional_sigma = conditional_sigma
self.mu = nn.Linear(hidden_size, num_outputs)
self.tanh_squash = tanh_squash

torch.zeros(1, num_outputs, requires_grad=True)
)
def forward(self, inputs):
def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
mu = self.mu(inputs)
if self.conditional_sigma:
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)

class MultiCategoricalDistribution(nn.Module):
def __init__(self, hidden_size, act_sizes):
def __init__(self, hidden_size: int, act_sizes: List[int]):
self.branches = self.create_policy_branches(hidden_size)
self.branches = self._create_policy_branches(hidden_size)
def create_policy_branches(self, hidden_size):
def _create_policy_branches(self, hidden_size: int) -> nn.ModuleList:
branches = []
for size in self.act_sizes:
branch_output_layer = nn.Linear(hidden_size, size)

def mask_branch(self, logits, mask):
def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
def split_masks(self, masks):
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]:
split_masks = []
for idx, _ in enumerate(self.act_sizes):
start = int(np.sum(self.act_sizes[:idx]))

def forward(self, inputs, masks):
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]:
masks = self.split_masks(masks)
masks = self._split_masks(masks)
norm_logits = self.mask_branch(logits, masks[idx])
norm_logits = self._mask_branch(logits, masks[idx])
distribution = CategoricalDistInstance(norm_logits)
branch_distributions.append(distribution)
return branch_distributions

41
ml-agents/mlagents/trainers/torch/encoders.py


class Normalizer(nn.Module):
def __init__(self, vec_obs_size: int):
super().__init__()
self.normalization_steps = nn.Parameter(torch.tensor(1), requires_grad=False)
self.running_mean = nn.Parameter(torch.zeros(vec_obs_size), requires_grad=False)
self.running_variance = nn.Parameter(
torch.ones(vec_obs_size), requires_grad=False
)
self.register_buffer("normalization_steps", torch.tensor(1))
self.register_buffer("running_mean", torch.zeros(vec_obs_size))
self.register_buffer("running_variance", torch.ones(vec_obs_size))
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
normalized_state = torch.clamp(

new_variance = self.running_variance + (
input_to_new_mean * input_to_old_mean
).sum(0)
self.running_mean.data = new_mean.data
self.running_variance.data = new_variance.data
self.normalization_steps.data = total_new_steps.data
# Update in-place
self.running_mean.data.copy_(new_mean.data)
self.running_variance.data.copy_(new_variance.data)
self.normalization_steps.data.copy_(total_new_steps.data)
def copy_from(self, other_normalizer: "Normalizer") -> None:
self.normalization_steps.data.copy_(other_normalizer.normalization_steps.data)

for _ in range(num_layers - 1):
self.layers.append(nn.Linear(hidden_size, hidden_size))
self.layers.append(nn.ReLU())
self.layers.append(nn.LeakyReLU())
self.seq_layers = nn.Sequential(*self.layers)
def forward(self, inputs: torch.Tensor) -> None:

self.dense = nn.Linear(self.final_flat, self.h_size)
def forward(self, visual_obs: torch.Tensor) -> None:
conv_1 = torch.relu(self.conv1(visual_obs))
conv_2 = torch.relu(self.conv2(conv_1))
conv_1 = nn.functional.leaky_relu(self.conv1(visual_obs))
conv_2 = nn.functional.leaky_relu(self.conv2(conv_1))
hidden = torch.relu(self.dense(torch.reshape(conv_2, (-1, self.final_flat))))
hidden = nn.functional.leaky_relu(
self.dense(torch.reshape(conv_2, (-1, self.final_flat)))
)
return hidden

self.dense = nn.Linear(self.final_flat, self.h_size)
def forward(self, visual_obs):
conv_1 = torch.relu(self.conv1(visual_obs))
conv_2 = torch.relu(self.conv2(conv_1))
conv_3 = torch.relu(self.conv3(conv_2))
hidden = torch.relu(self.dense(conv_3.view([-1, self.final_flat])))
conv_1 = nn.functional.leaky_relu(self.conv1(visual_obs))
conv_2 = nn.functional.leaky_relu(self.conv2(conv_1))
conv_3 = nn.functional.leaky_relu(self.conv3(conv_2))
hidden = nn.functional.leaky_relu(
self.dense(conv_3.view([-1, self.final_flat]))
)
return hidden

for _ in range(n_blocks):
self.layers.append(self.make_block(channel))
last_channel = channel
self.layers.append(nn.ReLU())
self.layers.append(nn.LeakyReLU())
nn.ReLU(),
nn.LeakyReLU(),
nn.ReLU(),
nn.LeakyReLU(),
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1),
]
return block_layers

389
ml-agents/mlagents/trainers/torch/networks.py


from typing import Callable, List, Dict, Tuple, Optional
import attr
import abc
import torch
from torch import nn

GaussianDistribution,
MultiCategoricalDistribution,
DistInstance,
)
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils

else:
self.lstm = None
def update_normalization(self, vec_inputs):
def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None:
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders):
vec_enc.update_normalization(vec_input)

def forward(
self,
vec_inputs: torch.Tensor,
vis_inputs: torch.Tensor,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
vec_embeds = []
vec_encodes = []
for idx, encoder in enumerate(self.vector_encoders):
vec_input = vec_inputs[idx]
if actions is not None:

vec_embeds.append(hidden)
vec_encodes.append(hidden)
vis_embeds = []
vis_encodes = []
vis_embeds.append(hidden)
vis_encodes.append(hidden)
# embedding = vec_embeds[0]
if len(vec_embeds) > 0 and len(vis_embeds) > 0:
vec_embeds_tensor = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
vis_embeds_tensor = torch.stack(vis_embeds, dim=-1).sum(dim=-1)
embedding = torch.stack([vec_embeds_tensor, vis_embeds_tensor], dim=-1).sum(
dim=-1
)
elif len(vec_embeds) > 0:
embedding = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
elif len(vis_embeds) > 0:
embedding = torch.stack(vis_embeds, dim=-1).sum(dim=-1)
if len(vec_encodes) > 0 and len(vis_encodes) > 0:
vec_encodes_tensor = torch.stack(vec_encodes, dim=-1).sum(dim=-1)
vis_encodes_tensor = torch.stack(vis_encodes, dim=-1).sum(dim=-1)
encoding = torch.stack(
[vec_encodes_tensor, vis_encodes_tensor], dim=-1
).sum(dim=-1)
elif len(vec_encodes) > 0:
encoding = torch.stack(vec_encodes, dim=-1).sum(dim=-1)
elif len(vis_encodes) > 0:
encoding = torch.stack(vis_encodes, dim=-1).sum(dim=-1)
embedding = embedding.view([sequence_length, -1, self.h_size])
encoding = encoding.view([sequence_length, -1, self.h_size])
embedding, memories = self.lstm(
embedding.contiguous(),
encoding, memories = self.lstm(
encoding.contiguous(),
embedding = embedding.view([-1, self.m_size // 2])
encoding = encoding.view([-1, self.m_size // 2])
return embedding, memories
return encoding, memories
class ValueNetwork(nn.Module):

self.network_body = NetworkBody(
observation_shapes, network_settings, encoded_act_size=encoded_act_size
)
self.value_heads = ValueHeads(
stream_names, network_settings.hidden_units, outputs_per_stream
)
if network_settings.memory is not None:
encoding_size = network_settings.memory.memory_size // 2
else:
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)
def forward(
self,

memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
embedding, memories = self.network_body(
encoding, memories = self.network_body(
output, _ = self.value_heads(embedding)
output = self.value_heads(encoding)
class ActorCritic(nn.Module):
class Actor(abc.ABC):
@abc.abstractmethod
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
"""
Updates normalization of Actor based on the provided List of vector obs.
:param vector_obs: A List of vector obs as tensors.
"""
pass
@abc.abstractmethod
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]:
"""
Takes a List of Distribution iinstances and samples an action from each.
"""
pass
@abc.abstractmethod
def get_dists(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]:
"""
Returns distributions from this Actor, from which actions can be sampled.
If memory is enabled, return the memories as well.
:param vec_inputs: A List of vector inputs as tensors.
:param vis_inputs: A List of visual inputs as tensors.
:param masks: If using discrete actions, a Tensor of action masks.
:param memories: If using memory, a Tensor of initial memories.
:param sequence_length: If using memory, the sequence length.
:return: A Tuple of a List of action distribution instances, and memories.
Memories will be None if not using memory.
"""
pass
@abc.abstractmethod
def forward(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
"""
Forward pass of the Actor for inference. This is required for export to ONNX, and
the inputs and outputs of this method should not be changed without a respective change
in the ONNX export code.
"""
pass
class ActorCritic(Actor):
@abc.abstractmethod
def critic_pass(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Get value outputs for the given obs.
:param vec_inputs: List of vector inputs as tensors.
:param vis_inputs: List of visual inputs as tensors.
:param memories: Tensor of memories, if using memory. Otherwise, None.
:returns: Dict of reward stream to output tensor for values.
"""
pass
@abc.abstractmethod
def get_dist_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
"""
Returns distributions, from which actions can be sampled, and value estimates.
If memory is enabled, return the memories as well.
:param vec_inputs: A List of vector inputs as tensors.
:param vis_inputs: A List of visual inputs as tensors.
:param masks: If using discrete actions, a Tensor of action masks.
:param memories: If using memory, a Tensor of initial memories.
:param sequence_length: If using memory, the sequence length.
:return: A Tuple of a List of action distribution instances, a Dict of reward signal
name to value estimate, and memories. Memories will be None if not using memory.
"""
pass
class SimpleActor(nn.Module, Actor):
def __init__(
self,
observation_shapes: List[Tuple[int, ...]],

stream_names: List[str],
separate_critic: bool,
conditional_sigma: bool = False,
tanh_squash: bool = False,
):

self.version_number = torch.nn.Parameter(torch.Tensor([2.0]))
self.memory_size = torch.nn.Parameter(torch.Tensor([0]))
self.is_continuous_int = torch.nn.Parameter(torch.Tensor([1]))
self.is_continuous_int = torch.nn.Parameter(
torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
)
self.separate_critic = separate_critic
embedding_size = network_settings.memory.memory_size // 2
self.encoding_size = network_settings.memory.memory_size // 2
embedding_size = network_settings.hidden_units
self.encoding_size = network_settings.hidden_units
embedding_size,
self.encoding_size,
self.distribution = MultiCategoricalDistribution(embedding_size, act_size)
if separate_critic:
self.critic = ValueNetwork(
stream_names, observation_shapes, network_settings
self.distribution = MultiCategoricalDistribution(
self.encoding_size, act_size
else:
self.stream_names = stream_names
self.value_heads = ValueHeads(stream_names, embedding_size)
def update_normalization(self, vector_obs):
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
if self.separate_critic:
self.critic.network_body.update_normalization(vector_obs)
def critic_pass(self, vec_inputs, vis_inputs, memories=None):
if self.separate_critic:
return self.critic(vec_inputs, vis_inputs)
else:
embedding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories)
return self.value_heads(embedding)
def sample_action(self, dists):
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]:
actions = []
for action_dist in dists:
action = action_dist.sample()

def get_probs_and_entropy(self, action_list, dists):
log_probs = []
all_probs = []
entropies = []
for action, action_dist in zip(action_list, dists):
log_prob = action_dist.log_prob(action)
log_probs.append(log_prob)
entropies.append(action_dist.entropy())
if self.act_type == ActionType.DISCRETE:
all_probs.append(action_dist.all_log_prob())
log_probs = torch.stack(log_probs, dim=-1)
entropies = torch.stack(entropies, dim=-1)
if self.act_type == ActionType.CONTINUOUS:
log_probs = log_probs.squeeze(-1)
entropies = entropies.squeeze(-1)
all_probs = None
else:
all_probs = torch.cat(all_probs, dim=-1)
return log_probs, entropies, all_probs
def get_dist_and_value(
self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1
):
embedding, memories = self.network_body(
def get_dists(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]:
encoding, memories = self.network_body(
dists = self.distribution(embedding)
else:
dists = self.distribution(embedding, masks=masks)
if self.separate_critic:
value_outputs = self.critic(vec_inputs, vis_inputs)
dists = self.distribution(encoding)
value_outputs = self.value_heads(embedding)
return dists, value_outputs, memories
dists = self.distribution(encoding, masks)
return dists, memories
self, vec_inputs, vis_inputs=None, masks=None, memories=None, sequence_length=1
):
embedding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
dists, value_outputs, memories = self.get_dist_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
"""
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
"""
dists, _ = self.get_dists(
vec_inputs, vis_inputs, masks, memories, sequence_length
)
action_list = self.sample_action(dists)

self.is_continuous_int,
self.act_size_vector,
)
class SharedActorCritic(SimpleActor, ActorCritic):
def __init__(
self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,
):
super().__init__(
observation_shapes,
network_settings,
act_type,
act_size,
conditional_sigma,
tanh_squash,
)
self.stream_names = stream_names
self.value_heads = ValueHeads(stream_names, self.encoding_size)
def critic_pass(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
encoding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories)
return self.value_heads(encoding)
def get_dist_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
if self.act_type == ActionType.CONTINUOUS:
dists = self.distribution(encoding)
else:
dists = self.distribution(encoding, masks=masks)
value_outputs = self.value_heads(encoding)
return dists, value_outputs, memories
class SeparateActorCritic(SimpleActor, ActorCritic):
def __init__(
self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,
):
# Give the Actor only half the memories. Note we previously validate
# that memory_size must be a multiple of 4.
self.use_lstm = network_settings.memory is not None
if network_settings.memory is not None:
self.half_mem_size = network_settings.memory.memory_size // 2
new_memory_settings = attr.evolve(
network_settings.memory, memory_size=self.half_mem_size
)
use_network_settings = attr.evolve(
network_settings, memory=new_memory_settings
)
else:
use_network_settings = network_settings
self.half_mem_size = 0
super().__init__(
observation_shapes,
use_network_settings,
act_type,
act_size,
conditional_sigma,
tanh_squash,
)
self.stream_names = stream_names
self.critic = ValueNetwork(
stream_names, observation_shapes, use_network_settings
)
def critic_pass(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
if self.use_lstm:
# Use only the back half of memories for critic
_, critic_mem = torch.split(memories, self.half_mem_size, -1)
else:
critic_mem = None
value_outputs, _memories = self.critic(
vec_inputs, vis_inputs, memories=critic_mem
)
return value_outputs
def get_dist_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
if self.use_lstm:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1)
else:
critic_mem = None
actor_mem = None
dists, actor_mem_outs = self.get_dists(
vec_inputs,
vis_inputs,
memories=actor_mem,
sequence_length=sequence_length,
masks=masks,
)
value_outputs, critic_mem_outs = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)
if self.use_lstm:
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=1)
else:
mem_out = None
return dists, value_outputs, mem_out
class GlobalSteps(nn.Module):

58
ml-agents/mlagents/trainers/torch/utils.py


)
from mlagents.trainers.settings import EncoderType
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance
class ModelUtils:

@staticmethod
def _check_resolution_for_encoder(
vis_in: torch.Tensor, vis_encoder_type: EncoderType
height: int, width: int, vis_encoder_type: EncoderType
height = vis_in.shape[1]
width = vis_in.shape[2]
if height < min_res or width < min_res:
raise UnityTrainerException(
f"Visual observation resolution ({width}x{height}) is too small for"

vector_size = 0
for i, dimension in enumerate(observation_shapes):
if len(dimension) == 3:
ModelUtils._check_resolution_for_encoder(
dimension[0], dimension[1], vis_encode_type
)
visual_encoders.append(
visual_encoder_class(
dimension[0], dimension[1], dimension[2], h_size

raise UnityTrainerException(
f"Unsupported shape of {dimension} for observation {i}"
)
if unnormalized_inputs > 0:
vector_encoders.append(
VectorAndUnnormalizedInputEncoder(
vector_size, h_size, unnormalized_inputs, num_layers, normalize
if vector_size + unnormalized_inputs > 0:
if unnormalized_inputs > 0:
vector_encoders.append(
VectorAndUnnormalizedInputEncoder(
vector_size, h_size, unnormalized_inputs, num_layers, normalize
)
)
else:
vector_encoders.append(
VectorEncoder(vector_size, h_size, num_layers, normalize)
)
else:
vector_encoders.append(
VectorEncoder(vector_size, h_size, num_layers, normalize)
)
return nn.ModuleList(visual_encoders), nn.ModuleList(vector_encoders)
@staticmethod

def actions_to_onehot(
discrete_actions: torch.Tensor, action_size: List[int]
) -> List[torch.Tensor]:
"""
Takes a tensor of discrete actions and turns it into a List of onehot encoding for each
action.
:param discrete_actions: Actions in integer form.
:param action_size: List of branch sizes. Should be of same size as discrete_actions'
last dimension.
:return: List of one-hot tensors, one representing each branch.
"""
@staticmethod
def get_probs_and_entropy(
action_list: List[torch.Tensor], dists: List[DistInstance]
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
log_probs_list = []
all_probs_list = []
entropies_list = []
for action, action_dist in zip(action_list, dists):
log_prob = action_dist.log_prob(action)
log_probs_list.append(log_prob)
entropies_list.append(action_dist.entropy())
if isinstance(action_dist, DiscreteDistInstance):
all_probs_list.append(action_dist.all_log_prob())
log_probs = torch.stack(log_probs_list, dim=-1)
entropies = torch.stack(entropies_list, dim=-1)
if not all_probs_list:
log_probs = log_probs.squeeze(-1)
entropies = entropies.squeeze(-1)
all_probs = None
else:
all_probs = torch.cat(all_probs_list, dim=-1)
return log_probs, entropies, all_probs

210
ml-agents/mlagents/trainers/tests/torch/test_networks.py


import pytest
import torch
from mlagents.trainers.torch.networks import (
NetworkBody,
ValueNetwork,
SimpleActor,
SharedActorCritic,
SeparateActorCritic,
)
from mlagents.trainers.settings import NetworkSettings
from mlagents_envs.base_env import ActionType
from mlagents.trainers.torch.distributions import (
GaussianDistInstance,
CategoricalDistInstance,
)
def test_networkbody_vector():
obs_size = 4
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]
networkbody = NetworkBody(obs_shapes, network_settings, encoded_act_size=2)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = torch.ones((1, obs_size))
sample_act = torch.ones((1, 2))
for _ in range(100):
encoded, _ = networkbody([sample_obs], [], sample_act)
assert encoded.shape == (1, network_settings.hidden_units)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten():
assert _enc == pytest.approx(1.0, abs=0.1)
def test_networkbody_lstm():
obs_size = 4
seq_len = 16
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=4)
)
obs_shapes = [(obs_size,)]
networkbody = NetworkBody(obs_shapes, network_settings)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = torch.ones((1, seq_len, obs_size))
for _ in range(100):
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 4))
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten():
assert _enc == pytest.approx(1.0, abs=0.1)
def test_networkbody_visual():
vec_obs_size = 4
obs_size = (84, 84, 3)
network_settings = NetworkSettings()
obs_shapes = [(vec_obs_size,), obs_size]
torch.random.manual_seed(0)
networkbody = NetworkBody(obs_shapes, network_settings)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = torch.ones((1, 84, 84, 3))
sample_vec_obs = torch.ones((1, vec_obs_size))
for _ in range(100):
encoded, _ = networkbody([sample_vec_obs], [sample_obs])
assert encoded.shape == (1, network_settings.hidden_units)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten():
assert _enc == pytest.approx(1.0, abs=0.1)
def test_valuenetwork():
obs_size = 4
num_outputs = 2
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]
stream_names = [f"stream_name{n}" for n in range(4)]
value_net = ValueNetwork(
stream_names, obs_shapes, network_settings, outputs_per_stream=num_outputs
)
optimizer = torch.optim.Adam(value_net.parameters(), lr=3e-3)
for _ in range(50):
sample_obs = torch.ones((1, obs_size))
values, _ = value_net([sample_obs], [])
loss = 0
for s_name in stream_names:
assert values[s_name].shape == (1, num_outputs)
# Try to force output to 1
loss += torch.nn.functional.mse_loss(
values[s_name], torch.ones((1, num_outputs))
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for value in values.values():
for _out in value:
assert _out[0] == pytest.approx(1.0, abs=0.1)
@pytest.mark.parametrize("action_type", [ActionType.DISCRETE, ActionType.CONTINUOUS])
def test_simple_actor(action_type):
obs_size = 4
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]
act_size = [2]
masks = None if action_type == ActionType.CONTINUOUS else torch.ones((1, 1))
actor = SimpleActor(obs_shapes, network_settings, action_type, act_size)
# Test get_dist
sample_obs = torch.ones((1, obs_size))
dists, _ = actor.get_dists([sample_obs], [], masks=masks)
for dist in dists:
if action_type == ActionType.CONTINUOUS:
assert isinstance(dist, GaussianDistInstance)
else:
assert isinstance(dist, CategoricalDistInstance)
# Test sample_actions
actions = actor.sample_action(dists)
for act in actions:
if action_type == ActionType.CONTINUOUS:
assert act.shape == (1, act_size[0])
else:
assert act.shape == (1, 1)
# Test forward
actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward(
[sample_obs], [], masks=masks
)
for act in actions:
if action_type == ActionType.CONTINUOUS:
assert act.shape == (
act_size[0],
1,
) # This is different from above for ONNX export
else:
assert act.shape == (1, 1)
# TODO: Once export works properly. fix the shapes here.
assert mem_size == 0
assert is_cont == int(action_type == ActionType.CONTINUOUS)
assert act_size_vec == torch.tensor(act_size)
@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic])
@pytest.mark.parametrize("lstm", [True, False])
def test_actor_critic(ac_type, lstm):
obs_size = 4
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings() if lstm else None
)
obs_shapes = [(obs_size,)]
act_size = [2]
stream_names = [f"stream_name{n}" for n in range(4)]
actor = ac_type(
obs_shapes, network_settings, ActionType.CONTINUOUS, act_size, stream_names
)
if lstm:
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
memories = torch.ones(
(
1,
network_settings.memory.sequence_length,
network_settings.memory.memory_size,
)
)
else:
sample_obs = torch.ones((1, obs_size))
memories = torch.tensor([])
# memories isn't always set to None, the network should be able to
# deal with that.
# Test critic pass
value_out = actor.critic_pass([sample_obs], [], memories=memories)
for stream in stream_names:
if lstm:
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
else:
assert value_out[stream].shape == (1,)
# Test get_dist_and_value
dists, value_out, _ = actor.get_dist_and_value([sample_obs], [], memories=memories)
for dist in dists:
assert isinstance(dist, GaussianDistInstance)
for stream in stream_names:
if lstm:
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
else:
assert value_out[stream].shape == (1,)

31
ml-agents/mlagents/trainers/tests/torch/test_decoders.py


import pytest
import torch
from mlagents.trainers.torch.decoders import ValueHeads
def test_valueheads():
stream_names = [f"reward_signal_{num}" for num in range(5)]
input_size = 5
batch_size = 4
# Test default 1 value per head
value_heads = ValueHeads(stream_names, input_size)
input_data = torch.ones((batch_size, input_size))
value_out, _ = value_heads(input_data) # Note: mean value will be removed shortly
for stream_name in stream_names:
assert value_out[stream_name].shape == (batch_size,)
# Test that inputting the wrong size input will throw an error
with pytest.raises(Exception):
value_out = value_heads(torch.ones((batch_size, input_size + 2)))
# Test multiple values per head (e.g. discrete Q function)
output_size = 4
value_heads = ValueHeads(stream_names, input_size, output_size)
input_data = torch.ones((batch_size, input_size))
value_out, _ = value_heads(input_data)
for stream_name in stream_names:
assert value_out[stream_name].shape == (batch_size, output_size)

141
ml-agents/mlagents/trainers/tests/torch/test_distributions.py


import pytest
import torch
from mlagents.trainers.torch.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,
GaussianDistInstance,
TanhGaussianDistInstance,
CategoricalDistInstance,
)
@pytest.mark.parametrize("tanh_squash", [True, False])
@pytest.mark.parametrize("conditional_sigma", [True, False])
def test_gaussian_distribution(conditional_sigma, tanh_squash):
torch.manual_seed(0)
hidden_size = 16
act_size = 4
sample_embedding = torch.ones((1, 16))
gauss_dist = GaussianDistribution(
hidden_size,
act_size,
conditional_sigma=conditional_sigma,
tanh_squash=tanh_squash,
)
# Make sure backprop works
force_action = torch.zeros((1, act_size))
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)
for _ in range(50):
dist_inst = gauss_dist(sample_embedding)[0]
if tanh_squash:
assert isinstance(dist_inst, TanhGaussianDistInstance)
else:
assert isinstance(dist_inst, GaussianDistInstance)
log_prob = dist_inst.log_prob(force_action)
loss = torch.nn.functional.mse_loss(log_prob, -2 * torch.ones(log_prob.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
for prob in log_prob.flatten():
assert prob == pytest.approx(-2, abs=0.1)
def test_multi_categorical_distribution():
torch.manual_seed(0)
hidden_size = 16
act_size = [3, 3, 4]
sample_embedding = torch.ones((1, 16))
gauss_dist = MultiCategoricalDistribution(hidden_size, act_size)
# Make sure backprop works
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)
def create_test_prob(size: int) -> torch.Tensor:
test_prob = torch.tensor(
[[1.0 - 0.01 * (size - 1)] + [0.01] * (size - 1)]
) # High prob for first action
return test_prob.log()
for _ in range(100):
dist_insts = gauss_dist(sample_embedding, masks=torch.ones((1, sum(act_size))))
loss = 0
for i, dist_inst in enumerate(dist_insts):
assert isinstance(dist_inst, CategoricalDistInstance)
log_prob = dist_inst.all_log_prob()
test_log_prob = create_test_prob(act_size[i])
# Force log_probs to match the high probability for the first action generated by
# create_test_prob
loss += torch.nn.functional.mse_loss(log_prob, test_log_prob)
optimizer.zero_grad()
loss.backward()
optimizer.step()
for dist_inst, size in zip(dist_insts, act_size):
# Check that the log probs are close to the fake ones that we generated.
test_log_probs = create_test_prob(size)
for _prob, _test_prob in zip(
dist_inst.all_log_prob().flatten().tolist(),
test_log_probs.flatten().tolist(),
):
assert _prob == pytest.approx(_test_prob, abs=0.1)
# Test masks
masks = []
for branch in act_size:
masks += [0] * (branch - 1) + [1]
masks = torch.tensor([masks])
dist_insts = gauss_dist(sample_embedding, masks=masks)
for dist_inst in dist_insts:
log_prob = dist_inst.all_log_prob()
assert log_prob.flatten()[-1] == pytest.approx(0, abs=0.001)
def test_gaussian_dist_instance():
torch.manual_seed(0)
act_size = 4
dist_instance = GaussianDistInstance(
torch.zeros(1, act_size), torch.ones(1, act_size)
)
action = dist_instance.sample()
assert action.shape == (1, act_size)
for log_prob in dist_instance.log_prob(torch.zeros((1, act_size))).flatten():
# Log prob of standard normal at 0
assert log_prob == pytest.approx(-0.919, abs=0.01)
for ent in dist_instance.entropy().flatten():
# entropy of standard normal at 0
assert ent == pytest.approx(2.83, abs=0.01)
def test_tanh_gaussian_dist_instance():
torch.manual_seed(0)
act_size = 4
dist_instance = GaussianDistInstance(
torch.zeros(1, act_size), torch.ones(1, act_size)
)
for _ in range(10):
action = dist_instance.sample()
assert action.shape == (1, act_size)
assert torch.max(action) < 1.0 and torch.min(action) > -1.0
def test_categorical_dist_instance():
torch.manual_seed(0)
act_size = 4
test_prob = torch.tensor(
[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)
) # High prob for first action
dist_instance = CategoricalDistInstance(test_prob)
for _ in range(10):
action = dist_instance.sample()
assert action.shape == (1,)
assert action < act_size
# Make sure the first action as higher probability than the others.
prob_first_action = dist_instance.log_prob(torch.tensor([0]))
for i in range(1, act_size):
assert dist_instance.log_prob(torch.tensor([i])) < prob_first_action

110
ml-agents/mlagents/trainers/tests/torch/test_encoders.py


import torch
from unittest import mock
import pytest
from mlagents.trainers.torch.encoders import (
VectorEncoder,
VectorAndUnnormalizedInputEncoder,
Normalizer,
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
)
# This test will also reveal issues with states not being saved in the state_dict.
def compare_models(module_1, module_2):
is_same = True
for key_item_1, key_item_2 in zip(
module_1.state_dict().items(), module_2.state_dict().items()
):
# Compare tensors in state_dict and not the keys.
is_same = torch.equal(key_item_1[1], key_item_2[1]) and is_same
return is_same
def test_normalizer():
input_size = 2
norm = Normalizer(input_size)
# These three inputs should mean to 0.5, and variance 2
# with the steps starting at 1
vec_input1 = torch.tensor([[1, 1]])
vec_input2 = torch.tensor([[1, 1]])
vec_input3 = torch.tensor([[0, 0]])
norm.update(vec_input1)
norm.update(vec_input2)
norm.update(vec_input3)
# Test normalization
for val in norm(vec_input1)[0]:
assert val == pytest.approx(0.707, abs=0.001)
# Test copy normalization
norm2 = Normalizer(input_size)
assert not compare_models(norm, norm2)
norm2.copy_from(norm)
assert compare_models(norm, norm2)
for val in norm2(vec_input1)[0]:
assert val == pytest.approx(0.707, abs=0.001)
@mock.patch("mlagents.trainers.torch.encoders.Normalizer")
def test_vector_encoder(mock_normalizer):
mock_normalizer_inst = mock.Mock()
mock_normalizer.return_value = mock_normalizer_inst
input_size = 64
hidden_size = 128
num_layers = 3
normalize = False
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize)
output = vector_encoder(torch.ones((1, input_size)))
assert output.shape == (1, hidden_size)
normalize = True
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize)
new_vec = torch.ones((1, input_size))
vector_encoder.update_normalization(new_vec)
mock_normalizer.assert_called_with(input_size)
mock_normalizer_inst.update.assert_called_with(new_vec)
vector_encoder2 = VectorEncoder(input_size, hidden_size, num_layers, normalize)
vector_encoder.copy_normalization(vector_encoder2)
mock_normalizer_inst.copy_from.assert_called_with(mock_normalizer_inst)
@mock.patch("mlagents.trainers.torch.encoders.Normalizer")
def test_vector_and_unnormalized_encoder(mock_normalizer):
mock_normalizer_inst = mock.Mock()
mock_normalizer.return_value = mock_normalizer_inst
input_size = 64
unnormalized_size = 32
hidden_size = 128
num_layers = 3
normalize = True
mock_normalizer_inst.return_value = torch.ones((1, input_size))
vector_encoder = VectorAndUnnormalizedInputEncoder(
input_size, hidden_size, unnormalized_size, num_layers, normalize
)
# Make sure normalizer is only called on input_size
mock_normalizer.assert_called_with(input_size)
normal_input = torch.ones((1, input_size))
unnormalized_input = torch.ones((1, 32))
output = vector_encoder(normal_input, unnormalized_input)
mock_normalizer_inst.assert_called_with(normal_input)
assert output.shape == (1, hidden_size)
@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)])
@pytest.mark.parametrize(
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder]
)
def test_visual_encoder(vis_class, image_size):
num_outputs = 128
enc = vis_class(image_size[0], image_size[1], image_size[2], num_outputs)
# Note: NCHW not NHWC
sample_input = torch.ones((1, image_size[2], image_size[0], image_size[1]))
encoding = enc(sample_input)
assert encoding.shape == (1, num_outputs)

166
ml-agents/mlagents/trainers/tests/torch/test_utils.py


import pytest
import torch
import numpy as np
from mlagents.trainers.settings import EncoderType
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.torch.encoders import (
VectorEncoder,
VectorAndUnnormalizedInputEncoder,
)
from mlagents.trainers.torch.distributions import (
CategoricalDistInstance,
GaussianDistInstance,
)
def test_min_visual_size():
# Make sure each EncoderType has an entry in MIS_RESOLUTION_FOR_ENCODER
assert set(ModelUtils.MIN_RESOLUTION_FOR_ENCODER.keys()) == set(EncoderType)
for encoder_type in EncoderType:
good_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type]
vis_input = torch.ones((1, 3, good_size, good_size))
ModelUtils._check_resolution_for_encoder(vis_input, encoder_type)
enc_func = ModelUtils.get_encoder_for_type(encoder_type)
enc = enc_func(good_size, good_size, 3, 1)
enc.forward(vis_input)
# Anything under the min size should raise an exception. If not, decrease the min size!
with pytest.raises(Exception):
bad_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] - 1
vis_input = torch.ones((1, 3, bad_size, bad_size))
with pytest.raises(UnityTrainerException):
# Make sure we'd hit a friendly error during model setup time.
ModelUtils._check_resolution_for_encoder(vis_input, encoder_type)
enc = enc_func(bad_size, bad_size, 3, 1)
enc.forward(vis_input)
@pytest.mark.parametrize("unnormalized_inputs", [0, 1])
@pytest.mark.parametrize("num_visual", [0, 1, 2])
@pytest.mark.parametrize("num_vector", [0, 1, 2])
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("encoder_type", [EncoderType.SIMPLE, EncoderType.NATURE_CNN])
def test_create_encoders(
encoder_type, normalize, num_vector, num_visual, unnormalized_inputs
):
vec_obs_shape = (5,)
vis_obs_shape = (84, 84, 3)
obs_shapes = []
for _ in range(num_vector):
obs_shapes.append(vec_obs_shape)
for _ in range(num_visual):
obs_shapes.append(vis_obs_shape)
h_size = 128
num_layers = 3
unnormalized_inputs = 1
vis_enc, vec_enc = ModelUtils.create_encoders(
obs_shapes, h_size, num_layers, encoder_type, unnormalized_inputs, normalize
)
vec_enc = list(vec_enc)
vis_enc = list(vis_enc)
assert len(vec_enc) == (
1 if unnormalized_inputs + num_vector > 0 else 0
) # There's always at most one vector encoder.
assert len(vis_enc) == num_visual
if unnormalized_inputs > 0:
assert isinstance(vec_enc[0], VectorAndUnnormalizedInputEncoder)
elif num_vector > 0:
assert isinstance(vec_enc[0], VectorEncoder)
for enc in vis_enc:
assert isinstance(enc, ModelUtils.get_encoder_for_type(encoder_type))
def test_list_to_tensor():
# Test converting pure list
unconverted_list = [[1, 2], [1, 3], [1, 4]]
tensor = ModelUtils.list_to_tensor(unconverted_list)
# Should be equivalent to torch.tensor conversion
assert torch.equal(tensor, torch.tensor(unconverted_list))
# Test converting pure numpy array
np_list = np.asarray(unconverted_list)
tensor = ModelUtils.list_to_tensor(np_list)
# Should be equivalent to torch.tensor conversion
assert torch.equal(tensor, torch.tensor(unconverted_list))
# Test converting list of numpy arrays
list_of_np = [np.asarray(_el) for _el in unconverted_list]
tensor = ModelUtils.list_to_tensor(list_of_np)
# Should be equivalent to torch.tensor conversion
assert torch.equal(tensor, torch.tensor(unconverted_list))
def test_break_into_branches():
# Test normal multi-branch case
all_actions = torch.tensor([[1, 2, 3, 4, 5, 6]])
action_size = [2, 1, 3]
broken_actions = ModelUtils.break_into_branches(all_actions, action_size)
assert len(action_size) == len(broken_actions)
for i, _action in enumerate(broken_actions):
assert _action.shape == (1, action_size[i])
# Test 1-branch case
action_size = [6]
broken_actions = ModelUtils.break_into_branches(all_actions, action_size)
assert len(broken_actions) == 1
assert broken_actions[0].shape == (1, 6)
def test_actions_to_onehot():
all_actions = torch.tensor([[1, 0, 2], [1, 0, 2]])
action_size = [2, 1, 3]
oh_actions = ModelUtils.actions_to_onehot(all_actions, action_size)
expected_result = [
torch.tensor([[0, 1], [0, 1]]),
torch.tensor([[1], [1]]),
torch.tensor([[0, 0, 1], [0, 0, 1]]),
]
for res, exp in zip(oh_actions, expected_result):
assert torch.equal(res, exp)
def test_get_probs_and_entropy():
# Test continuous
# Add two dists to the list. This isn't done in the code but we'd like to support it.
dist_list = [
GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))),
GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))),
]
action_list = [torch.zeros((1, 2)), torch.zeros((1, 2))]
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy(
action_list, dist_list
)
assert log_probs.shape == (1, 2, 2)
assert entropies.shape == (1, 2, 2)
assert all_probs is None
for log_prob in log_probs.flatten():
# Log prob of standard normal at 0
assert log_prob == pytest.approx(-0.919, abs=0.01)
for ent in entropies.flatten():
# entropy of standard normal at 0
assert ent == pytest.approx(2.83, abs=0.01)
# Test continuous
# Add two dists to the list.
act_size = 2
test_prob = torch.tensor(
[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)
) # High prob for first action
dist_list = [CategoricalDistInstance(test_prob), CategoricalDistInstance(test_prob)]
action_list = [torch.tensor([0]), torch.tensor([1])]
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy(
action_list, dist_list
)
assert all_probs.shape == (len(dist_list * act_size),)
assert entropies.shape == (len(dist_list),)
# Make sure the first action has high probability than the others.
assert log_probs.flatten()[0] > log_probs.flatten()[1]
正在加载...
取消
保存