|
|
|
|
|
|
memories=memories, |
|
|
|
seq_len=policy.sequence_length, |
|
|
|
) |
|
|
|
assert log_probs.shape == (64, policy.action_spec.size) |
|
|
|
assert entropy.shape == (64, policy.action_spec.size) |
|
|
|
assert log_probs.shape == (64, policy.behavior_spec.action_spec.size) |
|
|
|
assert entropy.shape == (64, policy.behavior_spec.action_spec.size) |
|
|
|
for val in values.values(): |
|
|
|
assert val.shape == (64,) |
|
|
|
|
|
|
|
|
|
|
all_log_probs=not policy.use_continuous_act, |
|
|
|
) |
|
|
|
if discrete: |
|
|
|
assert log_probs.shape == (64, sum(policy.action_spec.discrete_branches)) |
|
|
|
assert log_probs.shape == ( |
|
|
|
64, |
|
|
|
sum(policy.behavior_spec.action_spec.discrete_branches), |
|
|
|
) |
|
|
|
assert log_probs.shape == (64, policy.action_spec.continuous_size) |
|
|
|
assert entropies.shape == (64, policy.action_spec.size) |
|
|
|
assert log_probs.shape == (64, policy.behavior_spec.action_spec.continuous_size) |
|
|
|
assert entropies.shape == (64, policy.behavior_spec.action_spec.size) |
|
|
|
|
|
|
|
if rnn: |
|
|
|
assert memories.shape == (1, 1, policy.m_size) |