|
|
|
|
|
|
self.is_continuous_int = torch.nn.Parameter( |
|
|
|
torch.Tensor([int(act_type == ActionType.CONTINUOUS)]) |
|
|
|
) |
|
|
|
self.act_size_vector = torch.nn.Parameter(torch.Tensor(act_size)) |
|
|
|
self.act_size_vector = torch.nn.Parameter( |
|
|
|
torch.Tensor([sum(act_size)]), requires_grad=False |
|
|
|
) |
|
|
|
self.network_body = NetworkBody(observation_shapes, network_settings) |
|
|
|
if network_settings.memory is not None: |
|
|
|
self.encoding_size = network_settings.memory.memory_size // 2 |
|
|
|
|
|
|
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. |
|
|
|
""" |
|
|
|
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1) |
|
|
|
action_list = self.sample_action(dists) |
|
|
|
sampled_actions = torch.stack(action_list, dim=-1) |
|
|
|
action_out = sampled_actions |
|
|
|
action_list = self.sample_action(dists) |
|
|
|
action_out = torch.stack(action_list, dim=-1) |
|
|
|
action_out = dists[0].all_log_prob() |
|
|
|
action_out = torch.cat([dist.all_log_prob() for dist in dists], dim=1) |
|
|
|
return ( |
|
|
|
action_out, |
|
|
|
self.version_number, |
|
|
|