浏览代码

Some experimental stuff

/develop/jit/experiments
Ervin Teng 4 年前
当前提交
fdc887a1
共有 3 个文件被更改,包括 34 次插入1 次删除
  1. 5
      ml-agents/mlagents/trainers/learn.py
  2. 11
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 19
      ml-agents/mlagents/trainers/torch/utils.py

5
ml-agents/mlagents/trainers/learn.py


import os
import numpy as np
import json
import cProfile
from typing import Callable, Optional, List

def main():
run_cli(parse_command_line())
with cProfile.Profile() as pr:
run_cli(parse_command_line())
pr.dump_stats("pytorch_sac.prof")
# For python debugger to directly run this script

11
ml-agents/mlagents/trainers/policy/torch_policy.py


self.m_size = self.actor_critic.memory_size
self.actor_critic.to(default_device())
dummy_vec, dummy_vis, dummy_masks, dummy_mem = ModelUtils.create_dummy_input(
self
)
dist_val_dummy = (dummy_vec, dummy_vis, dummy_masks, dummy_mem)
critic_pass_dummy = (dummy_vec, dummy_vis, dummy_mem)
# example_in = {"get_dist_and_value": dist_val_dummy, "critic_pass": critic_pass_dummy}
self.sample_actions = torch.jit.trace(
self.sample_actions,
dist_val_dummy,
)
@property
def export_memory_size(self) -> int:

19
ml-agents/mlagents/trainers/torch/utils.py


return (tensor.T * masks).sum() / torch.clamp(
(torch.ones_like(tensor.T) * masks).float().sum(), min=1.0
)
@staticmethod
def create_dummy_input(policy):
batch_dim = [1]
seq_len_dim = [1]
dummy_vec_obs = [torch.zeros(batch_dim + [policy.vec_obs_size])]
# create input shape of NCHW
# (It's NHWC in self.policy.behavior_spec.observation_shapes)
dummy_vis_obs = [
torch.zeros(batch_dim + [shape[2], shape[0], shape[1]])
for shape in policy.behavior_spec.observation_shapes
if len(shape) == 3
]
dummy_masks = torch.ones(batch_dim + [sum(policy.actor_critic.act_size)])
dummy_memories = torch.zeros(
batch_dim + seq_len_dim + [policy.export_memory_size]
)
return dummy_vec_obs, torch.Tensor(dummy_vis_obs), dummy_masks, dummy_memories
正在加载...
取消
保存