浏览代码

Fix LSTM tests

/develop/add-fire/sac-lst
Ervin Teng 4 年前
当前提交
d56e53bb
共有 1 个文件被更改,包括 6 次插入5 次删除
  1. 11
      ml-agents/mlagents/trainers/tests/torch/test_networks.py

11
ml-agents/mlagents/trainers/tests/torch/test_networks.py


obs_size = 4
seq_len = 16
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=4)
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=12)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4)
for _ in range(100):
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 4))
for _ in range(200):
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 12))
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()

# memories isn't always set to None, the network should be able to
# deal with that.
# Test critic pass
value_out = actor.critic_pass([sample_obs], [], memories=memories)
value_out, memories_out = actor.critic_pass([sample_obs], [], memories=memories)
assert memories_out.shape == memories.shape
else:
assert value_out[stream].shape == (1,)

正在加载...
取消
保存