|
|
|
|
|
|
encoding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
dists = self.distribution(encoding, masks) |
|
|
|
return dists, memories |
|
|
|
continuous_dists, discrete_dists = self.distribution(encoding, masks) |
|
|
|
return continuous_dists, discrete_dists, memories |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
|
|
|
""" |
|
|
|
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. |
|
|
|
""" |
|
|
|
# TODO: This is bad right now |
|
|
|
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1) |
|
|
|
action_out = torch.cat([dist.exported_model_output() for dist in dists], dim=1) |
|
|
|
# TODO: How this is written depends on how the inference model is structured |
|
|
|
continuous_dists, discrete_dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1) |
|
|
|
action_out = torch.cat([dist.exported_model_output() for dist in continuous_dists + discrete_dists], dim=1) |
|
|
|
return ( |
|
|
|
action_out, |
|
|
|
self.version_number, |
|
|
|
|
|
|
encoding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
dists = self.distribution(encoding, masks) |
|
|
|
continuous_dists, discrete_dists = self.distribution(encoding, masks) |
|
|
|
return dists, value_outputs, memories |
|
|
|
return continuous_dists, discrete_dists, value_outputs, memories |
|
|
|
|
|
|
|
|
|
|
|
class SeparateActorCritic(HybridSimpleActor, ActorCritic): |
|
|
|
|
|
|
else: |
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|
dists, actor_mem_outs = self.get_dists( |
|
|
|
continuous_dists, discrete_dists, actor_mem_outs = self.get_dists( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
memories=actor_mem, |
|
|
|
|
|
|
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) |
|
|
|
else: |
|
|
|
mem_out = None |
|
|
|
return dists, value_outputs, mem_out |
|
|
|
return continuous_dists, discrete_dists, value_outputs, mem_out |
|
|
|
|
|
|
|
|
|
|
|
################################################################################ |
|
|
|