Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
from typing import Callable, List, Dict, Tuple, Optional, Union
import abc
from mlagents.torch_utils import torch, nn
from mlagents_envs.base_env import ActionSpec, SensorSpec
from mlagents.trainers.torch.action_model import ActionModel
from mlagents.trainers.torch.agent_action import AgentAction
from mlagents.trainers.torch.action_log_probs import ActionLogProbs
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.attention import ResidualSelfAttention, EntityEmbeddings
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[
[torch.Tensor, int, ActivationFunction, int, str, bool], torch.Tensor
EPSILON = 1e-7
class NetworkBody(nn.Module):
def __init__(
sensor_specs: List[SensorSpec],
network_settings: NetworkSettings,
encoded_act_size: int = 0,
self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
self.h_size = network_settings.hidden_units
self.m_size = (
if network_settings.memory is not None
else 0
self.processors, self.embedding_sizes = ModelUtils.create_input_processors(
total_enc_size = sum(self.embedding_sizes) + encoded_act_size
self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)
self.lstm = None # type: ignore
def update_normalization(self, buffer: AgentBuffer) -> None:
obs = ObsUtil.from_buffer(buffer, len(self.processors))
for vec_input, enc in zip(obs, self.processors):
if isinstance(enc, VectorInput):
def copy_normalization(self, other_network: "NetworkBody") -> None:
if self.normalize:
for n1, n2 in zip(self.processors, other_network.processors):
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput):
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
def forward(
inputs: List[torch.Tensor],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
processed_obs = processor(obs_input)
if len(encodes) == 0:
raise Exception("No valid inputs to network.")
# Constants don't work in Barracuda
if actions is not None:
inputs = torch.cat(encodes + [actions], dim=-1)
inputs = torch.cat(encodes, dim=-1)
encoding = self.linear_encoder(inputs)
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
return encoding, memories
# NOTE: this class will be replaced with a multi-head attention when the time comes
class MultiInputNetworkBody(nn.Module):
def __init__(
sensor_specs: List[SensorSpec],
network_settings: NetworkSettings,
encoded_act_size: int = 0,
num_obs_heads: int = 1,
self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
# Scale network depending on num agents
self.h_size = network_settings.hidden_units
self.m_size = (
if network_settings.memory is not None
else 0
self.processors, _input_size = ModelUtils.create_input_processors(
# Modules for self-attention
self.entity_encoder = EntityEmbeddings(
sum(_input_size), [sum(_input_size)], self.h_size
self.self_attn = ResidualSelfAttention(self.h_size)
encoder_input_size = self.h_size
total_enc_size = encoder_input_size + encoded_act_size
self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)
self.lstm = None # type: ignore
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
def update_normalization(self, buffer: AgentBuffer) -> None:
obs = ObsUtil.from_buffer(buffer, len(self.processors))
for vec_input, enc in zip(obs, self.processors):
if isinstance(enc, VectorInput):
def copy_normalization(self, other_network: "NetworkBody") -> None:
if self.normalize:
for n1, n2 in zip(self.processors, other_network.processors):
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput):
def forward(
all_net_inputs: List[List[torch.Tensor]],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
concat_encoded_obs = []
x_self = None
self_encodes = []
inputs = all_net_inputs[0]
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
processed_obs = processor(obs_input)
x_self = torch.cat(self_encodes, dim=-1)
# Get attention masks by grabbing an arbitrary obs across all the agents
# Since these are raw obs, the padded values are still NaN
only_first_obs = [_all_obs[0] for _all_obs in all_net_inputs]
obs_for_mask = torch.stack(only_first_obs, dim=1)
# Get the mask from nans
attn_mask = torch.any(obs_for_mask.isnan(), dim=2).type(torch.FloatTensor)
# Get the self encoding separately, but keep it in the entities
concat_encoded_obs = [x_self]
for inputs in all_net_inputs[1:]:
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
concat_encoded_obs.append(torch.cat(encodes, dim=-1))
concat_entites = torch.stack(concat_encoded_obs, dim=1)
encoded_entity = self.entity_encoder(x_self, [concat_entites])
encoded_state = self.self_attn(encoded_entity, [attn_mask])
if len(concat_encoded_obs) == 0:
raise Exception("No valid inputs to network.")
# Constants don't work in Barracuda
if actions is not None:
inputs = torch.cat([encoded_state, actions], dim=-1)
inputs = encoded_state
encoding = self.linear_encoder(inputs)
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
return encoding, memories
class ValueNetwork(nn.Module):
def __init__(
stream_names: List[str],
sensor_specs: List[SensorSpec],
network_settings: NetworkSettings,
encoded_act_size: int = 0,
outputs_per_stream: int = 1,
# This is not a typo, we want to call __init__ of nn.Module
self.network_body = NetworkBody(
sensor_specs, network_settings, encoded_act_size=encoded_act_size
if network_settings.memory is not None:
encoding_size = network_settings.memory.memory_size // 2
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)
def memory_size(self) -> int:
return self.network_body.memory_size
def forward(
inputs: List[torch.Tensor],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories = self.network_body(
inputs, actions, memories, sequence_length
output = self.value_heads(encoding)
return output, memories
class CentralizedValueNetwork(ValueNetwork):
def __init__(
stream_names: List[str],
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
encoded_act_size: int = 0,
outputs_per_stream: int = 1,
# This is not a typo, we want to call __init__ of nn.Module
self.network_body = MultiInputNetworkBody(
observation_shapes, network_settings, encoded_act_size=encoded_act_size
if network_settings.memory is not None:
encoding_size = network_settings.memory.memory_size // 2
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)
def forward(
inputs: List[List[torch.Tensor]],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories = self.network_body(
inputs, actions, memories, sequence_length
output = self.value_heads(encoding)
return output, memories
class Actor(abc.ABC):
def update_normalization(self, buffer: AgentBuffer) -> None:
Updates normalization of Actor based on the provided List of vector obs.
:param vector_obs: A List of vector obs as tensors.
def get_action_stats(
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]:
Returns sampled actions.
If memory is enabled, return the memories as well.
:param vec_inputs: A List of vector inputs as tensors.
:param vis_inputs: A List of visual inputs as tensors.
:param masks: If using discrete actions, a Tensor of action masks.
:param memories: If using memory, a Tensor of initial memories.
:param sequence_length: If using memory, the sequence length.
:return: A Tuple of AgentAction, ActionLogProbs, entropies, and memories.
Memories will be None if not using memory.
def forward(
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[Union[int, torch.Tensor], ...]:
Forward pass of the Actor for inference. This is required for export to ONNX, and
the inputs and outputs of this method should not be changed without a respective change
in the ONNX export code.
class ActorCritic(Actor):
def critic_pass(
inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
Get value outputs for the given obs.
:param inputs: List of inputs as tensors.
:param memories: Tensor of memories, if using memory. Otherwise, None.
:returns: Dict of reward stream to output tensor for values.
def get_action_stats_and_value(
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
critic_obs: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
Returns sampled actions and value estimates.
If memory is enabled, return the memories as well.
:param inputs: A List of vector inputs as tensors.
:param masks: If using discrete actions, a Tensor of action masks.
:param memories: If using memory, a Tensor of initial memories.
:param sequence_length: If using memory, the sequence length.
:return: A Tuple of AgentAction, ActionLogProbs, entropies, Dict of reward signal
name to value estimate, and memories. Memories will be None if not using memory.
def memory_size(self):
Returns the size of the memory (same size used as input and output in the other
methods) used by this Actor.
class SimpleActor(nn.Module, Actor):
def __init__(
sensor_specs: List[SensorSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
conditional_sigma: bool = False,
tanh_squash: bool = False,
self.action_spec = action_spec
self.version_number = torch.nn.Parameter(
torch.Tensor([2.0]), requires_grad=False
self.is_continuous_int_deprecated = torch.nn.Parameter(
torch.Tensor([int(self.action_spec.is_continuous())]), requires_grad=False
self.continuous_act_size_vector = torch.nn.Parameter(
torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False
# TODO: export list of branch sizes instead of sum
self.discrete_act_size_vector = torch.nn.Parameter(
torch.Tensor([sum(self.action_spec.discrete_branches)]), requires_grad=False
self.act_size_vector_deprecated = torch.nn.Parameter(
+ sum(self.action_spec.discrete_branches)
self.network_body = NetworkBody(sensor_specs, network_settings)
if network_settings.memory is not None:
self.encoding_size = network_settings.memory.memory_size // 2
self.encoding_size = network_settings.hidden_units
self.memory_size_vector = torch.nn.Parameter(
torch.Tensor([int(self.network_body.memory_size)]), requires_grad=False
self.action_model = ActionModel(
def memory_size(self) -> int:
return self.network_body.memory_size
def update_normalization(self, buffer: AgentBuffer) -> None:
def get_action_stats(
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]:
encoding, memories = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
action, log_probs, entropies = self.action_model(encoding, masks)
return action, log_probs, entropies, memories
def forward(
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[Union[int, torch.Tensor], ...]:
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
At this moment, torch.onnx.export() doesn't accept None as tensor to be exported,
so the size of return tuple varies with action spec.
# This code will convert the vec and vis obs into a list of inputs for the network
concatenated_vec_obs = vec_inputs[0]
inputs = []
start = 0
end = 0
vis_index = 0
for i, enc in enumerate(self.network_body.processors):
if isinstance(enc, VectorInput):
# This is a vec_obs
vec_size = self.network_body.embedding_sizes[i]
end = start + vec_size
inputs.append(concatenated_vec_obs[:, start:end])
start = end
vis_index += 1
# End of code to convert the vec and vis obs into a list of inputs for the network
encoding, memories_out = self.network_body(
inputs, memories=memories, sequence_length=1
) = self.action_model.get_action_out(encoding, masks)
export_out = [self.version_number, self.memory_size_vector]
if self.action_spec.continuous_size > 0:
export_out += [cont_action_out, self.continuous_act_size_vector]
if self.action_spec.discrete_size > 0:
export_out += [disc_action_out, self.discrete_act_size_vector]
# Only export deprecated nodes with non-hybrid action spec
if self.action_spec.continuous_size == 0 or self.action_spec.discrete_size == 0:
export_out += [
return tuple(export_out)
class SharedActorCritic(SimpleActor, ActorCritic):
def __init__(
sensor_specs: List[SensorSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,
self.use_lstm = network_settings.memory is not None
sensor_specs, network_settings, action_spec, conditional_sigma, tanh_squash
self.stream_names = stream_names
self.value_heads = ValueHeads(stream_names, self.encoding_size)
def critic_pass(
inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories_out = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
return self.value_heads(encoding), memories_out
def get_stats_and_value(
inputs: List[torch.Tensor],
actions: AgentAction,
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
critic_obs: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]:
encoding, memories = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
value_outputs = self.value_heads(encoding)
return log_probs, entropies, value_outputs
def get_action_stats_and_value(
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
encoding, memories = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
action, log_probs, entropies = self.action_model(encoding, masks)
value_outputs = self.value_heads(encoding)
return action, log_probs, entropies, value_outputs, memories
class SeparateActorCritic(SimpleActor, ActorCritic):
def __init__(
sensor_specs: List[SensorSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,
self.use_lstm = network_settings.memory is not None
sensor_specs, network_settings, action_spec, conditional_sigma, tanh_squash
self.stream_names = stream_names
self.critic = CentralizedValueNetwork(
stream_names, sensor_specs, network_settings
def memory_size(self) -> int:
return self.network_body.memory_size + self.critic.memory_size
def _get_actor_critic_mem(
self, memories: Optional[torch.Tensor] = None
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.use_lstm and memories is not None:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
critic_mem = None
actor_mem = None
return actor_mem, critic_mem
def critic_pass(
inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
critic_obs: List[List[torch.Tensor]] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
all_net_inputs = [inputs]
if critic_obs is not None:
value_outputs, critic_mem_out = self.critic(
all_net_inputs, memories=critic_mem, sequence_length=sequence_length
if actor_mem is not None:
# Make memories with the actor mem unchanged
memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1)
memories_out = None
return value_outputs, memories_out
def get_stats_and_value(
inputs: List[torch.Tensor],
actions: AgentAction,
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
critic_obs: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
encoding, actor_mem_outs = self.network_body(
inputs, memories=actor_mem, sequence_length=sequence_length
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
all_net_inputs = [inputs]
if critic_obs is not None:
value_outputs, critic_mem_outs = self.critic(
all_net_inputs, memories=critic_mem, sequence_length=sequence_length
return log_probs, entropies, value_outputs
def get_action_stats(
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
action, log_probs, entropies, actor_mem_out = super().get_action_stats(
inputs, masks=masks, memories=actor_mem, sequence_length=sequence_length
if critic_mem is not None:
# Make memories with the actor mem unchanged
memories_out = torch.cat([actor_mem_out, critic_mem], dim=-1)
memories_out = None
return action, log_probs, entropies, memories_out
def get_action_stats_and_value(
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
critic_obs: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
encoding, actor_mem_outs = self.network_body(
inputs, memories=actor_mem, sequence_length=sequence_length
action, log_probs, entropies = self.action_model(encoding, masks)
all_net_inputs = [inputs]
if critic_obs is not None:
value_outputs, critic_mem_outs = self.critic(
all_net_inputs, memories=critic_mem, sequence_length=sequence_length
if self.use_lstm:
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)
mem_out = None
return action, log_probs, entropies, value_outputs, mem_out
def update_normalization(self, buffer: AgentBuffer) -> None:
class GlobalSteps(nn.Module):
def __init__(self):
self.__global_step = nn.Parameter(
torch.Tensor([0]).to(torch.int64), requires_grad=False
def current_step(self):
return int(self.__global_step.item())
def current_step(self, value):
self.__global_step[:] = value
def increment(self, value):
self.__global_step += value
class LearningRate(nn.Module):
def __init__(self, lr):
# Todo: add learning rate decay
self.learning_rate = torch.Tensor([lr])