|
|
|
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.action_spec = action_spec |
|
|
|
self.version_number = torch.nn.Parameter(torch.Tensor([2.0])) |
|
|
|
self.version_number = torch.nn.Parameter( |
|
|
|
torch.Tensor([2.0]), requires_grad=False |
|
|
|
) |
|
|
|
torch.Tensor([int(self.action_spec.is_continuous())]) |
|
|
|
torch.Tensor([int(self.action_spec.is_continuous())]), requires_grad=False |
|
|
|
) |
|
|
|
self.continuous_act_size_vector = torch.nn.Parameter( |
|
|
|
torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False |
|
|
|
|
|
|
self.encoding_size = network_settings.memory.memory_size // 2 |
|
|
|
else: |
|
|
|
self.encoding_size = network_settings.hidden_units |
|
|
|
self.memory_size_vector = torch.nn.Parameter( |
|
|
|
torch.Tensor([int(self.network_body.memory_size)]), requires_grad=False |
|
|
|
) |
|
|
|
|
|
|
|
self.action_model = ActionModel( |
|
|
|
self.encoding_size, |
|
|
|
|
|
|
disc_action_out, |
|
|
|
action_out_deprecated, |
|
|
|
) = self.action_model.get_action_out(encoding, masks) |
|
|
|
export_out = [ |
|
|
|
self.version_number, |
|
|
|
torch.Tensor([self.network_body.memory_size]), |
|
|
|
] |
|
|
|
export_out = [self.version_number, self.memory_size_vector] |
|
|
|
if self.action_spec.continuous_size > 0: |
|
|
|
export_out += [cont_action_out, self.continuous_act_size_vector] |
|
|
|
if self.action_spec.discrete_size > 0: |
|
|
|