|
|
|
|
|
|
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. |
|
|
|
""" |
|
|
|
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1) |
|
|
|
if self.act_type == ActionType.CONTINUOUS: |
|
|
|
action_list = self.sample_action(dists) |
|
|
|
action_out = torch.stack(action_list, dim=-1) |
|
|
|
else: |
|
|
|
action_out = torch.cat([dist.all_log_prob() for dist in dists], dim=1) |
|
|
|
action_out = torch.cat([dist.action_out() for dist in dists], dim=1) |
|
|
|
return ( |
|
|
|
action_out, |
|
|
|
self.version_number, |
|
|
|