|
|
|
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
class HybridSimpleActor(nn.Module, Actor): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
observation_shapes: List[Tuple[int, ...]], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
continuous_act_size: List[int], |
|
|
|
discrete_act_size: List[int], |
|
|
|
conditional_sigma: bool = False, |
|
|
|
tanh_squash: bool = False, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.discrete_act_size = discrete_act_size |
|
|
|
self.continuous_act_size = continuous_act_size |
|
|
|
self.version_number = torch.nn.Parameter(torch.Tensor([2.0])) |
|
|
|
#self.is_continuous_int = torch.nn.Parameter( |
|
|
|
# torch.Tensor([int(act_type == ActionType.CONTINUOUS)]) |
|
|
|
#) |
|
|
|
self.continuous_act_size_vector = torch.nn.Parameter(torch.Tensor(continuous_act_size)) |
|
|
|
self.discrete_act_size_vector = torch.nn.Parameter(torch.Tensor(discrete_act_size)) |
|
|
|
self.network_body = NetworkBody(observation_shapes, network_settings) |
|
|
|
if network_settings.memory is not None: |
|
|
|
self.encoding_size = network_settings.memory.memory_size // 2 |
|
|
|
else: |
|
|
|
self.encoding_size = network_settings.hidden_units |
|
|
|
self.continuous_distribution = GaussianDistribution( |
|
|
|
self.encoding_size, |
|
|
|
continuous_act_size[0], |
|
|
|
conditional_sigma=conditional_sigma, |
|
|
|
tanh_squash=tanh_squash, |
|
|
|
) |
|
|
|
|
|
|
|
self.discrete_distribution = MultiCategoricalDistribution( |
|
|
|
self.encoding_size, discrete_act_size |
|
|
|
) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size |
|
|
|
|
|
|
|
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: |
|
|
|
self.network_body.update_normalization(vector_obs) |
|
|
|
|
|
|
|
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]: |
|
|
|
actions = [] |
|
|
|
for action_dist in dists: |
|
|
|
action = action_dist.sample() |
|
|
|
actions.append(action) |
|
|
|
return actions |
|
|
|
|
|
|
|
def get_dists( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]: |
|
|
|
encoding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
discrete_dists = self.discrete_distribution(encoding, masks) |
|
|
|
continuous_dists = self.continuous_distribution(encoding) |
|
|
|
return discrete_dists + continuous_dists, memories |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
) -> Tuple[torch.Tensor, int, int, int, int]: |
|
|
|
""" |
|
|
|
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. |
|
|
|
""" |
|
|
|
# TODO: This is bad right now |
|
|
|
dists _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1) |
|
|
|
|
|
|
|
discrete_dists = dists[0] |
|
|
|
continuous_dists = dists[1] |
|
|
|
|
|
|
|
discrete_action_out = discrete_dists[0].all_log_prob() |
|
|
|
|
|
|
|
continuous_action_list = self.sample_action(continuous_dists) |
|
|
|
continuous_action_out = torch.stack(continuous_action_list, dim=-1) |
|
|
|
action_out = torch.cat(continuous, discrete_action_out, dim=-1) |
|
|
|
return ( |
|
|
|
action_out, |
|
|
|
self.version_number, |
|
|
|
torch.Tensor([self.network_body.memory_size]), |
|
|
|
self.is_continuous_int, |
|
|
|
self.act_size_vector, |
|
|
|
) |
|
|
|
|
|
|
|
class HybridSharedActorCritic(HybridSimpleActor, ActorCritic): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
observation_shapes: List[Tuple[int, ...]], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
continuous_act_size: List[int], |
|
|
|
discrete_act_size: List[int], |
|
|
|
stream_names: List[str], |
|
|
|
conditional_sigma: bool = False, |
|
|
|
tanh_squash: bool = False, |
|
|
|
): |
|
|
|
super().__init__( |
|
|
|
self, |
|
|
|
observation_shapes, |
|
|
|
network_settings, |
|
|
|
act_type, |
|
|
|
act_size, |
|
|
|
conditional_sigma, |
|
|
|
tanh_squash, |
|
|
|
) |
|
|
|
self.stream_names = stream_names |
|
|
|
self.value_heads = ValueHeads(stream_names, self.encoding_size) |
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
encoding, memories_out = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
return self.value_heads(encoding), memories_out |
|
|
|
|
|
|
|
def get_dist_and_value( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
encoding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if self.act_type == ActionType.CONTINUOUS: |
|
|
|
dists = self.distribution(encoding) |
|
|
|
else: |
|
|
|
dists = self.distribution(encoding, masks=masks) |
|
|
|
|
|
|
|
value_outputs = self.value_heads(encoding) |
|
|
|
return dists, value_outputs, memories |
|
|
|
|
|
|
|
|
|
|
|
class HybridSeparateActorCritic(HybridSimpleActor, ActorCritic): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
observation_shapes: List[Tuple[int, ...]], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
continuous_act_size: List[int], |
|
|
|
discrete_act_size: List[int], |
|
|
|
stream_names: List[str], |
|
|
|
conditional_sigma: bool = False, |
|
|
|
tanh_squash: bool = False, |
|
|
|
): |
|
|
|
super().__init__( |
|
|
|
self, |
|
|
|
observation_shapes, |
|
|
|
network_settings, |
|
|
|
act_type, |
|
|
|
act_size, |
|
|
|
conditional_sigma, |
|
|
|
tanh_squash, |
|
|
|
) |
|
|
|
self.stream_names = stream_names |
|
|
|
self.value_heads = ValueHeads(stream_names, self.encoding_size) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size + self.critic.memory_size |
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
actor_mem, critic_mem = None, None |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) |
|
|
|
value_outputs, critic_mem_out = self.critic( |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if actor_mem is not None: |
|
|
|
# Make memories with the actor mem unchanged |
|
|
|
memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1) |
|
|
|
else: |
|
|
|
memories_out = None |
|
|
|
return value_outputs, memories_out |
|
|
|
|
|
|
|
def get_dist_and_value( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic and actor |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) |
|
|
|
else: |
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|
dists, actor_mem_outs = self.get_dists( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
memories=actor_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
masks=masks, |
|
|
|
) |
|
|
|
value_outputs, critic_mem_outs = self.critic( |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if self.use_lstm: |
|
|
|
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) |
|
|
|
else: |
|
|
|
mem_out = None |
|
|
|
return dists, value_outputs, mem_out |
|
|
|
|
|
|
|
################################################################################ |
|
|
|
######### Continuous xor Discrete cases ########## |
|
|
|
################################################################################ |
|
|
|
class SimpleActor(nn.Module, Actor): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|