您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
166 行
6.6 KiB
166 行
6.6 KiB
from typing import List, Optional, Tuple
|
|
import torch
|
|
import numpy as np
|
|
from torch import nn
|
|
|
|
from mlagents.trainers.torch.encoders import (
|
|
SimpleVisualEncoder,
|
|
ResNetVisualEncoder,
|
|
NatureVisualEncoder,
|
|
VectorEncoder,
|
|
VectorAndUnnormalizedInputEncoder,
|
|
)
|
|
from mlagents.trainers.settings import EncoderType
|
|
from mlagents.trainers.exception import UnityTrainerException
|
|
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance
|
|
|
|
|
|
class ModelUtils:
|
|
# Minimum supported side for each encoder type. If refactoring an encoder, please
|
|
# adjust these also.
|
|
MIN_RESOLUTION_FOR_ENCODER = {
|
|
EncoderType.SIMPLE: 20,
|
|
EncoderType.NATURE_CNN: 36,
|
|
EncoderType.RESNET: 15,
|
|
}
|
|
|
|
@staticmethod
|
|
def swish(input_activation: torch.Tensor) -> torch.Tensor:
|
|
"""Swish activation function. For more info: https://arxiv.org/abs/1710.05941"""
|
|
return torch.mul(input_activation, torch.sigmoid(input_activation))
|
|
|
|
@staticmethod
|
|
def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
|
|
ENCODER_FUNCTION_BY_TYPE = {
|
|
EncoderType.SIMPLE: SimpleVisualEncoder,
|
|
EncoderType.NATURE_CNN: NatureVisualEncoder,
|
|
EncoderType.RESNET: ResNetVisualEncoder,
|
|
}
|
|
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)
|
|
|
|
@staticmethod
|
|
def _check_resolution_for_encoder(
|
|
vis_in: torch.Tensor, vis_encoder_type: EncoderType
|
|
) -> None:
|
|
min_res = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[vis_encoder_type]
|
|
height = vis_in.shape[1]
|
|
width = vis_in.shape[2]
|
|
if height < min_res or width < min_res:
|
|
raise UnityTrainerException(
|
|
f"Visual observation resolution ({width}x{height}) is too small for"
|
|
f"the provided EncoderType ({vis_encoder_type.value}). The min dimension is {min_res}"
|
|
)
|
|
|
|
@staticmethod
|
|
def create_encoders(
|
|
observation_shapes: List[Tuple[int, ...]],
|
|
h_size: int,
|
|
num_layers: int,
|
|
vis_encode_type: EncoderType,
|
|
unnormalized_inputs: int = 0,
|
|
normalize: bool = False,
|
|
) -> Tuple[nn.ModuleList, nn.ModuleList]:
|
|
"""
|
|
Creates visual and vector encoders, along with their normalizers.
|
|
:param observation_shapes: List of Tuples that represent the action dimensions.
|
|
:param action_size: Number of additional un-normalized inputs to each vector encoder. Used for
|
|
conditioining network on other values (e.g. actions for a Q function)
|
|
:param h_size: Number of hidden units per layer.
|
|
:param num_layers: Depth of MLP per encoder.
|
|
:param vis_encode_type: Type of visual encoder to use.
|
|
:param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector
|
|
obs.
|
|
:param normalize: Normalize all vector inputs.
|
|
:return: Tuple of visual encoders and vector encoders each as a list.
|
|
"""
|
|
visual_encoders: List[nn.Module] = []
|
|
vector_encoders: List[nn.Module] = []
|
|
|
|
visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type)
|
|
vector_size = 0
|
|
for i, dimension in enumerate(observation_shapes):
|
|
if len(dimension) == 3:
|
|
visual_encoders.append(
|
|
visual_encoder_class(
|
|
dimension[0], dimension[1], dimension[2], h_size
|
|
)
|
|
)
|
|
elif len(dimension) == 1:
|
|
vector_size += dimension[0]
|
|
else:
|
|
raise UnityTrainerException(
|
|
f"Unsupported shape of {dimension} for observation {i}"
|
|
)
|
|
if vector_size + unnormalized_inputs > 0:
|
|
if unnormalized_inputs > 0:
|
|
vector_encoders.append(
|
|
VectorAndUnnormalizedInputEncoder(
|
|
vector_size, h_size, unnormalized_inputs, num_layers, normalize
|
|
)
|
|
)
|
|
else:
|
|
vector_encoders.append(
|
|
VectorEncoder(vector_size, h_size, num_layers, normalize)
|
|
)
|
|
return nn.ModuleList(visual_encoders), nn.ModuleList(vector_encoders)
|
|
|
|
@staticmethod
|
|
def list_to_tensor(
|
|
ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = None
|
|
) -> torch.Tensor:
|
|
"""
|
|
Converts a list of numpy arrays into a tensor. MUCH faster than
|
|
calling as_tensor on the list directly.
|
|
"""
|
|
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)
|
|
|
|
@staticmethod
|
|
def break_into_branches(
|
|
concatenated_logits: torch.Tensor, action_size: List[int]
|
|
) -> List[torch.Tensor]:
|
|
"""
|
|
Takes a concatenated set of logits that represent multiple discrete action branches
|
|
and breaks it up into one Tensor per branch.
|
|
:param concatenated_logits: Tensor that represents the concatenated action branches
|
|
:param action_size: List of ints containing the number of possible actions for each branch.
|
|
:return: A List of Tensors containing one tensor per branch.
|
|
"""
|
|
action_idx = [0] + list(np.cumsum(action_size))
|
|
branched_logits = [
|
|
concatenated_logits[:, action_idx[i] : action_idx[i + 1]]
|
|
for i in range(len(action_size))
|
|
]
|
|
return branched_logits
|
|
|
|
@staticmethod
|
|
def actions_to_onehot(
|
|
discrete_actions: torch.Tensor, action_size: List[int]
|
|
) -> List[torch.Tensor]:
|
|
onehot_branches = [
|
|
torch.nn.functional.one_hot(_act.T, action_size[i])
|
|
for i, _act in enumerate(discrete_actions.T)
|
|
]
|
|
return onehot_branches
|
|
|
|
@staticmethod
|
|
def get_probs_and_entropy(
|
|
action_list: List[torch.Tensor], dists: List[DistInstance]
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
log_probs_list = []
|
|
all_probs_list = []
|
|
entropies_list = []
|
|
for action, action_dist in zip(action_list, dists):
|
|
log_prob = action_dist.log_prob(action)
|
|
log_probs_list.append(log_prob)
|
|
entropies_list.append(action_dist.entropy())
|
|
if isinstance(action_dist, DiscreteDistInstance):
|
|
all_probs_list.append(action_dist.all_log_prob())
|
|
log_probs = torch.stack(log_probs_list, dim=-1)
|
|
entropies = torch.stack(entropies_list, dim=-1)
|
|
if not all_probs_list:
|
|
log_probs = log_probs.squeeze(-1)
|
|
entropies = entropies.squeeze(-1)
|
|
all_probs = None
|
|
else:
|
|
all_probs = torch.cat(all_probs, dim=-1)
|
|
return log_probs, entropies, all_probs
|