|
|
|
|
|
|
self.m_size = self.actor_critic.memory_size |
|
|
|
|
|
|
|
self.actor_critic.to(default_device()) |
|
|
|
dummy_vec, dummy_vis, dummy_masks, dummy_mem = ModelUtils.create_dummy_input( |
|
|
|
self |
|
|
|
) |
|
|
|
dist_val_dummy = (dummy_vec, dummy_vis, dummy_masks, dummy_mem) |
|
|
|
critic_pass_dummy = (dummy_vec, dummy_vis, dummy_mem) |
|
|
|
# example_in = {"get_dist_and_value": dist_val_dummy, "critic_pass": critic_pass_dummy} |
|
|
|
# dummy_vec, dummy_vis, dummy_masks, dummy_mem = ModelUtils.create_dummy_input( |
|
|
|
# self |
|
|
|
# ) |
|
|
|
# dist_val_dummy = (dummy_vec, dummy_vis, dummy_masks, dummy_mem) |
|
|
|
# critic_pass_dummy = (dummy_vec, dummy_vis, dummy_mem) |
|
|
|
# # example_in = {"get_dist_and_value": dist_val_dummy, "critic_pass": critic_pass_dummy} |
|
|
|
self.sample_actions = torch.jit.trace( |
|
|
|
self.sample_actions, |
|
|
|
dist_val_dummy, |
|
|
|
) |
|
|
|
# self.sample_actions = torch.jit.trace( |
|
|
|
# self.sample_actions, |
|
|
|
# dist_val_dummy, |
|
|
|
# ) |
|
|
|
|
|
|
|
@property |
|
|
|
def export_memory_size(self) -> int: |
|
|
|