浏览代码

Fix network tests

/develop/add-fire/memoryclass
Ervin Teng 4 年前
当前提交
b107a8d5
共有 1 个文件被更改,包括 1 次插入5 次删除
  1. 6
      ml-agents/mlagents/trainers/tests/torch/test_networks.py

6
ml-agents/mlagents/trainers/tests/torch/test_networks.py


if lstm:
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
memories = torch.ones(
(
1,
network_settings.memory.sequence_length,
network_settings.memory.memory_size,
)
(1, network_settings.memory.sequence_length, actor.memory_size)
)
else:
sample_obs = torch.ones((1, obs_size))

正在加载...
取消
保存