|
|
|
|
|
|
return (tensor.T * masks).sum() / torch.clamp( |
|
|
|
(torch.ones_like(tensor.T) * masks).float().sum(), min=1.0 |
|
|
|
) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def create_dummy_input(policy): |
|
|
|
batch_dim = [1] |
|
|
|
seq_len_dim = [1] |
|
|
|
dummy_vec_obs = [torch.zeros(batch_dim + [policy.vec_obs_size])] |
|
|
|
# create input shape of NCHW |
|
|
|
# (It's NHWC in self.policy.behavior_spec.observation_shapes) |
|
|
|
dummy_vis_obs = [ |
|
|
|
torch.zeros(batch_dim + [shape[2], shape[0], shape[1]]) |
|
|
|
for shape in policy.behavior_spec.observation_shapes |
|
|
|
if len(shape) == 3 |
|
|
|
] |
|
|
|
dummy_masks = torch.ones(batch_dim + [sum(policy.actor_critic.act_size)]) |
|
|
|
dummy_memories = torch.zeros( |
|
|
|
batch_dim + seq_len_dim + [policy.export_memory_size] |
|
|
|
) |
|
|
|
|
|
|
|
return dummy_vec_obs, torch.Tensor(dummy_vis_obs), dummy_masks, dummy_memories |