|
|
|
|
|
|
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size]) |
|
|
|
|
|
|
|
self.input_names = [] |
|
|
|
self.dynamic_axes = {"action": {0: 'batch'}, "action_probs": {0: 'batch'}} |
|
|
|
self.dynamic_axes = {"action": {0: "batch"}, "action_probs": {0: "batch"}} |
|
|
|
self.dynamic_axes.update({"vector_observation": {0: 'batch'}}) |
|
|
|
self.dynamic_axes.update({"vector_observation": {0: "batch"}}) |
|
|
|
self.dynamic_axes.update({"visual_observation": {0: 'batch'}}) |
|
|
|
self.dynamic_axes.update({"visual_observation": {0: "batch"}}) |
|
|
|
self.input_names.append("action_mask") |
|
|
|
self.dynamic_axes.update({"action_mask": {0: 'batch'}}) |
|
|
|
self.input_names.append("action_masks") |
|
|
|
self.dynamic_axes.update({"action_masks": {0: "batch"}}) |
|
|
|
self.dynamic_axes.update({"memories": {0: 'batch'}}) |
|
|
|
self.dynamic_axes.update({"memories": {0: "batch"}}) |
|
|
|
|
|
|
|
self.output_names = [ |
|
|
|
"action", |
|
|
|