浏览代码

more use of item() and additional tests

/develop/torch-to-np
vincentpierre 4 年前
当前提交
fdd343b2
共有 3 个文件被更改,包括 16 次插入17 次删除
  1. 6
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  2. 21
      ml-agents/mlagents/trainers/tests/torch/test_bcmodule.py
  3. 6
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py

6
ml-agents/mlagents/trainers/sac/optimizer_torch.py


self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau)
update_stats = {
"Losses/Policy Loss": policy_loss.item(),
"Losses/Value Loss": ModelUtils.to_numpy(value_loss),
"Losses/Q1 Loss": ModelUtils.to_numpy(q1_loss),
"Losses/Q2 Loss": ModelUtils.to_numpy(q2_loss),
"Losses/Value Loss": value_loss.item(),
"Losses/Q1 Loss": q1_loss.item(),
"Losses/Q2 Loss": q2_loss.item(),
"Policy/Entropy Coeff": torch.exp(self._log_ent_coef).item(),
"Policy/Learning Rate": decay_lr,
}

21
ml-agents/mlagents/trainers/tests/torch/test_bcmodule.py


import pytest
import mlagents.trainers.tests.mock_brain as mb
import numpy as np
import os
from mlagents.trainers.policy.torch_policy import TorchPolicy

default_num_epoch=3,
)
return bc_module
def assert_stats_are_float(stats):
for _, item in stats.items():
assert isinstance(item, float)
# Test default values

)
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
stats = bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)
assert_stats_are_float(stats)
# Test with constant pretraining learning rate

)
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
stats = bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)
assert_stats_are_float(stats)
old_learning_rate = bc_module.current_lr
_ = bc_module.update()

)
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac)
stats = bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)
assert_stats_are_float(stats)
# Test with discrete control and visual observations

)
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
stats = bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)
assert_stats_are_float(stats)
# Test with discrete control, visual observations and RNN

)
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac)
stats = bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)
assert_stats_are_float(stats)
if __name__ == "__main__":

6
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py


buffer = create_agent_buffer(behavior_spec, 5)
for _ in range(200):
curiosity_rp.update(buffer)
prediction = ModelUtils.to_numpy(curiosity_rp._network.predict_action(buffer)[0])
target = buffer["actions"][0]
error = float(torch.mean((prediction - target) ** 2))
prediction = curiosity_rp._network.predict_action(buffer)[0]
target = torch.tensor(buffer["actions"][0])
error = torch.mean((prediction - target) ** 2).item()
assert error < 0.001

正在加载...
取消
保存