|
|
|
|
|
|
|
|
|
|
trainer_settings = attr.evolve(dummy_config) |
|
|
|
trainer_settings.network_settings.memory = ( |
|
|
|
NetworkSettings.MemorySettings(sequence_length=8, memory_size=10) |
|
|
|
NetworkSettings.MemorySettings(sequence_length=16, memory_size=10) |
|
|
|
if use_rnn |
|
|
|
else None |
|
|
|
) |
|
|
|
|
|
|
) |
|
|
|
# Time horizon is longer than sequence length, make sure to test |
|
|
|
# process trajectory on multiple sequences in trajectory + some padding |
|
|
|
time_horizon = 15 |
|
|
|
time_horizon = 30 |
|
|
|
trajectory = make_fake_trajectory( |
|
|
|
length=time_horizon, |
|
|
|
observation_specs=optimizer.policy.behavior_spec.observation_specs, |
|
|
|
|
|
|
|
|
|
|
for key, val in run_out.items(): |
|
|
|
assert type(key) is str |
|
|
|
assert len(val) == 15 |
|
|
|
assert len(val) == time_horizon |
|
|
|
assert len(all_memories) == 15 |
|
|
|
assert len(all_memories) == time_horizon |
|
|
|
|
|
|
|
run_out, final_value_out, _ = optimizer.get_trajectory_value_estimates( |
|
|
|
trajectory.to_agentbuffer(), trajectory.next_obs, done=True |
|
|
|