|
|
|
|
|
|
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|
|
|
from mlagents.trainers.tests import mock_brain as mb |
|
|
|
from mlagents.trainers.settings import TrainerSettings, NetworkSettings |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils, AgentAction |
|
|
|
|
|
|
|
VECTOR_ACTION_SPACE = 2 |
|
|
|
VECTOR_OBS_SPACE = 8 |
|
|
|
|
|
|
|
|
|
|
run_out = policy.evaluate(decision_step, list(decision_step.agent_id)) |
|
|
|
if discrete: |
|
|
|
run_out["action"].shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE)) |
|
|
|
run_out["action"]["discrete_action"].shape == ( |
|
|
|
NUM_AGENTS, |
|
|
|
len(DISCRETE_ACTION_SPACE), |
|
|
|
) |
|
|
|
assert run_out["action"].shape == (NUM_AGENTS, VECTOR_ACTION_SPACE) |
|
|
|
assert run_out["action"]["continuous_action"].shape == ( |
|
|
|
NUM_AGENTS, |
|
|
|
VECTOR_ACTION_SPACE, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|
|
|
|
|
|
buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size) |
|
|
|
vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])] |
|
|
|
act_masks = ModelUtils.list_to_tensor(buffer["action_mask"]) |
|
|
|
if policy.use_continuous_act: |
|
|
|
actions = ModelUtils.list_to_tensor(buffer["actions"]).unsqueeze(-1) |
|
|
|
else: |
|
|
|
actions = ModelUtils.list_to_tensor(buffer["actions"], dtype=torch.long) |
|
|
|
agent_action = AgentAction.from_dict(buffer) |
|
|
|
vis_obs = [] |
|
|
|
for idx, _ in enumerate(policy.actor_critic.network_body.visual_processors): |
|
|
|
vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx]) |
|
|
|
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
masks=act_masks, |
|
|
|
actions=actions, |
|
|
|
actions=agent_action, |
|
|
|
memories=memories, |
|
|
|
seq_len=policy.sequence_length, |
|
|
|
) |
|
|
|
|
|
|
_size = policy.behavior_spec.action_spec.continuous_size |
|
|
|
|
|
|
|
assert log_probs.shape == (64, _size) |
|
|
|
assert log_probs.flatten().shape == (64, _size) |
|
|
|
assert entropy.shape == (64,) |
|
|
|
for val in values.values(): |
|
|
|
assert val.shape == (64,) |
|
|
|
|
|
|
masks=act_masks, |
|
|
|
memories=memories, |
|
|
|
seq_len=policy.sequence_length, |
|
|
|
all_log_probs=not policy.use_continuous_act, |
|
|
|
assert log_probs.shape == ( |
|
|
|
assert log_probs.all_discrete_tensor.shape == ( |
|
|
|
assert log_probs.shape == (64, policy.behavior_spec.action_spec.continuous_size) |
|
|
|
assert log_probs.continuous_tensor.shape == ( |
|
|
|
64, |
|
|
|
policy.behavior_spec.action_spec.continuous_size, |
|
|
|
) |
|
|
|
assert entropies.shape == (64,) |
|
|
|
|
|
|
|
if rnn: |