浏览代码
POCA trainer (#5005)
POCA trainer (#5005)
Co-authored-by: Ervin Teng <ervin@unity3d.com> Co-authored-by: Ruo-Ping Dong <ruoping.dong@unity3d.com> Co-authored-by: Chris Elion <chris.elion@unity3d.com> Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com>/develop/input-actuator-tanks
GitHub
4 年前
当前提交
8f35bdd3
共有 28 个文件被更改,包括 2222 次插入 和 156 次删除
-
2com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs
-
6ml-agents/mlagents/trainers/buffer.py
-
4ml-agents/mlagents/trainers/ghost/trainer.py
-
28ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
-
62ml-agents/mlagents/trainers/ppo/optimizer_torch.py
-
11ml-agents/mlagents/trainers/settings.py
-
6ml-agents/mlagents/trainers/stats.py
-
6ml-agents/mlagents/trainers/tests/check_env_trains.py
-
19ml-agents/mlagents/trainers/tests/dummy_config.py
-
2ml-agents/mlagents/trainers/tests/mock_brain.py
-
206ml-agents/mlagents/trainers/tests/simple_test_envs.py
-
16ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
-
11ml-agents/mlagents/trainers/tests/torch/test_agent_action.py
-
2ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
-
151ml-agents/mlagents/trainers/tests/torch/test_networks.py
-
19ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
-
79ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
-
16ml-agents/mlagents/trainers/torch/agent_action.py
-
30ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py
-
320ml-agents/mlagents/trainers/torch/networks.py
-
93ml-agents/mlagents/trainers/torch/utils.py
-
11ml-agents/mlagents/trainers/trainer/trainer_factory.py
-
4ml-agents/mlagents/trainers/trajectory.py
-
290ml-agents/mlagents/trainers/tests/torch/test_poca.py
-
0ml-agents/mlagents/trainers/poca/__init__.py
-
674ml-agents/mlagents/trainers/poca/optimizer_torch.py
-
310ml-agents/mlagents/trainers/poca/trainer.py
|
|||
import pytest |
|||
|
|||
import numpy as np |
|||
import attr |
|||
|
|||
from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer |
|||
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType |
|||
|
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.tests import mock_brain as mb |
|||
from mlagents.trainers.tests.mock_brain import copy_buffer_fields |
|||
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory |
|||
from mlagents.trainers.settings import NetworkSettings |
|||
from mlagents.trainers.tests.dummy_config import ( # noqa: F401 |
|||
ppo_dummy_config, |
|||
curiosity_dummy_config, |
|||
gail_dummy_config, |
|||
) |
|||
|
|||
from mlagents_envs.base_env import ActionSpec |
|||
from mlagents.trainers.buffer import BufferKey, RewardSignalUtil |
|||
|
|||
|
|||
@pytest.fixture |
|||
def dummy_config(): |
|||
# poca has the same hyperparameters as ppo for now |
|||
return ppo_dummy_config() |
|||
|
|||
|
|||
VECTOR_ACTION_SPACE = 2 |
|||
VECTOR_OBS_SPACE = 8 |
|||
DISCRETE_ACTION_SPACE = [3, 3, 3, 2] |
|||
BUFFER_INIT_SAMPLES = 64 |
|||
NUM_AGENTS = 4 |
|||
|
|||
CONTINUOUS_ACTION_SPEC = ActionSpec.create_continuous(VECTOR_ACTION_SPACE) |
|||
DISCRETE_ACTION_SPEC = ActionSpec.create_discrete(tuple(DISCRETE_ACTION_SPACE)) |
|||
|
|||
|
|||
def create_test_poca_optimizer(dummy_config, use_rnn, use_discrete, use_visual): |
|||
mock_specs = mb.setup_test_behavior_specs( |
|||
use_discrete, |
|||
use_visual, |
|||
vector_action_space=DISCRETE_ACTION_SPACE |
|||
if use_discrete |
|||
else VECTOR_ACTION_SPACE, |
|||
vector_obs_space=VECTOR_OBS_SPACE, |
|||
) |
|||
|
|||
trainer_settings = attr.evolve(dummy_config) |
|||
trainer_settings.reward_signals = { |
|||
RewardSignalType.EXTRINSIC: RewardSignalSettings(strength=1.0, gamma=0.99) |
|||
} |
|||
|
|||
trainer_settings.network_settings.memory = ( |
|||
NetworkSettings.MemorySettings(sequence_length=16, memory_size=10) |
|||
if use_rnn |
|||
else None |
|||
) |
|||
policy = TorchPolicy(0, mock_specs, trainer_settings, "test", False) |
|||
optimizer = TorchPOCAOptimizer(policy, trainer_settings) |
|||
return optimizer |
|||
|
|||
|
|||
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|||
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|||
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|||
def test_poca_optimizer_update(dummy_config, rnn, visual, discrete): |
|||
# Test evaluate |
|||
optimizer = create_test_poca_optimizer( |
|||
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|||
) |
|||
# Test update |
|||
update_buffer = mb.simulate_rollout( |
|||
BUFFER_INIT_SAMPLES, |
|||
optimizer.policy.behavior_spec, |
|||
memory_size=optimizer.policy.m_size, |
|||
num_other_agents_in_group=NUM_AGENTS, |
|||
) |
|||
# Mock out reward signal eval |
|||
copy_buffer_fields( |
|||
update_buffer, |
|||
BufferKey.ENVIRONMENT_REWARDS, |
|||
[ |
|||
BufferKey.ADVANTAGES, |
|||
RewardSignalUtil.returns_key("extrinsic"), |
|||
RewardSignalUtil.value_estimates_key("extrinsic"), |
|||
RewardSignalUtil.baseline_estimates_key("extrinsic"), |
|||
], |
|||
) |
|||
# Copy memories to critic memories |
|||
copy_buffer_fields( |
|||
update_buffer, |
|||
BufferKey.MEMORY, |
|||
[BufferKey.CRITIC_MEMORY, BufferKey.BASELINE_MEMORY], |
|||
) |
|||
|
|||
return_stats = optimizer.update( |
|||
update_buffer, |
|||
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length, |
|||
) |
|||
# Make sure we have the right stats |
|||
required_stats = [ |
|||
"Losses/Policy Loss", |
|||
"Losses/Value Loss", |
|||
"Policy/Learning Rate", |
|||
"Policy/Epsilon", |
|||
"Policy/Beta", |
|||
] |
|||
for stat in required_stats: |
|||
assert stat in return_stats.keys() |
|||
|
|||
|
|||
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|||
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|||
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|||
def test_poca_get_value_estimates(dummy_config, rnn, visual, discrete): |
|||
optimizer = create_test_poca_optimizer( |
|||
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|||
) |
|||
time_horizon = 15 |
|||
trajectory = make_fake_trajectory( |
|||
length=time_horizon, |
|||
observation_specs=optimizer.policy.behavior_spec.observation_specs, |
|||
action_spec=DISCRETE_ACTION_SPEC if discrete else CONTINUOUS_ACTION_SPEC, |
|||
max_step_complete=True, |
|||
num_other_agents_in_group=NUM_AGENTS, |
|||
) |
|||
( |
|||
value_estimates, |
|||
baseline_estimates, |
|||
value_next, |
|||
value_memories, |
|||
baseline_memories, |
|||
) = optimizer.get_trajectory_and_baseline_value_estimates( |
|||
trajectory.to_agentbuffer(), |
|||
trajectory.next_obs, |
|||
trajectory.next_group_obs, |
|||
done=False, |
|||
) |
|||
for key, val in value_estimates.items(): |
|||
assert type(key) is str |
|||
assert len(val) == 15 |
|||
for key, val in baseline_estimates.items(): |
|||
assert type(key) is str |
|||
assert len(val) == 15 |
|||
|
|||
if value_memories is not None: |
|||
assert len(value_memories) == 15 |
|||
assert len(baseline_memories) == 15 |
|||
|
|||
( |
|||
value_estimates, |
|||
baseline_estimates, |
|||
value_next, |
|||
value_memories, |
|||
baseline_memories, |
|||
) = optimizer.get_trajectory_and_baseline_value_estimates( |
|||
trajectory.to_agentbuffer(), |
|||
trajectory.next_obs, |
|||
trajectory.next_group_obs, |
|||
done=True, |
|||
) |
|||
for key, val in value_next.items(): |
|||
assert type(key) is str |
|||
assert val == 0.0 |
|||
|
|||
# Check if we ignore terminal states properly |
|||
optimizer.reward_signals["extrinsic"].use_terminal_states = False |
|||
( |
|||
value_estimates, |
|||
baseline_estimates, |
|||
value_next, |
|||
value_memories, |
|||
baseline_memories, |
|||
) = optimizer.get_trajectory_and_baseline_value_estimates( |
|||
trajectory.to_agentbuffer(), |
|||
trajectory.next_obs, |
|||
trajectory.next_group_obs, |
|||
done=False, |
|||
) |
|||
for key, val in value_next.items(): |
|||
assert type(key) is str |
|||
assert val != 0.0 |
|||
|
|||
|
|||
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|||
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|||
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|||
# We need to test this separately from test_reward_signals.py to ensure no interactions |
|||
def test_ppo_optimizer_update_curiosity( |
|||
dummy_config, curiosity_dummy_config, rnn, visual, discrete # noqa: F811 |
|||
): |
|||
# Test evaluate |
|||
dummy_config.reward_signals = curiosity_dummy_config |
|||
optimizer = create_test_poca_optimizer( |
|||
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|||
) |
|||
# Test update |
|||
update_buffer = mb.simulate_rollout( |
|||
BUFFER_INIT_SAMPLES, |
|||
optimizer.policy.behavior_spec, |
|||
memory_size=optimizer.policy.m_size, |
|||
) |
|||
# Mock out reward signal eval |
|||
copy_buffer_fields( |
|||
update_buffer, |
|||
src_key=BufferKey.ENVIRONMENT_REWARDS, |
|||
dst_keys=[ |
|||
BufferKey.ADVANTAGES, |
|||
RewardSignalUtil.returns_key("extrinsic"), |
|||
RewardSignalUtil.value_estimates_key("extrinsic"), |
|||
RewardSignalUtil.baseline_estimates_key("extrinsic"), |
|||
RewardSignalUtil.returns_key("curiosity"), |
|||
RewardSignalUtil.value_estimates_key("curiosity"), |
|||
RewardSignalUtil.baseline_estimates_key("curiosity"), |
|||
], |
|||
) |
|||
# Copy memories to critic memories |
|||
copy_buffer_fields( |
|||
update_buffer, |
|||
BufferKey.MEMORY, |
|||
[BufferKey.CRITIC_MEMORY, BufferKey.BASELINE_MEMORY], |
|||
) |
|||
|
|||
optimizer.update( |
|||
update_buffer, |
|||
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length, |
|||
) |
|||
|
|||
|
|||
# We need to test this separately from test_reward_signals.py to ensure no interactions |
|||
def test_ppo_optimizer_update_gail(gail_dummy_config, dummy_config): # noqa: F811 |
|||
# Test evaluate |
|||
dummy_config.reward_signals = gail_dummy_config |
|||
config = ppo_dummy_config() |
|||
optimizer = create_test_poca_optimizer( |
|||
config, use_rnn=False, use_discrete=False, use_visual=False |
|||
) |
|||
# Test update |
|||
update_buffer = mb.simulate_rollout( |
|||
BUFFER_INIT_SAMPLES, optimizer.policy.behavior_spec |
|||
) |
|||
# Mock out reward signal eval |
|||
copy_buffer_fields( |
|||
update_buffer, |
|||
src_key=BufferKey.ENVIRONMENT_REWARDS, |
|||
dst_keys=[ |
|||
BufferKey.ADVANTAGES, |
|||
RewardSignalUtil.returns_key("extrinsic"), |
|||
RewardSignalUtil.value_estimates_key("extrinsic"), |
|||
RewardSignalUtil.baseline_estimates_key("extrinsic"), |
|||
RewardSignalUtil.returns_key("gail"), |
|||
RewardSignalUtil.value_estimates_key("gail"), |
|||
RewardSignalUtil.baseline_estimates_key("gail"), |
|||
], |
|||
) |
|||
|
|||
update_buffer[BufferKey.CONTINUOUS_LOG_PROBS] = np.ones_like( |
|||
update_buffer[BufferKey.CONTINUOUS_ACTION] |
|||
) |
|||
optimizer.update( |
|||
update_buffer, |
|||
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length, |
|||
) |
|||
|
|||
# Check if buffer size is too big |
|||
update_buffer = mb.simulate_rollout(3000, optimizer.policy.behavior_spec) |
|||
# Mock out reward signal eval |
|||
copy_buffer_fields( |
|||
update_buffer, |
|||
src_key=BufferKey.ENVIRONMENT_REWARDS, |
|||
dst_keys=[ |
|||
BufferKey.ADVANTAGES, |
|||
RewardSignalUtil.returns_key("extrinsic"), |
|||
RewardSignalUtil.value_estimates_key("extrinsic"), |
|||
RewardSignalUtil.baseline_estimates_key("extrinsic"), |
|||
RewardSignalUtil.returns_key("gail"), |
|||
RewardSignalUtil.value_estimates_key("gail"), |
|||
RewardSignalUtil.baseline_estimates_key("gail"), |
|||
], |
|||
) |
|||
optimizer.update( |
|||
update_buffer, |
|||
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length, |
|||
) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
pytest.main() |
|
|||
from typing import Dict, cast, List, Tuple, Optional |
|||
from mlagents.trainers.torch.components.reward_providers.extrinsic_reward_provider import ( |
|||
ExtrinsicRewardProvider, |
|||
) |
|||
import numpy as np |
|||
import math |
|||
from mlagents.torch_utils import torch |
|||
|
|||
from mlagents.trainers.buffer import ( |
|||
AgentBuffer, |
|||
BufferKey, |
|||
RewardSignalUtil, |
|||
AgentBufferField, |
|||
) |
|||
|
|||
from mlagents_envs.timers import timed |
|||
from mlagents_envs.base_env import ObservationSpec, ActionSpec |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|||
from mlagents.trainers.settings import ( |
|||
RewardSignalSettings, |
|||
RewardSignalType, |
|||
TrainerSettings, |
|||
POCASettings, |
|||
) |
|||
from mlagents.trainers.torch.networks import Critic, MultiAgentNetworkBody |
|||
from mlagents.trainers.torch.decoders import ValueHeads |
|||
from mlagents.trainers.torch.agent_action import AgentAction |
|||
from mlagents.trainers.torch.action_log_probs import ActionLogProbs |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.trajectory import ObsUtil, GroupObsUtil |
|||
from mlagents.trainers.settings import NetworkSettings |
|||
|
|||
from mlagents_envs.logging_util import get_logger |
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class TorchPOCAOptimizer(TorchOptimizer): |
|||
class POCAValueNetwork(torch.nn.Module, Critic): |
|||
""" |
|||
The POCAValueNetwork uses the MultiAgentNetworkBody to compute the value |
|||
and POCA baseline for a variable number of agents in a group that all |
|||
share the same observation and action space. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
stream_names: List[str], |
|||
observation_specs: List[ObservationSpec], |
|||
network_settings: NetworkSettings, |
|||
action_spec: ActionSpec, |
|||
): |
|||
torch.nn.Module.__init__(self) |
|||
self.network_body = MultiAgentNetworkBody( |
|||
observation_specs, network_settings, action_spec |
|||
) |
|||
if network_settings.memory is not None: |
|||
encoding_size = network_settings.memory.memory_size // 2 |
|||
else: |
|||
encoding_size = network_settings.hidden_units |
|||
|
|||
self.value_heads = ValueHeads(stream_names, encoding_size, 1) |
|||
|
|||
@property |
|||
def memory_size(self) -> int: |
|||
return self.network_body.memory_size |
|||
|
|||
def update_normalization(self, buffer: AgentBuffer) -> None: |
|||
self.network_body.update_normalization(buffer) |
|||
|
|||
def baseline( |
|||
self, |
|||
obs_without_actions: List[torch.Tensor], |
|||
obs_with_actions: Tuple[List[List[torch.Tensor]], List[AgentAction]], |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|||
""" |
|||
The POCA baseline marginalizes the action of the agent associated with self_obs. |
|||
It calls the forward pass of the MultiAgentNetworkBody with the state action |
|||
pairs of groupmates but just the state of the agent in question. |
|||
:param obs_without_actions: The obs of the agent for which to compute the baseline. |
|||
:param obs_with_actions: Tuple of observations and actions for all groupmates. |
|||
:param memories: If using memory, a Tensor of initial memories. |
|||
:param sequence_length: If using memory, the sequence length. |
|||
|
|||
:return: A Tuple of Dict of reward stream to tensor and critic memories. |
|||
""" |
|||
(obs, actions) = obs_with_actions |
|||
encoding, memories = self.network_body( |
|||
obs_only=[obs_without_actions], |
|||
obs=obs, |
|||
actions=actions, |
|||
memories=memories, |
|||
sequence_length=sequence_length, |
|||
) |
|||
value_outputs, critic_mem_out = self.forward( |
|||
encoding, memories, sequence_length |
|||
) |
|||
return value_outputs, critic_mem_out |
|||
|
|||
def critic_pass( |
|||
self, |
|||
obs: List[List[torch.Tensor]], |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|||
""" |
|||
A centralized value function. It calls the forward pass of MultiAgentNetworkBody |
|||
with just the states of all agents. |
|||
:param obs: List of observations for all agents in group |
|||
:param memories: If using memory, a Tensor of initial memories. |
|||
:param sequence_length: If using memory, the sequence length. |
|||
:return: A Tuple of Dict of reward stream to tensor and critic memories. |
|||
""" |
|||
encoding, memories = self.network_body( |
|||
obs_only=obs, |
|||
obs=[], |
|||
actions=[], |
|||
memories=memories, |
|||
sequence_length=sequence_length, |
|||
) |
|||
value_outputs, critic_mem_out = self.forward( |
|||
encoding, memories, sequence_length |
|||
) |
|||
return value_outputs, critic_mem_out |
|||
|
|||
def forward( |
|||
self, |
|||
encoding: torch.Tensor, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[torch.Tensor, torch.Tensor]: |
|||
|
|||
output = self.value_heads(encoding) |
|||
return output, memories |
|||
|
|||
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|||
""" |
|||
Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy. |
|||
:param policy: A TorchPolicy object that will be updated by this POCA Optimizer. |
|||
:param trainer_params: Trainer parameters dictionary that specifies the |
|||
properties of the trainer. |
|||
""" |
|||
# Create the graph here to give more granular control of the TF graph to the Optimizer. |
|||
|
|||
super().__init__(policy, trainer_settings) |
|||
reward_signal_configs = trainer_settings.reward_signals |
|||
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] |
|||
|
|||
self._critic = TorchPOCAOptimizer.POCAValueNetwork( |
|||
reward_signal_names, |
|||
policy.behavior_spec.observation_specs, |
|||
network_settings=trainer_settings.network_settings, |
|||
action_spec=policy.behavior_spec.action_spec, |
|||
) |
|||
|
|||
params = list(self.policy.actor.parameters()) + list(self.critic.parameters()) |
|||
self.hyperparameters: POCASettings = cast( |
|||
POCASettings, trainer_settings.hyperparameters |
|||
) |
|||
self.decay_learning_rate = ModelUtils.DecayedValue( |
|||
self.hyperparameters.learning_rate_schedule, |
|||
self.hyperparameters.learning_rate, |
|||
1e-10, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
self.decay_epsilon = ModelUtils.DecayedValue( |
|||
self.hyperparameters.learning_rate_schedule, |
|||
self.hyperparameters.epsilon, |
|||
0.1, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
self.decay_beta = ModelUtils.DecayedValue( |
|||
self.hyperparameters.learning_rate_schedule, |
|||
self.hyperparameters.beta, |
|||
1e-5, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
|
|||
self.optimizer = torch.optim.Adam( |
|||
params, lr=self.trainer_settings.hyperparameters.learning_rate |
|||
) |
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
} |
|||
|
|||
self.stream_names = list(self.reward_signals.keys()) |
|||
self.value_memory_dict: Dict[str, torch.Tensor] = {} |
|||
self.baseline_memory_dict: Dict[str, torch.Tensor] = {} |
|||
|
|||
def create_reward_signals( |
|||
self, reward_signal_configs: Dict[RewardSignalType, RewardSignalSettings] |
|||
) -> None: |
|||
""" |
|||
Create reward signals. Override default to provide warnings for Curiosity and |
|||
GAIL, and make sure Extrinsic adds team rewards. |
|||
:param reward_signal_configs: Reward signal config. |
|||
""" |
|||
for reward_signal in reward_signal_configs.keys(): |
|||
if reward_signal != RewardSignalType.EXTRINSIC: |
|||
logger.warning( |
|||
f"Reward signal {reward_signal.value.capitalize()} is not supported with the POCA trainer; " |
|||
"results may be unexpected." |
|||
) |
|||
super().create_reward_signals(reward_signal_configs) |
|||
# Make sure we add the groupmate rewards in POCA, so agents learn how to help each |
|||
# other achieve individual rewards as well |
|||
for reward_provider in self.reward_signals.values(): |
|||
if isinstance(reward_provider, ExtrinsicRewardProvider): |
|||
reward_provider.add_groupmate_rewards = True |
|||
|
|||
@property |
|||
def critic(self): |
|||
return self._critic |
|||
|
|||
@timed |
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
""" |
|||
Performs update on model. |
|||
:param batch: Batch of experiences. |
|||
:param num_sequences: Number of sequences to process. |
|||
:return: Results of update. |
|||
""" |
|||
# Get decayed parameters |
|||
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|||
decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step()) |
|||
decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) |
|||
returns = {} |
|||
old_values = {} |
|||
old_baseline_values = {} |
|||
for name in self.reward_signals: |
|||
old_values[name] = ModelUtils.list_to_tensor( |
|||
batch[RewardSignalUtil.value_estimates_key(name)] |
|||
) |
|||
returns[name] = ModelUtils.list_to_tensor( |
|||
batch[RewardSignalUtil.returns_key(name)] |
|||
) |
|||
old_baseline_values[name] = ModelUtils.list_to_tensor( |
|||
batch[RewardSignalUtil.baseline_estimates_key(name)] |
|||
) |
|||
|
|||
n_obs = len(self.policy.behavior_spec.observation_specs) |
|||
current_obs = ObsUtil.from_buffer(batch, n_obs) |
|||
# Convert to tensors |
|||
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] |
|||
groupmate_obs = GroupObsUtil.from_buffer(batch, n_obs) |
|||
groupmate_obs = [ |
|||
[ModelUtils.list_to_tensor(obs) for obs in _groupmate_obs] |
|||
for _groupmate_obs in groupmate_obs |
|||
] |
|||
|
|||
act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK]) |
|||
actions = AgentAction.from_buffer(batch) |
|||
groupmate_actions = AgentAction.group_from_buffer(batch) |
|||
|
|||
memories = [ |
|||
ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) |
|||
for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length) |
|||
] |
|||
if len(memories) > 0: |
|||
memories = torch.stack(memories).unsqueeze(0) |
|||
value_memories = [ |
|||
ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) |
|||
for i in range( |
|||
0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length |
|||
) |
|||
] |
|||
|
|||
baseline_memories = [ |
|||
ModelUtils.list_to_tensor(batch[BufferKey.BASELINE_MEMORY][i]) |
|||
for i in range( |
|||
0, len(batch[BufferKey.BASELINE_MEMORY]), self.policy.sequence_length |
|||
) |
|||
] |
|||
|
|||
if len(value_memories) > 0: |
|||
value_memories = torch.stack(value_memories).unsqueeze(0) |
|||
baseline_memories = torch.stack(baseline_memories).unsqueeze(0) |
|||
|
|||
log_probs, entropy = self.policy.evaluate_actions( |
|||
current_obs, |
|||
masks=act_masks, |
|||
actions=actions, |
|||
memories=memories, |
|||
seq_len=self.policy.sequence_length, |
|||
) |
|||
all_obs = [current_obs] + groupmate_obs |
|||
values, _ = self.critic.critic_pass( |
|||
all_obs, |
|||
memories=value_memories, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
groupmate_obs_and_actions = (groupmate_obs, groupmate_actions) |
|||
baselines, _ = self.critic.baseline( |
|||
current_obs, |
|||
groupmate_obs_and_actions, |
|||
memories=baseline_memories, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
old_log_probs = ActionLogProbs.from_buffer(batch).flatten() |
|||
log_probs = log_probs.flatten() |
|||
loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) |
|||
|
|||
baseline_loss = ModelUtils.trust_region_value_loss( |
|||
baselines, old_baseline_values, returns, decay_eps, loss_masks |
|||
) |
|||
value_loss = ModelUtils.trust_region_value_loss( |
|||
values, old_values, returns, decay_eps, loss_masks |
|||
) |
|||
policy_loss = ModelUtils.trust_region_policy_loss( |
|||
ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]), |
|||
log_probs, |
|||
old_log_probs, |
|||
loss_masks, |
|||
decay_eps, |
|||
) |
|||
loss = ( |
|||
policy_loss |
|||
+ 0.5 * (value_loss + 0.5 * baseline_loss) |
|||
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks) |
|||
) |
|||
|
|||
# Set optimizer learning rate |
|||
ModelUtils.update_learning_rate(self.optimizer, decay_lr) |
|||
self.optimizer.zero_grad() |
|||
loss.backward() |
|||
|
|||
self.optimizer.step() |
|||
update_stats = { |
|||
# NOTE: abs() is not technically correct, but matches the behavior in TensorFlow. |
|||
# TODO: After PyTorch is default, change to something more correct. |
|||
"Losses/Policy Loss": torch.abs(policy_loss).item(), |
|||
"Losses/Value Loss": value_loss.item(), |
|||
"Losses/Baseline Loss": baseline_loss.item(), |
|||
"Policy/Learning Rate": decay_lr, |
|||
"Policy/Epsilon": decay_eps, |
|||
"Policy/Beta": decay_bet, |
|||
} |
|||
|
|||
for reward_provider in self.reward_signals.values(): |
|||
update_stats.update(reward_provider.update(batch)) |
|||
|
|||
return update_stats |
|||
|
|||
def get_modules(self): |
|||
modules = {"Optimizer:adam": self.optimizer, "Optimizer:critic": self._critic} |
|||
for reward_provider in self.reward_signals.values(): |
|||
modules.update(reward_provider.get_modules()) |
|||
return modules |
|||
|
|||
def _evaluate_by_sequence_team( |
|||
self, |
|||
self_obs: List[torch.Tensor], |
|||
obs: List[List[torch.Tensor]], |
|||
actions: List[AgentAction], |
|||
init_value_mem: torch.Tensor, |
|||
init_baseline_mem: torch.Tensor, |
|||
) -> Tuple[ |
|||
Dict[str, torch.Tensor], |
|||
Dict[str, torch.Tensor], |
|||
AgentBufferField, |
|||
AgentBufferField, |
|||
torch.Tensor, |
|||
torch.Tensor, |
|||
]: |
|||
""" |
|||
Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the |
|||
intermediate memories for the critic. |
|||
:param tensor_obs: A List of tensors of shape (trajectory_len, <obs_dim>) that are the agent's |
|||
observations for this trajectory. |
|||
:param initial_memory: The memory that preceeds this trajectory. Of shape (1,1,<mem_size>), i.e. |
|||
what is returned as the output of a MemoryModules. |
|||
:return: A Tuple of the value estimates as a Dict of [name, tensor], an AgentBufferField of the initial |
|||
memories to be used during value function update, and the final memory at the end of the trajectory. |
|||
""" |
|||
num_experiences = self_obs[0].shape[0] |
|||
all_next_value_mem = AgentBufferField() |
|||
all_next_baseline_mem = AgentBufferField() |
|||
# In the buffer, the 1st sequence are the ones that are padded. So if seq_len = 3 and |
|||
# trajectory is of length 10, the 1st sequence is [pad,pad,obs]. |
|||
# Compute the number of elements in this padded seq. |
|||
leftover = num_experiences % self.policy.sequence_length |
|||
|
|||
# Compute values for the potentially truncated initial sequence |
|||
|
|||
first_seq_len = leftover if leftover > 0 else self.policy.sequence_length |
|||
|
|||
self_seq_obs = [] |
|||
groupmate_seq_obs = [] |
|||
groupmate_seq_act = [] |
|||
seq_obs = [] |
|||
for _self_obs in self_obs: |
|||
first_seq_obs = _self_obs[0:first_seq_len] |
|||
seq_obs.append(first_seq_obs) |
|||
self_seq_obs.append(seq_obs) |
|||
|
|||
for groupmate_obs, groupmate_action in zip(obs, actions): |
|||
seq_obs = [] |
|||
for _obs in groupmate_obs: |
|||
first_seq_obs = _obs[0:first_seq_len] |
|||
seq_obs.append(first_seq_obs) |
|||
groupmate_seq_obs.append(seq_obs) |
|||
_act = groupmate_action.slice(0, first_seq_len) |
|||
groupmate_seq_act.append(_act) |
|||
|
|||
# For the first sequence, the initial memory should be the one at the |
|||
# beginning of this trajectory. |
|||
for _ in range(first_seq_len): |
|||
all_next_value_mem.append(ModelUtils.to_numpy(init_value_mem.squeeze())) |
|||
all_next_baseline_mem.append( |
|||
ModelUtils.to_numpy(init_baseline_mem.squeeze()) |
|||
) |
|||
|
|||
all_seq_obs = self_seq_obs + groupmate_seq_obs |
|||
init_values, _value_mem = self.critic.critic_pass( |
|||
all_seq_obs, init_value_mem, sequence_length=first_seq_len |
|||
) |
|||
all_values = { |
|||
signal_name: [init_values[signal_name]] |
|||
for signal_name in init_values.keys() |
|||
} |
|||
|
|||
groupmate_obs_and_actions = (groupmate_seq_obs, groupmate_seq_act) |
|||
init_baseline, _baseline_mem = self.critic.baseline( |
|||
self_seq_obs[0], |
|||
groupmate_obs_and_actions, |
|||
init_baseline_mem, |
|||
sequence_length=first_seq_len, |
|||
) |
|||
all_baseline = { |
|||
signal_name: [init_baseline[signal_name]] |
|||
for signal_name in init_baseline.keys() |
|||
} |
|||
|
|||
# Evaluate other trajectories, carrying over _mem after each |
|||
# trajectory |
|||
for seq_num in range( |
|||
1, math.ceil((num_experiences) / (self.policy.sequence_length)) |
|||
): |
|||
for _ in range(self.policy.sequence_length): |
|||
all_next_value_mem.append(ModelUtils.to_numpy(_value_mem.squeeze())) |
|||
all_next_baseline_mem.append( |
|||
ModelUtils.to_numpy(_baseline_mem.squeeze()) |
|||
) |
|||
|
|||
start = seq_num * self.policy.sequence_length - ( |
|||
self.policy.sequence_length - leftover |
|||
) |
|||
end = (seq_num + 1) * self.policy.sequence_length - ( |
|||
self.policy.sequence_length - leftover |
|||
) |
|||
|
|||
self_seq_obs = [] |
|||
groupmate_seq_obs = [] |
|||
groupmate_seq_act = [] |
|||
seq_obs = [] |
|||
for _self_obs in self_obs: |
|||
seq_obs.append(_obs[start:end]) |
|||
self_seq_obs.append(seq_obs) |
|||
|
|||
for groupmate_obs, team_action in zip(obs, actions): |
|||
seq_obs = [] |
|||
for (_obs,) in groupmate_obs: |
|||
first_seq_obs = _obs[start:end] |
|||
seq_obs.append(first_seq_obs) |
|||
groupmate_seq_obs.append(seq_obs) |
|||
_act = team_action.slice(start, end) |
|||
groupmate_seq_act.append(_act) |
|||
|
|||
all_seq_obs = self_seq_obs + groupmate_seq_obs |
|||
values, _value_mem = self.critic.critic_pass( |
|||
all_seq_obs, _value_mem, sequence_length=self.policy.sequence_length |
|||
) |
|||
all_values = { |
|||
signal_name: [init_values[signal_name]] for signal_name in values.keys() |
|||
} |
|||
|
|||
groupmate_obs_and_actions = (groupmate_seq_obs, groupmate_seq_act) |
|||
baselines, _baseline_mem = self.critic.baseline( |
|||
self_seq_obs[0], |
|||
groupmate_obs_and_actions, |
|||
_baseline_mem, |
|||
sequence_length=first_seq_len, |
|||
) |
|||
all_baseline = { |
|||
signal_name: [baselines[signal_name]] |
|||
for signal_name in baselines.keys() |
|||
} |
|||
# Create one tensor per reward signal |
|||
all_value_tensors = { |
|||
signal_name: torch.cat(value_list, dim=0) |
|||
for signal_name, value_list in all_values.items() |
|||
} |
|||
all_baseline_tensors = { |
|||
signal_name: torch.cat(baseline_list, dim=0) |
|||
for signal_name, baseline_list in all_baseline.items() |
|||
} |
|||
next_value_mem = _value_mem |
|||
next_baseline_mem = _baseline_mem |
|||
return ( |
|||
all_value_tensors, |
|||
all_baseline_tensors, |
|||
all_next_value_mem, |
|||
all_next_baseline_mem, |
|||
next_value_mem, |
|||
next_baseline_mem, |
|||
) |
|||
|
|||
def get_trajectory_value_estimates( |
|||
self, |
|||
batch: AgentBuffer, |
|||
next_obs: List[np.ndarray], |
|||
done: bool, |
|||
agent_id: str = "", |
|||
) -> Tuple[Dict[str, np.ndarray], Dict[str, float], Optional[AgentBufferField]]: |
|||
""" |
|||
Override base class method. Unused in the trainer, but needed to make sure class heirarchy is maintained. |
|||
Assume that there are no group obs. |
|||
""" |
|||
( |
|||
value_estimates, |
|||
_, |
|||
next_value_estimates, |
|||
all_next_value_mem, |
|||
_, |
|||
) = self.get_trajectory_and_baseline_value_estimates( |
|||
batch, next_obs, [], done, agent_id |
|||
) |
|||
|
|||
return value_estimates, next_value_estimates, all_next_value_mem |
|||
|
|||
def get_trajectory_and_baseline_value_estimates( |
|||
self, |
|||
batch: AgentBuffer, |
|||
next_obs: List[np.ndarray], |
|||
next_groupmate_obs: List[List[np.ndarray]], |
|||
done: bool, |
|||
agent_id: str = "", |
|||
) -> Tuple[ |
|||
Dict[str, np.ndarray], |
|||
Dict[str, np.ndarray], |
|||
Dict[str, float], |
|||
Optional[AgentBufferField], |
|||
Optional[AgentBufferField], |
|||
]: |
|||
""" |
|||
Get value estimates, baseline estimates, and memories for a trajectory, in batch form. |
|||
:param batch: An AgentBuffer that consists of a trajectory. |
|||
:param next_obs: the next observation (after the trajectory). Used for boostrapping |
|||
if this is not a termiinal trajectory. |
|||
:param next_groupmate_obs: the next observations from other members of the group. |
|||
:param done: Set true if this is a terminal trajectory. |
|||
:param agent_id: Agent ID of the agent that this trajectory belongs to. |
|||
:returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)], |
|||
the baseline estimates as a Dict, the final value estimate as a Dict of [name, float], and |
|||
optionally (if using memories) an AgentBufferField of initial critic and baseline memories to be used |
|||
during update. |
|||
""" |
|||
|
|||
n_obs = len(self.policy.behavior_spec.observation_specs) |
|||
|
|||
current_obs = ObsUtil.from_buffer(batch, n_obs) |
|||
groupmate_obs = GroupObsUtil.from_buffer(batch, n_obs) |
|||
|
|||
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] |
|||
groupmate_obs = [ |
|||
[ModelUtils.list_to_tensor(obs) for obs in _groupmate_obs] |
|||
for _groupmate_obs in groupmate_obs |
|||
] |
|||
|
|||
groupmate_actions = AgentAction.group_from_buffer(batch) |
|||
|
|||
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] |
|||
next_obs = [obs.unsqueeze(0) for obs in next_obs] |
|||
|
|||
next_groupmate_obs = [ |
|||
ModelUtils.list_to_tensor_list(_list_obs) |
|||
for _list_obs in next_groupmate_obs |
|||
] |
|||
# Expand dimensions of next critic obs |
|||
next_groupmate_obs = [ |
|||
[_obs.unsqueeze(0) for _obs in _list_obs] |
|||
for _list_obs in next_groupmate_obs |
|||
] |
|||
|
|||
if agent_id in self.value_memory_dict: |
|||
# The agent_id should always be in both since they are added together |
|||
_init_value_mem = self.value_memory_dict[agent_id] |
|||
_init_baseline_mem = self.baseline_memory_dict[agent_id] |
|||
else: |
|||
_init_value_mem = ( |
|||
torch.zeros((1, 1, self.critic.memory_size)) |
|||
if self.policy.use_recurrent |
|||
else None |
|||
) |
|||
_init_baseline_mem = ( |
|||
torch.zeros((1, 1, self.critic.memory_size)) |
|||
if self.policy.use_recurrent |
|||
else None |
|||
) |
|||
|
|||
all_obs = ( |
|||
[current_obs] + groupmate_obs |
|||
if groupmate_obs is not None |
|||
else [current_obs] |
|||
) |
|||
all_next_value_mem: Optional[AgentBufferField] = None |
|||
all_next_baseline_mem: Optional[AgentBufferField] = None |
|||
with torch.no_grad(): |
|||
if self.policy.use_recurrent: |
|||
( |
|||
value_estimates, |
|||
baseline_estimates, |
|||
all_next_value_mem, |
|||
all_next_baseline_mem, |
|||
next_value_mem, |
|||
next_baseline_mem, |
|||
) = self._evaluate_by_sequence_team( |
|||
current_obs, |
|||
groupmate_obs, |
|||
groupmate_actions, |
|||
_init_value_mem, |
|||
_init_baseline_mem, |
|||
) |
|||
else: |
|||
value_estimates, next_value_mem = self.critic.critic_pass( |
|||
all_obs, _init_value_mem, sequence_length=batch.num_experiences |
|||
) |
|||
groupmate_obs_and_actions = (groupmate_obs, groupmate_actions) |
|||
baseline_estimates, next_baseline_mem = self.critic.baseline( |
|||
current_obs, |
|||
groupmate_obs_and_actions, |
|||
_init_baseline_mem, |
|||
sequence_length=batch.num_experiences, |
|||
) |
|||
# Store the memory for the next trajectory |
|||
self.value_memory_dict[agent_id] = next_value_mem |
|||
self.baseline_memory_dict[agent_id] = next_baseline_mem |
|||
|
|||
all_next_obs = ( |
|||
[next_obs] + next_groupmate_obs |
|||
if next_groupmate_obs is not None |
|||
else [next_obs] |
|||
) |
|||
|
|||
next_value_estimates, _ = self.critic.critic_pass( |
|||
all_next_obs, next_value_mem, sequence_length=1 |
|||
) |
|||
|
|||
for name, estimate in baseline_estimates.items(): |
|||
baseline_estimates[name] = ModelUtils.to_numpy(estimate) |
|||
|
|||
for name, estimate in value_estimates.items(): |
|||
value_estimates[name] = ModelUtils.to_numpy(estimate) |
|||
|
|||
# the base line and V shpuld not be on the same done flag |
|||
for name, estimate in next_value_estimates.items(): |
|||
next_value_estimates[name] = ModelUtils.to_numpy(estimate) |
|||
|
|||
if done: |
|||
for k in next_value_estimates: |
|||
if not self.reward_signals[k].ignore_done: |
|||
next_value_estimates[k][-1] = 0.0 |
|||
|
|||
return ( |
|||
value_estimates, |
|||
baseline_estimates, |
|||
next_value_estimates, |
|||
all_next_value_mem, |
|||
all_next_baseline_mem, |
|||
) |
|
|||
# # Unity ML-Agents Toolkit |
|||
# ## ML-Agents Learning (POCA) |
|||
# Contains an implementation of MA-POCA. |
|||
|
|||
from collections import defaultdict |
|||
from typing import cast, Dict |
|||
|
|||
import numpy as np |
|||
|
|||
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod |
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.buffer import BufferKey, RewardSignalUtil |
|||
from mlagents.trainers.trainer.rl_trainer import RLTrainer |
|||
from mlagents.trainers.policy import Policy |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer |
|||
from mlagents.trainers.trajectory import Trajectory |
|||
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers |
|||
from mlagents.trainers.settings import TrainerSettings, POCASettings |
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class POCATrainer(RLTrainer): |
|||
"""The POCATrainer is an implementation of the MA-POCA algorithm.""" |
|||
|
|||
def __init__( |
|||
self, |
|||
behavior_name: str, |
|||
reward_buff_cap: int, |
|||
trainer_settings: TrainerSettings, |
|||
training: bool, |
|||
load: bool, |
|||
seed: int, |
|||
artifact_path: str, |
|||
): |
|||
""" |
|||
Responsible for collecting experiences and training POCA model. |
|||
:param behavior_name: The name of the behavior associated with trainer config |
|||
:param reward_buff_cap: Max reward history to track in the reward buffer |
|||
:param trainer_settings: The parameters for the trainer. |
|||
:param training: Whether the trainer is set for training. |
|||
:param load: Whether the model should be loaded. |
|||
:param seed: The seed the model will be initialized with |
|||
:param artifact_path: The directory within which to store artifacts from this trainer. |
|||
""" |
|||
super().__init__( |
|||
behavior_name, |
|||
trainer_settings, |
|||
training, |
|||
load, |
|||
artifact_path, |
|||
reward_buff_cap, |
|||
) |
|||
self.hyperparameters: POCASettings = cast( |
|||
POCASettings, self.trainer_settings.hyperparameters |
|||
) |
|||
self.seed = seed |
|||
self.policy: TorchPolicy = None # type: ignore |
|||
self.collected_group_rewards: Dict[str, int] = defaultdict(lambda: 0) |
|||
|
|||
def _process_trajectory(self, trajectory: Trajectory) -> None: |
|||
""" |
|||
Takes a trajectory and processes it, putting it into the update buffer. |
|||
Processing involves calculating value and advantage targets for model updating step. |
|||
:param trajectory: The Trajectory tuple containing the steps to be processed. |
|||
""" |
|||
super()._process_trajectory(trajectory) |
|||
agent_id = trajectory.agent_id # All the agents should have the same ID |
|||
|
|||
agent_buffer_trajectory = trajectory.to_agentbuffer() |
|||
# Update the normalization |
|||
if self.is_training: |
|||
self.policy.update_normalization(agent_buffer_trajectory) |
|||
|
|||
# Get all value estimates |
|||
( |
|||
value_estimates, |
|||
baseline_estimates, |
|||
value_next, |
|||
value_memories, |
|||
baseline_memories, |
|||
) = self.optimizer.get_trajectory_and_baseline_value_estimates( |
|||
agent_buffer_trajectory, |
|||
trajectory.next_obs, |
|||
trajectory.next_group_obs, |
|||
trajectory.all_group_dones_reached |
|||
and trajectory.done_reached |
|||
and not trajectory.interrupted, |
|||
) |
|||
|
|||
if value_memories is not None and baseline_memories is not None: |
|||
agent_buffer_trajectory[BufferKey.CRITIC_MEMORY].set(value_memories) |
|||
agent_buffer_trajectory[BufferKey.BASELINE_MEMORY].set(baseline_memories) |
|||
|
|||
for name, v in value_estimates.items(): |
|||
agent_buffer_trajectory[RewardSignalUtil.value_estimates_key(name)].extend( |
|||
v |
|||
) |
|||
agent_buffer_trajectory[ |
|||
RewardSignalUtil.baseline_estimates_key(name) |
|||
].extend(baseline_estimates[name]) |
|||
self._stats_reporter.add_stat( |
|||
f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Baseline Estimate", |
|||
np.mean(baseline_estimates[name]), |
|||
) |
|||
self._stats_reporter.add_stat( |
|||
f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Value Estimate", |
|||
np.mean(value_estimates[name]), |
|||
) |
|||
|
|||
self.collected_rewards["environment"][agent_id] += np.sum( |
|||
agent_buffer_trajectory[BufferKey.ENVIRONMENT_REWARDS] |
|||
) |
|||
self.collected_group_rewards[agent_id] += np.sum( |
|||
agent_buffer_trajectory[BufferKey.GROUP_REWARD] |
|||
) |
|||
for name, reward_signal in self.optimizer.reward_signals.items(): |
|||
evaluate_result = ( |
|||
reward_signal.evaluate(agent_buffer_trajectory) * reward_signal.strength |
|||
) |
|||
agent_buffer_trajectory[RewardSignalUtil.rewards_key(name)].extend( |
|||
evaluate_result |
|||
) |
|||
# Report the reward signals |
|||
self.collected_rewards[name][agent_id] += np.sum(evaluate_result) |
|||
|
|||
# Compute lambda returns and advantage |
|||
tmp_advantages = [] |
|||
for name in self.optimizer.reward_signals: |
|||
|
|||
local_rewards = np.array( |
|||
agent_buffer_trajectory[RewardSignalUtil.rewards_key(name)].get_batch(), |
|||
dtype=np.float32, |
|||
) |
|||
|
|||
baseline_estimate = agent_buffer_trajectory[ |
|||
RewardSignalUtil.baseline_estimates_key(name) |
|||
].get_batch() |
|||
v_estimates = agent_buffer_trajectory[ |
|||
RewardSignalUtil.value_estimates_key(name) |
|||
].get_batch() |
|||
|
|||
lambd_returns = lambda_return( |
|||
r=local_rewards, |
|||
value_estimates=v_estimates, |
|||
gamma=self.optimizer.reward_signals[name].gamma, |
|||
lambd=self.hyperparameters.lambd, |
|||
value_next=value_next[name], |
|||
) |
|||
|
|||
local_advantage = np.array(lambd_returns) - np.array(baseline_estimate) |
|||
|
|||
agent_buffer_trajectory[RewardSignalUtil.returns_key(name)].set( |
|||
lambd_returns |
|||
) |
|||
agent_buffer_trajectory[RewardSignalUtil.advantage_key(name)].set( |
|||
local_advantage |
|||
) |
|||
tmp_advantages.append(local_advantage) |
|||
|
|||
# Get global advantages |
|||
global_advantages = list( |
|||
np.mean(np.array(tmp_advantages, dtype=np.float32), axis=0) |
|||
) |
|||
agent_buffer_trajectory[BufferKey.ADVANTAGES].set(global_advantages) |
|||
|
|||
# Append to update buffer |
|||
agent_buffer_trajectory.resequence_and_append( |
|||
self.update_buffer, training_length=self.policy.sequence_length |
|||
) |
|||
|
|||
# If this was a terminal trajectory, append stats and reset reward collection |
|||
if trajectory.done_reached: |
|||
self._update_end_episode_stats(agent_id, self.optimizer) |
|||
# Remove dead agents from group reward recording |
|||
if not trajectory.all_group_dones_reached: |
|||
self.collected_group_rewards.pop(agent_id) |
|||
|
|||
# If the whole team is done, average the remaining group rewards. |
|||
if trajectory.all_group_dones_reached and trajectory.done_reached: |
|||
self.stats_reporter.add_stat( |
|||
"Environment/Group Cumulative Reward", |
|||
self.collected_group_rewards.get(agent_id, 0), |
|||
aggregation=StatsAggregationMethod.HISTOGRAM, |
|||
) |
|||
self.collected_group_rewards.pop(agent_id) |
|||
|
|||
def _is_ready_update(self): |
|||
""" |
|||
Returns whether or not the trainer has enough elements to run update model |
|||
:return: A boolean corresponding to whether or not update_model() can be run |
|||
""" |
|||
size_of_buffer = self.update_buffer.num_experiences |
|||
return size_of_buffer > self.hyperparameters.buffer_size |
|||
|
|||
def _update_policy(self): |
|||
""" |
|||
Uses demonstration_buffer to update the policy. |
|||
The reward signal generators must be updated in this method at their own pace. |
|||
""" |
|||
buffer_length = self.update_buffer.num_experiences |
|||
self.cumulative_returns_since_policy_update.clear() |
|||
|
|||
# Make sure batch_size is a multiple of sequence length. During training, we |
|||
# will need to reshape the data into a batch_size x sequence_length tensor. |
|||
batch_size = ( |
|||
self.hyperparameters.batch_size |
|||
- self.hyperparameters.batch_size % self.policy.sequence_length |
|||
) |
|||
# Make sure there is at least one sequence |
|||
batch_size = max(batch_size, self.policy.sequence_length) |
|||
|
|||
n_sequences = max( |
|||
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1 |
|||
) |
|||
|
|||
advantages = np.array( |
|||
self.update_buffer[BufferKey.ADVANTAGES].get_batch(), dtype=np.float32 |
|||
) |
|||
self.update_buffer[BufferKey.ADVANTAGES].set( |
|||
(advantages - advantages.mean()) / (advantages.std() + 1e-10) |
|||
) |
|||
num_epoch = self.hyperparameters.num_epoch |
|||
batch_update_stats = defaultdict(list) |
|||
for _ in range(num_epoch): |
|||
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|||
buffer = self.update_buffer |
|||
max_num_batch = buffer_length // batch_size |
|||
for i in range(0, max_num_batch * batch_size, batch_size): |
|||
update_stats = self.optimizer.update( |
|||
buffer.make_mini_batch(i, i + batch_size), n_sequences |
|||
) |
|||
for stat_name, value in update_stats.items(): |
|||
batch_update_stats[stat_name].append(value) |
|||
|
|||
for stat, stat_list in batch_update_stats.items(): |
|||
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|||
|
|||
if self.optimizer.bc_module: |
|||
update_stats = self.optimizer.bc_module.update() |
|||
for stat, val in update_stats.items(): |
|||
self._stats_reporter.add_stat(stat, val) |
|||
self._clear_update_buffer() |
|||
return True |
|||
|
|||
def create_torch_policy( |
|||
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec |
|||
) -> TorchPolicy: |
|||
""" |
|||
Creates a policy with a PyTorch backend and POCA hyperparameters |
|||
:param parsed_behavior_id: |
|||
:param behavior_spec: specifications for policy construction |
|||
:return policy |
|||
""" |
|||
policy = TorchPolicy( |
|||
self.seed, |
|||
behavior_spec, |
|||
self.trainer_settings, |
|||
condition_sigma_on_obs=False, # Faster training for POCA |
|||
separate_critic=True, # Match network architecture with TF |
|||
) |
|||
return policy |
|||
|
|||
def create_poca_optimizer(self) -> TorchPOCAOptimizer: |
|||
return TorchPOCAOptimizer(self.policy, self.trainer_settings) |
|||
|
|||
def add_policy( |
|||
self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy |
|||
) -> None: |
|||
""" |
|||
Adds policy to trainer. |
|||
:param parsed_behavior_id: Behavior identifiers that the policy should belong to. |
|||
:param policy: Policy to associate with name_behavior_id. |
|||
""" |
|||
if not isinstance(policy, TorchPolicy): |
|||
raise RuntimeError(f"policy {policy} must be an instance of TorchPolicy.") |
|||
self.policy = policy |
|||
self.policies[parsed_behavior_id.behavior_id] = policy |
|||
self.optimizer = self.create_poca_optimizer() |
|||
for _reward_signal in self.optimizer.reward_signals.keys(): |
|||
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) |
|||
|
|||
self.model_saver.register(self.policy) |
|||
self.model_saver.register(self.optimizer) |
|||
self.model_saver.initialize_or_load() |
|||
|
|||
# Needed to resume loads properly |
|||
self.step = policy.get_current_step() |
|||
|
|||
def get_policy(self, name_behavior_id: str) -> Policy: |
|||
""" |
|||
Gets policy from trainer associated with name_behavior_id |
|||
:param name_behavior_id: full identifier of policy |
|||
""" |
|||
|
|||
return self.policy |
|||
|
|||
|
|||
def lambda_return(r, value_estimates, gamma=0.99, lambd=0.8, value_next=0.0): |
|||
returns = np.zeros_like(r) |
|||
returns[-1] = r[-1] + gamma * value_next |
|||
for t in reversed(range(0, r.size - 1)): |
|||
returns[t] = ( |
|||
gamma * lambd * returns[t + 1] |
|||
+ r[t] |
|||
+ (1 - lambd) * gamma * value_estimates[t + 1] |
|||
) |
|||
return returns |
撰写
预览
正在加载...
取消
保存
Reference in new issue