浏览代码

Changes for speed test

/develop/add-fire/speedtest
Ervin Teng 5 年前
当前提交
f214836a
共有 4 个文件被更改,包括 40 次插入2 次删除
  1. 1
      config/ppo/3DBall.yaml
  2. 4
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 8
      ml-agents/mlagents/trainers/ppo/trainer.py
  4. 29
      docs/test_env_perf.py

1
config/ppo/3DBall.yaml


summary_freq: 12000
use_recurrent: false
vis_encode_type: simple
threaded: false
reward_signals:
extrinsic:
strength: 1.0

4
ml-agents/mlagents/trainers/policy/torch_policy.py


EPSILON = 1e-7 # Small value to avoid divide by zero
torch.set_num_threads(1)
class TorchPolicy(Policy):
def __init__(

self.inference_dict: Dict[str, tf.Tensor] = {}
self.update_dict: Dict[str, tf.Tensor] = {}
# TF defaults to 32-bit, so we use the same here.
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_tensor_type(torch.FloatTensor)
reward_signal_configs = trainer_params["reward_signals"]
self.stats_name_to_update_name = {

8
ml-agents/mlagents/trainers/ppo/trainer.py


# Contains an implementation of PPO as described in: https://arxiv.org/abs/1707.06347
from collections import defaultdict
import time
import numpy as np
from mlagents.trainers.policy import Policy

from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
logger = get_logger(__name__)
TIMINGS = []
class PPOTrainer(RLTrainer):

self.seed = seed
self.framework = "torch"
self.policy: Policy = None # type: ignore
self.update_times = []
def _check_param_keys(self):
super()._check_param_keys()

buffer = self.update_buffer
max_num_batch = buffer_length // batch_size
for i in range(0, max_num_batch * batch_size, batch_size):
t1 = time.perf_counter()
t2 = time.perf_counter()
TIMINGS.append(t2 - t1)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)

29
docs/test_env_perf.py


from mlagents.trainers.tests.test_simple_rl import (
_check_environment_trains,
PPO_CONFIG,
generate_config,
)
from mlagents.trainers.tests.simple_test_envs import SimpleEnvironment
from mlagents.trainers.ppo.trainer import TIMINGS
import matplotlib.pyplot as plt
import numpy as np
BRAIN_NAME = "1D"
if __name__ == "__main__":
env = SimpleEnvironment([BRAIN_NAME], use_discrete=False)
config = generate_config(
PPO_CONFIG,
override_vals={"batch_size": 256, "max_steps": 20000, "buffer_size": 1024},
)
try:
_check_environment_trains(env, config)
except Exception:
pass
print(f"Mean update time {np.mean(TIMINGS)}")
plt.plot(TIMINGS)
plt.ylim((0, 0.006))
plt.title("PyTorch w/ 3DBall Running, batch size 256, 32 hidden units, 1 layer")
plt.ylabel("Update Time (s)")
plt.ylabel("Update #")
plt.show()
正在加载...
取消
保存