浏览代码

Lots of test fixes

/develop-newnormalization
Ervin Teng 5 年前
当前提交
27c2a55b
共有 10 个文件被更改,包括 107 次插入114 次删除
  1. 3
      ml-agents/mlagents/trainers/ppo/policy.py
  2. 1
      ml-agents/mlagents/trainers/ppo/trainer.py
  3. 38
      ml-agents/mlagents/trainers/rl_trainer.py
  4. 27
      ml-agents/mlagents/trainers/sac/trainer.py
  5. 2
      ml-agents/mlagents/trainers/tests/mock_brain.py
  6. 4
      ml-agents/mlagents/trainers/tests/test_agent_processor.py
  7. 58
      ml-agents/mlagents/trainers/tests/test_ppo.py
  8. 38
      ml-agents/mlagents/trainers/tests/test_sac.py
  9. 22
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py
  10. 28
      ml-agents/mlagents/trainers/tests/test_trajectory.py

3
ml-agents/mlagents/trainers/ppo/policy.py


) -> Dict[str, float]:
"""
Generates value estimates for bootstrapping.
:param brain_info: BrainInfo to be used for bootstrapping.
:param idx: Index in BrainInfo of agent.
:param experience: BootstrapExperience to be used for bootstrapping.
:param done: Whether or not this is the last element of the episode, in which case the value estimate will be 0.
:return: The value estimate dictionary with key being the name of the reward signal and the value the
corresponding value estimate.

1
ml-agents/mlagents/trainers/ppo/trainer.py


self.update_buffer, training_length=self.policy.sequence_length
)
# If this was a terminal trajectory, append stats and reset reward collection
if trajectory.steps[-1].done:
self.stats["Environment/Episode Length"].append(
self.episode_steps.get(agent_id, 0)

38
ml-agents/mlagents/trainers/rl_trainer.py


from collections import defaultdict
import numpy as np
from mlagents.envs.action_info import ActionInfoOutputs
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trainer import Trainer, UnityTrainerException
from mlagents.trainers.components.reward_signals import RewardSignalResult

we're not training, this should be called instead of update_policy.
"""
self.update_buffer.reset_agent()
def add_policy_outputs(
self, take_action_outputs: ActionInfoOutputs, agent_id: str, agent_idx: int
) -> None:
"""
Takes the output of the last action and store it into the training buffer.
We break this out from add_experiences since it is very highly dependent
on the type of trainer.
:param take_action_outputs: The outputs of the Policy's get_action method.
:param agent_id: the Agent we're adding to.
:param agent_idx: the index of the Agent agent_id
"""
raise UnityTrainerException(
"The add_policy_outputs method was not implemented."
)
def add_rewards_outputs(
self,
rewards_out: AllRewardsOutput,
values: Dict[str, np.ndarray],
agent_id: str,
agent_idx: int,
agent_next_idx: int,
) -> None:
"""
Takes the value and evaluated rewards output of the last action and store it
into the training buffer. We break this out from add_experiences since it is very
highly dependent on the type of trainer.
:param take_action_outputs: The outputs of the Policy's get_action method.
:param rewards_dict: Dict of rewards after evaluation
:param agent_id: the Agent we're adding to.
:param agent_idx: the index of the Agent agent_id in the current brain info
:param agent_next_idx: the index of the Agent agent_id in the next brain info
"""
raise UnityTrainerException(
"The add_rewards_outputs method was not implemented."
)

27
ml-agents/mlagents/trainers/sac/trainer.py


import numpy as np
from mlagents.envs.action_info import ActionInfoOutputs
from mlagents.trainers.rl_trainer import RLTrainer, AllRewardsOutput
from mlagents.trainers.rl_trainer import RLTrainer
from mlagents.trainers.trajectory import (
Trajectory,
trajectory_to_agentbuffer,

"Experience replay buffer has {} experiences.".format(
self.update_buffer.num_experiences
)
)
def add_policy_outputs(
self, take_action_outputs: ActionInfoOutputs, agent_id: str, agent_idx: int
) -> None:
"""
Takes the output of the last action and store it into the training buffer.
"""
actions = take_action_outputs["action"]
self.processing_buffer[agent_id]["actions"].append(actions[agent_idx])
def add_rewards_outputs(
self,
rewards_out: AllRewardsOutput,
values: Dict[str, np.ndarray],
agent_id: str,
agent_idx: int,
agent_next_idx: int,
) -> None:
"""
Takes the value output of the last action and store it into the training buffer.
"""
self.processing_buffer[agent_id]["environment_rewards"].append(
rewards_out.environment[agent_next_idx]
)
def process_trajectory(self, trajectory: Trajectory) -> None:

