|
|
|
|
|
|
loss = torch.mean(loss) |
|
|
|
return loss |
|
|
|
|
|
|
|
def get_prediction(self, inputs: List[torch.Tensor]) -> torch.Tensor: |
|
|
|
prediction, _ = self.forward(inputs) |
|
|
|
prediction = self.surrogate_predictor(prediction) |
|
|
|
return prediction |
|
|
|
|
|
|
|
class ValueNetwork(nn.Module): |
|
|
|
def __init__( |
|
|
|
|
|
|
action_out_deprecated, |
|
|
|
) = self.action_model.get_action_out(encoding, masks) |
|
|
|
export_out = [self.version_number, self.memory_size_vector] |
|
|
|
if self.action_spec.continuous_size > 0: |
|
|
|
export_out += [cont_action_out, self.continuous_act_size_vector] |
|
|
|
if True: |
|
|
|
# export_out += [cont_action_out, self.continuous_act_size_vector] |
|
|
|
export_out += [self.network_body.get_prediction(inputs), torch.nn.Parameter( |
|
|
|
torch.Tensor([int(9)]), requires_grad=False |
|
|
|
)] |
|
|
|
# Only export deprecated nodes with non-hybrid action spec |
|
|
|
if self.action_spec.continuous_size == 0 or self.action_spec.discrete_size == 0: |
|
|
|
export_out += [ |
|
|
|
action_out_deprecated, |
|
|
|
self.is_continuous_int_deprecated, |
|
|
|
self.act_size_vector_deprecated, |
|
|
|
] |
|
|
|
# # Only export deprecated nodes with non-hybrid action spec |
|
|
|
# if self.action_spec.continuous_size == 0 or self.action_spec.discrete_size == 0: |
|
|
|
# export_out += [ |
|
|
|
# action_out_deprecated, |
|
|
|
# self.is_continuous_int_deprecated, |
|
|
|
# self.act_size_vector_deprecated, |
|
|
|
# ] |
|
|
|
return tuple(export_out) |
|
|
|
|
|
|
|
|
|
|
|