|
|
|
|
|
|
from typing import Callable, List, Dict, Tuple, Optional |
|
|
|
import abc |
|
|
|
import typing |
|
|
|
|
|
|
|
import torch |
|
|
|
from torch import nn |
|
|
|
|
|
|
GaussianDistribution, |
|
|
|
MultiCategoricalDistribution, |
|
|
|
DistInstance, |
|
|
|
GaussianDistInstance, |
|
|
|
CategoricalDistInstance |
|
|
|
) |
|
|
|
from mlagents.trainers.settings import NetworkSettings |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
|
|
|
else: |
|
|
|
self.lstm = None # type: ignore |
|
|
|
|
|
|
|
self.memory_size = 0 |
|
|
|
|
|
|
|
def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None: |
|
|
|
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders): |
|
|
|
vec_enc.update_normalization(vec_input) |
|
|
|
|
|
|
for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders): |
|
|
|
n1.copy_normalization(n2) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.lstm.memory_size if self.use_lstm else 0 |
|
|
|
# @property |
|
|
|
# def memory_size(self) -> int: |
|
|
|
# return self.lstm.memory_size if self.use_lstm else 0 |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
|
if actions is not None: |
|
|
|
hidden = encoder(vec_input, actions) |
|
|
|
else: |
|
|
|
hidden = encoder(vec_input) |
|
|
|
# if actions is not None: |
|
|
|
# hidden = encoder(vec_input, actions) |
|
|
|
# else: |
|
|
|
# hidden = encoder(vec_input) |
|
|
|
hidden = encoder(vec_input) |
|
|
|
if not torch.onnx.is_in_onnx_export(): |
|
|
|
vis_input = vis_input.permute([0, 3, 1, 2]) |
|
|
|
# if not torch.onnx.is_in_onnx_export(): |
|
|
|
# vis_input = vis_input.permute([0, 3, 1, 2]) |
|
|
|
hidden = encoder(vis_input) |
|
|
|
encodes.append(hidden) |
|
|
|
|
|
|
|
|
|
|
for _enc in encodes[1:]: |
|
|
|
encoding += _enc |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|
encoding = encoding.reshape([-1, sequence_length, self.h_size]) |
|
|
|
encoding, memories = self.lstm(encoding, memories) |
|
|
|
encoding = encoding.reshape([-1, self.m_size // 2]) |
|
|
|
# if self.use_lstm: |
|
|
|
# # Resize to (batch, sequence length, encoding size) |
|
|
|
# encoding = encoding.reshape([-1, sequence_length, self.h_size]) |
|
|
|
# encoding, memories = self.lstm(encoding, memories) |
|
|
|
# encoding = encoding.reshape([-1, self.m_size // 2]) |
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoding_size = network_settings.hidden_units |
|
|
|
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size |
|
|
|
self.memory_size = self.network_body.memory_size |
|
|
|
|
|
|
|
# @property |
|
|
|
# def memory_size(self) -> int: |
|
|
|
# return self.network_body.memory_size |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], Optional[torch.Tensor]]: |
|
|
|
encoding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, actions, memories, sequence_length |
|
|
|
) |
|
|
|
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
@abc.abstractproperty |
|
|
|
def memory_size(self): |
|
|
|
""" |
|
|
|
Returns the size of the memory (same size used as input and output in the other |
|
|
|
methods) used by this Actor. |
|
|
|
""" |
|
|
|
pass |
|
|
|
# @abc.abstractproperty |
|
|
|
# def memory_size(self): |
|
|
|
# """ |
|
|
|
# Returns the size of the memory (same size used as input and output in the other |
|
|
|
# methods) used by this Actor. |
|
|
|
# """ |
|
|
|
# pass |
|
|
|
|
|
|
|
|
|
|
|
class SimpleActor(nn.Module, Actor): |
|
|
|
|
|
|
self.distribution = MultiCategoricalDistribution( |
|
|
|
self.encoding_size, act_size |
|
|
|
) |
|
|
|
self.memory_size = 0 |
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size |
|
|
|
# @property |
|
|
|
# def memory_size(self) -> int: |
|
|
|
# return self.network_body.memory_size |
|
|
|
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]: |
|
|
|
@torch.jit.export |
|
|
|
def sample_action(self, dists:List[GaussianDistInstance]) -> List[torch.Tensor]: |
|
|
|
actions = [] |
|
|
|
for action_dist in dists: |
|
|
|
action = action_dist.sample() |
|
|
|
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]: |
|
|
|
): |
|
|
|
if self.act_type == ActionType.CONTINUOUS: |
|
|
|
dists = self.distribution(encoding) |
|
|
|
else: |
|
|
|
dists = self.distribution(encoding, masks) |
|
|
|
################## |
|
|
|
# if self.act_type == ActionType.CONTINUOUS: |
|
|
|
dists = self.distribution(encoding) |
|
|
|
# else: |
|
|
|
# dists = self.distribution(encoding, masks) |
|
|
|
@torch.jit.ignore |
|
|
|
def forward( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
|
|
|
self.stream_names = stream_names |
|
|
|
self.value_heads = ValueHeads(stream_names, self.encoding_size) |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
|
|
|
encoding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if self.act_type == ActionType.CONTINUOUS: |
|
|
|
dists = self.distribution(encoding) |
|
|
|
else: |
|
|
|
dists = self.distribution(encoding, masks=masks) |
|
|
|
# if self.act_type == ActionType.CONTINUOUS: |
|
|
|
# dists = self.distribution(encoding) |
|
|
|
# else: |
|
|
|
dists = self.distribution(encoding, masks=masks) |
|
|
|
|
|
|
|
value_outputs = self.value_heads(encoding) |
|
|
|
return dists, value_outputs, memories |
|
|
|
|
|
|
) |
|
|
|
self.stream_names = stream_names |
|
|
|
self.critic = ValueNetwork(stream_names, observation_shapes, network_settings) |
|
|
|
# self.critic = torch.jit.script(ValueNetwork(stream_names, observation_shapes, network_settings)) |
|
|
|
self.memory_size = 0 |
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size + self.critic.memory_size |
|
|
|
# @property |
|
|
|
# def memory_size(self) -> int: |
|
|
|
# return self.network_body.memory_size + self.critic.memory_size |
|
|
|
@torch.jit.export |
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], Optional[torch.Tensor]]: |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) |
|
|
|
# if self.use_lstm: |
|
|
|
# # Use only the back half of memories for critic |
|
|
|
# actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) |
|
|
|
value_outputs, critic_mem_out = self.critic( |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
memories_out = None |
|
|
|
return value_outputs, memories_out |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def get_dist_and_value( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic and actor |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) |
|
|
|
else: |
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|
): |
|
|
|
# if self.use_lstm: |
|
|
|
# # Use only the back half of memories for critic and actor |
|
|
|
# actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) |
|
|
|
# else: |
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|
dists, actor_mem_outs = self.get_dists( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
|
|
|
value_outputs, critic_mem_outs = self.critic( |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if self.use_lstm: |
|
|
|
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) |
|
|
|
else: |
|
|
|
mem_out = None |
|
|
|
# if self.use_lstm: |
|
|
|
# mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) |
|
|
|
# else: |
|
|
|
mem_out = None |
|
|
|
return dists, value_outputs, mem_out |
|
|
|
|
|
|
|
|
|
|
|