|
|
|
|
|
|
if lstm: |
|
|
|
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size)) |
|
|
|
memories = torch.ones( |
|
|
|
( |
|
|
|
1, |
|
|
|
network_settings.memory.sequence_length, |
|
|
|
network_settings.memory.memory_size, |
|
|
|
) |
|
|
|
(1, network_settings.memory.sequence_length, actor.memory_size) |
|
|
|
) |
|
|
|
else: |
|
|
|
sample_obs = torch.ones((1, obs_size)) |
|
|
|