|
|
|
|
|
|
def __init__(self, policy): |
|
|
|
self.policy = policy |
|
|
|
dummy_vec_obs = [torch.zeros([1] + [self.policy.vec_obs_size])] |
|
|
|
dummy_vis_obs = [torch.zeros([1] + self.policy.vis_obs_shape)] \ |
|
|
|
dummy_vis_obs = [torch.zeros([1] + list(self.policy.vis_obs_shape))] \ |
|
|
|
dummy_masks = [torch.ones([1] + self.policy.actor_critic.act_size)] |
|
|
|
dummy_memories = [torch.zeros([1] + [self.policy.m_size])] |
|
|
|
dummy_sequence_length = [torch.tensor([self.policy.sequence_length])] |
|
|
|
dummy_masks = torch.ones([1] + self.policy.actor_critic.act_size) |
|
|
|
dummy_memories = torch.zeros([1] + [self.policy.m_size]) |
|
|
|
"action_mask", "memories", "sequence_length"] |
|
|
|
"action_mask", "memories"] |
|
|
|
"action_mask": [0], "memories": [0], "action": [0],"action_probs": [0]} |
|
|
|
"action_mask": [0], "memories": [0], "action": [0],"action_probs": [0]} |
|
|
|
dummy_masks, dummy_memories, dummy_sequence_length) |
|
|
|
dummy_masks, dummy_memories) |
|
|
|
|
|
|
|
def export_policy_model(self, output_filepath: str) -> None: |
|
|
|
""" |
|
|
|