2
ml-agents/mlagents/trainers/tests/mock_brain.py


)
else:
buffer["action_probs"].append(
np.ones(buffer[0]["actions"][0].shape, dtype=np.float32)
np.ones(buffer["actions"][0].shape, dtype=np.float32)
)
buffer["actions_pre"].append(
np.ones(buffer["actions"][0].shape, dtype=np.float32)

4
ml-agents/mlagents/trainers/tests/test_agent_processor.py


return mock_policy
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.add_policy_outputs")
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.add_rewards_outputs")
def test_agentprocessor(add_policy_outputs, add_rewards_outputs, num_vis_obs):
def test_agentprocessor(num_vis_obs):
policy = create_mock_policy()
trainer = mock.Mock()
processor = AgentProcessor(trainer, policy, time_horizon=5)

58
ml-agents/mlagents/trainers/tests/test_ppo.py


from mlagents.trainers.ppo.models import PPOModel
from mlagents.trainers.ppo.trainer import PPOTrainer, discount_rewards
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.trainers.rl_trainer import AllRewardsOutput
from mlagents.trainers.components.reward_signals import RewardSignalResult
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory
@pytest.fixture

trainer.update_policy()
def test_add_rewards_output(dummy_config):
def test_process_trajectory(dummy_config):
brain_params = BrainParameters(
brain_name="test_brain",
vector_observation_space_size=1,

dummy_config["summary_path"] = "./summaries/test_trainer_summary"
dummy_config["model_path"] = "./models/test_trainer_models/TestModel"
trainer = PPOTrainer(brain_params, 0, dummy_config, True, False, 0, "0", False)
rewardsout = AllRewardsOutput(
reward_signals={
"extrinsic": RewardSignalResult(
scaled_reward=np.array([1.0, 1.0], dtype=np.float32),
unscaled_reward=np.array([1.0, 1.0], dtype=np.float32),
)
},
environment=np.array([1.0, 1.0], dtype=np.float32),
trajectory = make_fake_trajectory(
length=15, max_step_complete=True, vec_obs_size=1, num_vis_obs=0, action_space=2
values = {"extrinsic": np.array([[2.0]], dtype=np.float32)}
agent_id = "123"
idx = 0
# make sure that we're grabbing from the next_idx for rewards. If we're not, the test will fail.
next_idx = 1
trainer.add_rewards_outputs(
rewardsout,
values=values,
agent_id=agent_id,
agent_idx=idx,
agent_next_idx=next_idx,
trainer.process_trajectory(trajectory)
# Check that trainer put trajectory in update buffer
assert trainer.update_buffer.num_experiences == 15
# Check that GAE worked
assert (
"advantages" in trainer.update_buffer
and "discounted_returns" in trainer.update_buffer
)
# Check that the stats are being collected as episode isn't complete
for reward in trainer.collected_rewards.values():
for agent in reward.values():
assert agent > 0
# Add a terminal trajectory
trajectory = make_fake_trajectory(
length=15,
max_step_complete=False,
vec_obs_size=1,
num_vis_obs=0,
action_space=2,
assert trainer.processing_buffer[agent_id]["extrinsic_value_estimates"][0] == 2.0
assert trainer.processing_buffer[agent_id]["extrinsic_rewards"][0] == 1.0
trainer.process_trajectory(trajectory)
# Check that the stats are reset as episode is finished
for reward in trainer.collected_rewards.values():
for agent in reward.values():
assert agent == 0
assert len(trainer.stats["Environment/Cumulative Reward"]) > 0
if __name__ == "__main__":

38
ml-agents/mlagents/trainers/tests/test_sac.py


from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.mock_brain import make_brain_parameters
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory
@pytest.fixture

# Wipe Trainer and try to load
trainer2 = SACTrainer(mock_brain, 1, trainer_params, True, True, 0, 0)
assert trainer2.update_buffer.num_experiences == buffer_len
def test_process_trajectory(dummy_config):
brain_params = make_brain_parameters(
discrete_action=False, visual_inputs=0, vec_obs_size=6
)
dummy_config["summary_path"] = "./summaries/test_trainer_summary"
dummy_config["model_path"] = "./models/test_trainer_models/TestModel"
trainer = SACTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
trajectory = make_fake_trajectory(
length=15, max_step_complete=True, vec_obs_size=6, num_vis_obs=0, action_space=2
)
trainer.process_trajectory(trajectory)
# Check that trainer put trajectory in update buffer
assert trainer.update_buffer.num_experiences == 15
# Check that the stats are being collected as episode isn't complete
for reward in trainer.collected_rewards.values():
for agent in reward.values():
assert agent > 0
# Add a terminal trajectory
trajectory = make_fake_trajectory(
length=15,
max_step_complete=False,
vec_obs_size=6,
num_vis_obs=0,
action_space=2,
)
trainer.process_trajectory(trajectory)
# Check that the stats are reset as episode is finished
for reward in trainer.collected_rewards.values():
for agent in reward.values():
assert agent == 0
assert len(trainer.stats["Environment/Cumulative Reward"]) > 0
if __name__ == "__main__":

22
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


import yaml
import pytest
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.trainer_controller import TrainerController, AgentManager
from mlagents.envs.subprocess_env_manager import EnvironmentStep
from mlagents.envs.sampler_class import SamplerManager

trainer_mock.parameters = {"some": "parameter"}
trainer_mock.write_tensorboard_text = MagicMock()
processor_mock = MagicMock()
tc.managers = {"testbrain": AgentManager(processor=processor_mock)}
return tc, trainer_mock

env_mock.reset.return_value = [old_step_info]
tc.advance(env_mock)
trainer_mock.add_experiences.assert_called_once_with(
processor_mock = tc.managers[brain_name].processor
processor_mock.add_experiences.assert_called_once_with(
trainer_mock.process_experiences.assert_called_once_with(
new_step_info.previous_all_brain_info[brain_name],
new_step_info.current_all_brain_info[brain_name],
)
trainer_mock.update_policy.assert_called_once()
trainer_mock.increment_step.assert_called_once()

tc.advance(env_mock)
env_mock.reset.assert_not_called()
env_mock.step.assert_called_once()
trainer_mock.add_experiences.assert_called_once_with(
processor_mock = tc.managers[brain_name].processor
processor_mock.add_experiences.assert_called_once_with(
)
trainer_mock.process_experiences.assert_called_once_with(
new_step_info.previous_all_brain_info[brain_name],
new_step_info.current_all_brain_info[brain_name],
)
trainer_mock.clear_update_buffer.assert_called_once()

28
ml-agents/mlagents/trainers/tests/test_trajectory.py


from mlagents.trainers.trajectory import (
AgentExperience,
BootstrapExperience,
Trajectory,
split_obs,
trajectory_to_agentbuffer,

ACTION_SIZE = 4
def make_fake_trajectory(length: int, max_step_complete: bool = False) -> Trajectory:
def make_fake_trajectory(
length: int,
max_step_complete: bool = False,
vec_obs_size: int = VEC_OBS_SIZE,
num_vis_obs: int = 1,
action_space: int = ACTION_SIZE,
) -> Trajectory:
"""
Makes a fake trajectory of length length. If max_step_complete,
the trajectory is terminated by a max step rather than a done.

obs = [np.ones((84, 84, 3)), np.ones(VEC_OBS_SIZE)]
obs = []
for i in range(num_vis_obs):
obs.append(np.ones((84, 84, 3)))
obs.append(np.ones(vec_obs_size))
action = np.zeros(ACTION_SIZE)
action_probs = np.ones(ACTION_SIZE)
action_pre = np.zeros(ACTION_SIZE)
action_mask = np.ones(ACTION_SIZE)
prev_action = np.ones(ACTION_SIZE)
action = np.zeros(action_space)
action_probs = np.ones(action_space)
action_pre = np.zeros(action_space)
action_mask = np.ones(action_space)
prev_action = np.ones(action_space)
max_step = False
memory = np.ones(10)
agent_id = "test_agent"

agent_id=agent_id,
)
steps_list.append(last_experience)
bootstrap_step = experience
return Trajectory(steps=steps_list, bootstrap_step=bootstrap_step)
bootstrap_experience = BootstrapExperience(obs=obs, agent_id=agent_id)
return Trajectory(steps=steps_list, bootstrap_step=bootstrap_experience)
@pytest.mark.parametrize("num_visual_obs", [0, 1, 2])

正在加载...
取消
保存