|
|
|
|
|
|
from enum import Enum |
|
|
|
from typing import Callable, List, Dict, Tuple, Optional, Union |
|
|
|
import abc |
|
|
|
|
|
|
|
|
|
|
from mlagents.trainers.settings import NetworkSettings |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.decoders import ValueHeads |
|
|
|
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, HyperNetwork |
|
|
|
from mlagents.trainers.torch.layers import ( |
|
|
|
LSTM, |
|
|
|
LinearEncoder, |
|
|
|
HyperNetwork, |
|
|
|
ConditionalEncoder, |
|
|
|
) |
|
|
|
from mlagents.trainers.torch.encoders import VectorInput |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.trajectory import ObsUtil |
|
|
|
|
|
|
EPSILON = 1e-7 |
|
|
|
|
|
|
|
|
|
|
|
class ConditioningMode(Enum): |
|
|
|
DEFAULT = 0 |
|
|
|
HYPER = 1 |
|
|
|
SOFT = 3 |
|
|
|
|
|
|
|
|
|
|
|
class NetworkBody(nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.conditioning_mode = ConditioningMode.HYPER |
|
|
|
self.normalize = network_settings.normalize |
|
|
|
self.use_lstm = network_settings.memory is not None |
|
|
|
self.h_size = network_settings.hidden_units |
|
|
|
|
|
|
|
|
|
|
total_enc_size, total_goal_size = 0, 0 |
|
|
|
for idx, embedding_size in enumerate(self.embedding_sizes): |
|
|
|
if self.obs_types[idx] == ObservationType.DEFAULT: |
|
|
|
if ( |
|
|
|
self.obs_types[idx] == ObservationType.DEFAULT |
|
|
|
or self.conditioning_mode == ConditioningMode.DEFAULT |
|
|
|
): |
|
|
|
if self.obs_types[idx] == ObservationType.GOAL: |
|
|
|
if ( |
|
|
|
self.obs_types[idx] == ObservationType.GOAL |
|
|
|
and self.conditioning_mode != ConditioningMode.DEFAULT |
|
|
|
): |
|
|
|
if ObservationType.GOAL in self.obs_types: |
|
|
|
if ( |
|
|
|
ObservationType.GOAL in self.obs_types |
|
|
|
and self.conditioning_mode == ConditioningMode.HYPER |
|
|
|
): |
|
|
|
total_goal_size, |
|
|
|
network_settings.num_layers, |
|
|
|
self.h_size, |
|
|
|
) |
|
|
|
elif ( |
|
|
|
ObservationType.GOAL in self.obs_types |
|
|
|
and self.conditioning_mode == ConditioningMode.SOFT |
|
|
|
): |
|
|
|
self.linear_encoder = ConditionalEncoder( |
|
|
|
total_enc_size, |
|
|
|
total_goal_size, |
|
|
|
network_settings.num_layers, |
|
|
|
self.h_size, |
|
|
|
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
if self.obs_types[idx] == ObservationType.DEFAULT: |
|
|
|
if ( |
|
|
|
self.obs_types[idx] == ObservationType.DEFAULT |
|
|
|
or self.conditioning_mode == ConditioningMode.DEFAULT |
|
|
|
): |
|
|
|
elif self.obs_types[idx] == ObservationType.GOAL: |
|
|
|
elif ( |
|
|
|
self.obs_types[idx] == ObservationType.GOAL |
|
|
|
and self.conditioning_mode != ConditioningMode.DEFAULT |
|
|
|
): |
|
|
|
goal_signal = processed_obs |
|
|
|
|
|
|
|
if len(encodes) == 0: |
|
|
|