浏览代码

Remove stuff in policy

/develop/jit/experiments
Ervin Teng 4 年前
当前提交
f59f35ea
共有 1 个文件被更改,包括 10 次插入10 次删除
  1. 20
      ml-agents/mlagents/trainers/policy/torch_policy.py

20
ml-agents/mlagents/trainers/policy/torch_policy.py


self.m_size = self.actor_critic.memory_size
self.actor_critic.to(default_device())
dummy_vec, dummy_vis, dummy_masks, dummy_mem = ModelUtils.create_dummy_input(
self
)
dist_val_dummy = (dummy_vec, dummy_vis, dummy_masks, dummy_mem)
critic_pass_dummy = (dummy_vec, dummy_vis, dummy_mem)
# example_in = {"get_dist_and_value": dist_val_dummy, "critic_pass": critic_pass_dummy}
# dummy_vec, dummy_vis, dummy_masks, dummy_mem = ModelUtils.create_dummy_input(
# self
# )
# dist_val_dummy = (dummy_vec, dummy_vis, dummy_masks, dummy_mem)
# critic_pass_dummy = (dummy_vec, dummy_vis, dummy_mem)
# # example_in = {"get_dist_and_value": dist_val_dummy, "critic_pass": critic_pass_dummy}
self.sample_actions = torch.jit.trace(
self.sample_actions,
dist_val_dummy,
)
# self.sample_actions = torch.jit.trace(
# self.sample_actions,
# dist_val_dummy,
# )
@property
def export_memory_size(self) -> int:

正在加载...
取消
保存