|
|
|
|
|
|
obs_size = 4 |
|
|
|
seq_len = 16 |
|
|
|
network_settings = NetworkSettings( |
|
|
|
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=4) |
|
|
|
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=12) |
|
|
|
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) |
|
|
|
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4) |
|
|
|
for _ in range(100): |
|
|
|
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 4)) |
|
|
|
for _ in range(200): |
|
|
|
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 12)) |
|
|
|
# Try to force output to 1 |
|
|
|
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
# memories isn't always set to None, the network should be able to |
|
|
|
# deal with that. |
|
|
|
# Test critic pass |
|
|
|
value_out = actor.critic_pass([sample_obs], [], memories=memories) |
|
|
|
value_out, memories_out = actor.critic_pass([sample_obs], [], memories=memories) |
|
|
|
assert memories_out.shape == memories.shape |
|
|
|
else: |
|
|
|
assert value_out[stream].shape == (1,) |
|
|
|
|
|
|
|