|
|
|
|
|
|
# cause problem to barracuda import. |
|
|
|
self.policy = policy |
|
|
|
batch_dim = [1] |
|
|
|
seq_len_dim = [1] |
|
|
|
dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])] |
|
|
|
# create input shape of NCHW |
|
|
|
# (It's NHWC in self.policy.behavior_spec.observation_shapes) |
|
|
|
|
|
|
if len(shape) == 3 |
|
|
|
] |
|
|
|
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)]) |
|
|
|
dummy_memories = torch.zeros( |
|
|
|
batch_dim + seq_len_dim + [self.policy.export_memory_size] |
|
|
|
) |
|
|
|
# Assume sequence length is 1 |
|
|
|
dummy_memories = torch.zeros(batch_dim + [self.policy.export_memory_size]) |
|
|
|
|
|
|
|
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories) |
|
|
|
|
|
|
|