浏览代码

-

/exp-vince
vincentpierre 4 年前
当前提交
a8137478
共有 2 个文件被更改,包括 14 次插入4 次删除
  1. 8
      ml-agents/mlagents/trainers/learn.py
  2. 10
      ml-agents/mlagents/trainers/trainer/rl_trainer.py

8
ml-agents/mlagents/trainers/learn.py


target_frame_rate=engine_settings.target_frame_rate,
capture_frame_rate=engine_settings.capture_frame_rate,
)
# env_manager = SubprocessEnvManager(
# env_factory, engine_config, env_settings.num_envs
# )
env_manager = SubprocessEnvManager(
env_factory, engine_config, env_settings.num_envs
)
env_manager = SimpleEnvManager(env_factory(0, []), env_parameter_manager)
# env_manager = SimpleEnvManager(env_factory(0, []), env_parameter_manager)
trainer_factory = TrainerFactory(
trainer_config=options.behaviors,

10
ml-agents/mlagents/trainers/trainer/rl_trainer.py


from pympler import muppy, summary
import psutil
import os
import torch
import gc
class RLTrainer(Trainer): # pylint: disable=abstract-method

diff = summary.get_diff( self.past_sum, sum1)
summary.print_(diff)
self.past_sum = sum1
tmp = 0
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
tmp+=1
except:
pass
print("Total number of tensors", tmp)

正在加载...
取消
保存