浏览代码

[cherry-pick] Fix group rewards for POCA, add warning for non-POCA trainers (#5120)

* Fix end episode for POCA, add warning for group reward if not POCA (#5113)

* Fix end episode for POCA, add warning for group reward if not POCA

* Add missing imports

* Use np.any, which is faster
/release_15_branch
GitHub 3 年前
当前提交
63169e2c
共有 6 个文件被更改,包括 95 次插入12 次删除
  1. 9
      ml-agents/mlagents/trainers/poca/trainer.py
  2. 3
      ml-agents/mlagents/trainers/ppo/trainer.py
  3. 2
      ml-agents/mlagents/trainers/sac/trainer.py
  4. 15
      ml-agents/mlagents/trainers/tests/mock_brain.py
  5. 62
      ml-agents/mlagents/trainers/tests/torch/test_poca.py
  6. 16
      ml-agents/mlagents/trainers/trainer/rl_trainer.py

9
ml-agents/mlagents/trainers/poca/trainer.py


self._clear_update_buffer()
return True
def end_episode(self) -> None:
"""
A signal that the Episode has ended. The buffer must be reset.
Get only called when the academy resets. For POCA, we should
also zero out the group rewards.
"""
super().end_episode()
self.collected_group_rewards.clear()
def create_torch_policy(
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
) -> TorchPolicy:

3
ml-agents/mlagents/trainers/ppo/trainer.py


agent_id = trajectory.agent_id # All the agents should have the same ID
agent_buffer_trajectory = trajectory.to_agentbuffer()
# Check if we used group rewards, warn if so.
self._warn_if_group_reward(agent_buffer_trajectory)
# Update the normalization
if self.is_training:
self.policy.update_normalization(agent_buffer_trajectory)

2
ml-agents/mlagents/trainers/sac/trainer.py


agent_id = trajectory.agent_id # All the agents should have the same ID
agent_buffer_trajectory = trajectory.to_agentbuffer()
# Check if we used group rewards, warn if so.
self._warn_if_group_reward(agent_buffer_trajectory)
# Update the normalization
if self.is_training:

15
ml-agents/mlagents/trainers/tests/mock_brain.py


max_step_complete: bool = False,
memory_size: int = 10,
num_other_agents_in_group: int = 0,
group_reward: float = 0.0,
is_terminal: bool = True,
) -> Trajectory:
"""
Makes a fake trajectory of length length. If max_step_complete,

interrupted=max_step,
memory=memory,
group_status=group_status,
group_reward=0,
group_reward=group_reward,
last_group_status = []
for _ in range(num_other_agents_in_group):
last_group_status.append(
AgentStatus(obs, reward, action, not max_step_complete and is_terminal)
)
done=not max_step_complete,
done=not max_step_complete and is_terminal,
action=action,
action_probs=action_probs,
action_mask=action_mask,

group_status=group_status,
group_reward=0,
group_status=last_group_status,
group_reward=group_reward,
)
steps_list.append(last_experience)
return Trajectory(

62
ml-agents/mlagents/trainers/tests/torch/test_poca.py


from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
# Import to avoid circular import
from mlagents.trainers.trainer.trainer_factory import TrainerFactory # noqa F401
from mlagents.trainers.poca.trainer import POCATrainer
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType
from mlagents.trainers.policy.torch_policy import TorchPolicy

from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.tests.dummy_config import ( # noqa: F401
ppo_dummy_config,
create_observation_specs_with_shapes,
poca_dummy_config,
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.settings import TrainerSettings
from mlagents_envs.base_env import ActionSpec
from mlagents_envs.base_env import ActionSpec, BehaviorSpec
# poca has the same hyperparameters as ppo for now
return ppo_dummy_config()
return poca_dummy_config()
VECTOR_ACTION_SPACE = 2

@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
# We need to test this separately from test_reward_signals.py to ensure no interactions
def test_ppo_optimizer_update_curiosity(
def test_poca_optimizer_update_curiosity(
dummy_config, curiosity_dummy_config, rnn, visual, discrete # noqa: F811
):
# Test evaluate

# We need to test this separately from test_reward_signals.py to ensure no interactions
def test_ppo_optimizer_update_gail(gail_dummy_config, dummy_config): # noqa: F811
def test_poca_optimizer_update_gail(gail_dummy_config, dummy_config): # noqa: F811
config = ppo_dummy_config()
config = poca_dummy_config()
optimizer = create_test_poca_optimizer(
config, use_rnn=False, use_discrete=False, use_visual=False
)

update_buffer,
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length,
)
def test_poca_end_episode():
name_behavior_id = "test_trainer"
trainer = POCATrainer(
name_behavior_id,
10,
TrainerSettings(max_steps=100, checkpoint_interval=10, summary_freq=20),
True,
False,
0,
"mock_model_path",
)
behavior_spec = BehaviorSpec(
create_observation_specs_with_shapes([(1,)]), ActionSpec.create_discrete((2,))
)
parsed_behavior_id = BehaviorIdentifiers.from_name_behavior_id(name_behavior_id)
mock_policy = trainer.create_policy(parsed_behavior_id, behavior_spec)
trainer.add_policy(parsed_behavior_id, mock_policy)
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")
trainer.subscribe_trajectory_queue(trajectory_queue)
trainer.publish_policy_queue(policy_queue)
time_horizon = 10
trajectory = mb.make_fake_trajectory(
length=time_horizon,
observation_specs=behavior_spec.observation_specs,
max_step_complete=False,
action_spec=behavior_spec.action_spec,
num_other_agents_in_group=2,
group_reward=1.0,
is_terminal=False,
)
trajectory_queue.put(trajectory)
trainer.advance()
# Test that some trajectoories have been injested
for reward in trainer.collected_group_rewards.values():
assert reward == 10
# Test end episode
trainer.end_episode()
assert len(trainer.collected_group_rewards.keys()) == 0
if __name__ == "__main__":

16
ml-agents/mlagents/trainers/trainer/rl_trainer.py


import abc
import time
import attr
import numpy as np
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
from mlagents.trainers.policy.checkpoint_manager import (

from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed
from mlagents.trainers.optimizer import Optimizer
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.buffer import AgentBuffer, BufferKey
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (
BaseRewardProvider,

self.model_saver = self.create_model_saver(
self.trainer_settings, self.artifact_path, self.load
)
self._has_warned_group_rewards = False
def end_episode(self) -> None:
"""

)
if step_after_process >= self._next_save_step and self.get_step != 0:
self._checkpoint()
def _warn_if_group_reward(self, buffer: AgentBuffer) -> None:
"""
Warn if the trainer receives a Group Reward but isn't a multiagent trainer (e.g. POCA).
"""
if not self._has_warned_group_rewards:
if not np.any(buffer[BufferKey.GROUP_REWARD]):
logger.warning(
"An agent recieved a Group Reward, but you are not using a multi-agent trainer. "
"Please use the POCA trainer for best results."
)
self._has_warned_group_rewards = True
def advance(self) -> None:
"""

正在加载...
取消
保存