|
|
|
|
|
|
import os |
|
|
|
import pytest |
|
|
|
from unittest.mock import patch |
|
|
|
|
|
|
|
import pytest |
|
|
|
from unittest.mock import patch |
|
|
|
|
|
|
|
from mlagents_envs.base_env import BehaviorSpec, ActionSpec |
|
|
|
|
|
|
|
import os |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.demonstrations.demonstration_provider import ( |
|
|
|
DemonstrationProvider, |
|
|
|
) |
|
|
|
from mlagents_envs.base_env import BehaviorSpec, ActionSpec |
|
|
|
from mlagents.trainers.settings import GAILSettings, RewardSignalType |
|
|
|
from mlagents.trainers.tests.torch.test_reward_providers.utils import ( |
|
|
|
create_agent_buffer, |
|
|
|
|
|
|
ACTIONSPEC_DISCRETE = ActionSpec.create_discrete((20,)) |
|
|
|
|
|
|
|
|
|
|
|
class MockDemonstrationProvider(DemonstrationProvider): |
|
|
|
def __init__(self, behavior_spec, buffer): |
|
|
|
self._behavior_spec = behavior_spec |
|
|
|
self._buffer = buffer |
|
|
|
|
|
|
|
def get_behavior_spec(self) -> BehaviorSpec: |
|
|
|
return self._behavior_spec |
|
|
|
|
|
|
|
def pop_trajectories(self): |
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
def to_agentbuffer(self, training_length: int) -> AgentBuffer: |
|
|
|
return self._buffer |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
"behavior_spec", |
|
|
|
[BehaviorSpec(create_observation_specs_with_shapes([(8,)]), ACTIONSPEC_CONTINUOUS)], |
|
|
|
|
|
|
], |
|
|
|
) |
|
|
|
@pytest.mark.parametrize("use_actions", [False, True]) |
|
|
|
@patch( |
|
|
|
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer" |
|
|
|
) |
|
|
|
@patch.object(GAILRewardProvider, "_get_demonstration_provider") |
|
|
|
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int |
|
|
|
mock_demo_provider: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int |
|
|
|
demo_to_buffer.return_value = None, buffer_expert |
|
|
|
mock_demo_provider.return_value = MockDemonstrationProvider( |
|
|
|
behavior_spec, buffer_expert |
|
|
|
) |
|
|
|
gail_settings = GAILSettings( |
|
|
|
demo_path="", learning_rate=0.005, use_vail=False, use_actions=use_actions |
|
|
|
) |
|
|
|
|
|
|
], |
|
|
|
) |
|
|
|
@pytest.mark.parametrize("use_actions", [False, True]) |
|
|
|
@patch( |
|
|
|
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer" |
|
|
|
) |
|
|
|
@patch.object(GAILRewardProvider, "_get_demonstration_provider") |
|
|
|
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int |
|
|
|
mock_demo_provider: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int |
|
|
|
demo_to_buffer.return_value = None, buffer_expert |
|
|
|
mock_demo_provider.return_value = MockDemonstrationProvider( |
|
|
|
behavior_spec, buffer_expert |
|
|
|
) |
|
|
|
gail_settings = GAILSettings( |
|
|
|
demo_path="", learning_rate=0.005, use_vail=True, use_actions=use_actions |
|
|
|
) |
|
|
|