浏览代码

POCA trainer (#5005)

Co-authored-by: Ervin Teng <ervin@unity3d.com>
Co-authored-by: Ruo-Ping Dong <ruoping.dong@unity3d.com>
Co-authored-by: Chris Elion <chris.elion@unity3d.com>
Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com>
/develop/input-actuator-tanks
GitHub 4 年前
当前提交
8f35bdd3
共有 28 个文件被更改,包括 2222 次插入156 次删除
  1. 2
      com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs
  2. 6
      ml-agents/mlagents/trainers/buffer.py
  3. 4
      ml-agents/mlagents/trainers/ghost/trainer.py
  4. 28
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  5. 62
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  6. 11
      ml-agents/mlagents/trainers/settings.py
  7. 6
      ml-agents/mlagents/trainers/stats.py
  8. 6
      ml-agents/mlagents/trainers/tests/check_env_trains.py
  9. 19
      ml-agents/mlagents/trainers/tests/dummy_config.py
  10. 2
      ml-agents/mlagents/trainers/tests/mock_brain.py
  11. 206
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  12. 16
      ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
  13. 11
      ml-agents/mlagents/trainers/tests/torch/test_agent_action.py
  14. 2
      ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
  15. 151
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  16. 19
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
  17. 79
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
  18. 16
      ml-agents/mlagents/trainers/torch/agent_action.py
  19. 30
      ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py
  20. 320
      ml-agents/mlagents/trainers/torch/networks.py
  21. 93
      ml-agents/mlagents/trainers/torch/utils.py
  22. 11
      ml-agents/mlagents/trainers/trainer/trainer_factory.py
  23. 4
      ml-agents/mlagents/trainers/trajectory.py
  24. 290
      ml-agents/mlagents/trainers/tests/torch/test_poca.py
  25. 0
      ml-agents/mlagents/trainers/poca/__init__.py
  26. 674
      ml-agents/mlagents/trainers/poca/optimizer_torch.py
  27. 310
      ml-agents/mlagents/trainers/poca/trainer.py

2
com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs


/// <summary>
/// A basic class implementation of MultiAgentGroup.
/// </summary>
internal class SimpleMultiAgentGroup : IMultiAgentGroup, IDisposable
public class SimpleMultiAgentGroup : IMultiAgentGroup, IDisposable
{
readonly int m_Id = MultiAgentGroupIdCounter.GetGroupId();
HashSet<Agent> m_Agents = new HashSet<Agent>();

6
ml-agents/mlagents/trainers/buffer.py


MASKS = "masks"
MEMORY = "memory"
CRITIC_MEMORY = "critic_memory"
BASELINE_MEMORY = "poca_baseline_memory"
PREV_ACTION = "prev_action"
ADVANTAGES = "advantages"

VALUE_ESTIMATES = "value_estimates"
RETURNS = "returns"
ADVANTAGE = "advantage"
BASELINES = "baselines"
AgentBufferKey = Union[

@staticmethod
def advantage_key(name: str) -> AgentBufferKey:
return RewardSignalKeyPrefix.ADVANTAGE, name
@staticmethod
def baseline_estimates_key(name: str) -> AgentBufferKey:
return RewardSignalKeyPrefix.BASELINES, name
class AgentBufferField(list):

4
ml-agents/mlagents/trainers/ghost/trainer.py


"""
if trajectory.done_reached:
# Assumption is that final reward is >0/0/<0 for win/draw/loss
final_reward = trajectory.steps[-1].reward
final_reward = (
trajectory.steps[-1].reward + trajectory.steps[-1].group_reward
)
result = 0.5
if final_reward > 0:
result = 1.0

28
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.optimizer import Optimizer
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.settings import (
TrainerSettings,
RewardSignalSettings,
RewardSignalType,
)
from mlagents.trainers.torch.utils import ModelUtils

def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
pass
def create_reward_signals(self, reward_signal_configs):
def create_reward_signals(
self, reward_signal_configs: Dict[RewardSignalType, RewardSignalSettings]
) -> None:
"""
Create reward signals
:param reward_signal_configs: Reward signal config.

)
def _evaluate_by_sequence(
self, tensor_obs: List[torch.Tensor], initial_memory: np.ndarray
self, tensor_obs: List[torch.Tensor], initial_memory: torch.Tensor
) -> Tuple[Dict[str, torch.Tensor], AgentBufferField, torch.Tensor]:
"""
Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the

# Compute values for the potentially truncated initial sequence
seq_obs = []
first_seq_len = self.policy.sequence_length
first_seq_len = leftover if leftover > 0 else self.policy.sequence_length
if leftover > 0:
first_seq_len = leftover
first_seq_obs = _obs[0:first_seq_len]
seq_obs.append(first_seq_obs)

seq_obs = []
for _ in range(self.policy.sequence_length):
all_next_memories.append(ModelUtils.to_numpy(_mem.squeeze()))
start = seq_num * self.policy.sequence_length - (
self.policy.sequence_length - leftover
)
end = (seq_num + 1) * self.policy.sequence_length - (
self.policy.sequence_length - leftover
)
start = seq_num * self.policy.sequence_length - (
self.policy.sequence_length - leftover
)
end = (seq_num + 1) * self.policy.sequence_length - (
self.policy.sequence_length - leftover
)
seq_obs.append(_obs[start:end])
values, _mem = self.critic.critic_pass(
seq_obs, _mem, sequence_length=self.policy.sequence_length

62
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


def critic(self):
return self._critic
def ppo_value_loss(
self,
values: Dict[str, torch.Tensor],
old_values: Dict[str, torch.Tensor],
returns: Dict[str, torch.Tensor],
epsilon: float,
loss_masks: torch.Tensor,
) -> torch.Tensor:
"""
Evaluates value loss for PPO.
:param values: Value output of the current network.
:param old_values: Value stored with experiences in buffer.
:param returns: Computed returns.
:param epsilon: Clipping value for value estimate.
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
value_losses = []
for name, head in values.items():
old_val_tensor = old_values[name]
returns_tensor = returns[name]
clipped_value_estimate = old_val_tensor + torch.clamp(
head - old_val_tensor, -1 * epsilon, epsilon
)
v_opt_a = (returns_tensor - head) ** 2
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))
return value_loss
def ppo_policy_loss(
self,
advantages: torch.Tensor,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
loss_masks: torch.Tensor,
) -> torch.Tensor:
"""
Evaluate PPO policy loss.
:param advantages: Computed advantages.
:param log_probs: Current policy probabilities
:param old_log_probs: Past policy probabilities
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
advantage = advantages.unsqueeze(-1)
decay_epsilon = self.hyperparameters.epsilon
r_theta = torch.exp(log_probs - old_log_probs)
p_opt_a = r_theta * advantage
p_opt_b = (
torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage
)
policy_loss = -1 * ModelUtils.masked_mean(
torch.min(p_opt_a, p_opt_b), loss_masks
)
return policy_loss
@timed
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"""

old_log_probs = ActionLogProbs.from_buffer(batch).flatten()
log_probs = log_probs.flatten()
loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool)
value_loss = self.ppo_value_loss(
value_loss = ModelUtils.trust_region_value_loss(
policy_loss = self.ppo_policy_loss(
policy_loss = ModelUtils.trust_region_policy_loss(
decay_eps,
)
loss = (
policy_loss

11
ml-agents/mlagents/trainers/settings.py


return self.steps_per_update
# POCA uses the same hyperparameters as PPO
POCASettings = PPOSettings
# INTRINSIC REWARD SIGNALS #############################################################
class RewardSignalType(Enum):
EXTRINSIC: str = "extrinsic"

class TrainerType(Enum):
PPO: str = "ppo"
SAC: str = "sac"
POCA: str = "poca"
_mapping = {TrainerType.PPO: PPOSettings, TrainerType.SAC: SACSettings}
_mapping = {
TrainerType.PPO: PPOSettings,
TrainerType.SAC: SACSettings,
TrainerType.POCA: POCASettings,
}
return _mapping[self]

6
ml-agents/mlagents/trainers/stats.py


log_info.append(f"Rank: {self.rank}")
log_info.append(f"Mean Reward: {stats_summary.mean:0.3f}")
log_info.append(f"Std of Reward: {stats_summary.std:0.3f}")
if "Environment/Group Cumulative Reward" in values:
group_stats_summary = values["Environment/Group Cumulative Reward"]
log_info.append(f"Mean Group Reward: {group_stats_summary.mean:0.3f}")
else:
log_info.append(f"Std of Reward: {stats_summary.std:0.3f}")
log_info.append(is_training)
if self.self_play and "Self-play/ELO" in values:

6
ml-agents/mlagents/trainers/tests/check_env_trains.py


self, category: str, values: Dict[str, StatsSummary], step: int
) -> None:
for val, stats_summary in values.items():
if val == "Environment/Cumulative Reward":
if (
val == "Environment/Cumulative Reward"
or val == "Environment/Group Cumulative Reward"
):
print(step, val, stats_summary.aggregated_value)
self._last_reward_summary[category] = stats_summary.aggregated_value

19
ml-agents/mlagents/trainers/tests/dummy_config.py


import copy
import os
from mlagents.trainers.settings import (
POCASettings,
TrainerSettings,
PPOSettings,
SACSettings,

threaded=False,
)
_POCA_CONFIG = TrainerSettings(
trainer_type=TrainerType.POCA,
hyperparameters=POCASettings(
learning_rate=5.0e-3,
learning_rate_schedule=ScheduleType.CONSTANT,
batch_size=16,
buffer_size=64,
),
network_settings=NetworkSettings(num_layers=1, hidden_units=32),
summary_freq=500,
max_steps=3000,
threaded=False,
)
def ppo_dummy_config():
return copy.deepcopy(_PPO_CONFIG)

return copy.deepcopy(_SAC_CONFIG)
def poca_dummy_config():
return copy.deepcopy(_POCA_CONFIG)
@pytest.fixture

2
ml-agents/mlagents/trainers/tests/mock_brain.py


behavior_spec: BehaviorSpec,
memory_size: int = 10,
exclude_key_list: List[str] = None,
num_other_agents_in_group: int = 0,
) -> AgentBuffer:
trajectory = make_fake_trajectory(
length,

num_other_agents_in_group=num_other_agents_in_group,
)
buffer = trajectory.to_agentbuffer()
# If a key_list was given, remove those keys

206
ml-agents/mlagents/trainers/tests/simple_test_envs.py


return (decision_step, terminal_step)
class MultiAgentEnvironment(BaseEnv):
"""
The MultiAgentEnvironment maintains a list of SimpleEnvironment, one for each agent.
When sending DecisionSteps and TerminalSteps to the trainers, it first batches the
decision steps from the individual environments. When setting actions, it indexes the
batched ActionTuple to obtain the ActionTuple for individual agents
"""
def __init__(
self,
brain_names,
step_size=STEP_SIZE,
num_visual=0,
num_vector=1,
num_var_len=0,
vis_obs_size=VIS_OBS_SIZE,
vec_obs_size=OBS_SIZE,
var_len_obs_size=VAR_LEN_SIZE,
action_sizes=(1, 0),
num_agents=2,
):
super().__init__()
self.envs = {}
self.dones = {}
self.just_died = set()
self.names = brain_names
self.final_rewards: Dict[str, List[float]] = {}
for name in brain_names:
self.final_rewards[name] = []
for i in range(num_agents):
name_and_num = name + str(i)
self.envs[name_and_num] = SimpleEnvironment(
[name],
step_size,
num_visual,
num_vector,
num_var_len,
vis_obs_size,
vec_obs_size,
var_len_obs_size,
action_sizes,
)
self.dones[name_and_num] = False
self.envs[name_and_num].reset()
# All envs have the same behavior spec, so just get the last one.
self.behavior_spec = self.envs[name_and_num].behavior_spec
self.action_spec = self.envs[name_and_num].action_spec
self.num_agents = num_agents
@property
def all_done(self):
return all(self.dones.values())
@property
def behavior_specs(self):
behavior_dict = {}
for n in self.names:
behavior_dict[n] = self.behavior_spec
return BehaviorMapping(behavior_dict)
def set_action_for_agent(self, behavior_name, agent_id, action):
pass
def set_actions(self, behavior_name, action):
# The ActionTuple contains the actions for all n_agents. This
# slices the ActionTuple into an action tuple for each environment
# and sets it. The index j is used to ignore agents that have already
# reached done.
j = 0
for i in range(self.num_agents):
_act = ActionTuple()
name_and_num = behavior_name + str(i)
env = self.envs[name_and_num]
if not self.dones[name_and_num]:
if self.action_spec.continuous_size > 0:
_act.add_continuous(action.continuous[j : j + 1])
if self.action_spec.discrete_size > 0:
_disc_list = [action.discrete[j, :]]
_act.add_discrete(np.array(_disc_list))
j += 1
env.action[behavior_name] = _act
def get_steps(self, behavior_name):
# This gets the individual DecisionSteps and TerminalSteps
# from the envs and merges them into a batch to be sent
# to the AgentProcessor.
dec_vec_obs = []
dec_reward = []
dec_group_reward = []
dec_agent_id = []
dec_group_id = []
ter_vec_obs = []
ter_reward = []
ter_group_reward = []
ter_agent_id = []
ter_group_id = []
interrupted = []
action_mask = None
terminal_step = TerminalSteps.empty(self.behavior_spec)
decision_step = None
for i in range(self.num_agents):
name_and_num = behavior_name + str(i)
env = self.envs[name_and_num]
_dec, _term = env.step_result[behavior_name]
if not self.dones[name_and_num]:
dec_agent_id.append(i)
dec_group_id.append(1)
if len(dec_vec_obs) > 0:
for j, obs in enumerate(_dec.obs):
dec_vec_obs[j] = np.concatenate((dec_vec_obs[j], obs), axis=0)
else:
for obs in _dec.obs:
dec_vec_obs.append(obs)
dec_reward.append(_dec.reward[0])
dec_group_reward.append(_dec.group_reward[0])
if _dec.action_mask is not None:
if action_mask is None:
action_mask = []
if len(action_mask) > 0:
action_mask[0] = np.concatenate(
(action_mask[0], _dec.action_mask[0]), axis=0
)
else:
action_mask.append(_dec.action_mask[0])
if len(_term.reward) > 0 and name_and_num in self.just_died:
ter_agent_id.append(i)
ter_group_id.append(1)
if len(ter_vec_obs) > 0:
for j, obs in enumerate(_term.obs):
ter_vec_obs[j] = np.concatenate((ter_vec_obs[j], obs), axis=0)
else:
for obs in _term.obs:
ter_vec_obs.append(obs)
ter_reward.append(_term.reward[0])
ter_group_reward.append(_term.group_reward[0])
interrupted.append(False)
self.just_died.remove(name_and_num)
decision_step = DecisionSteps(
dec_vec_obs,
dec_reward,
dec_agent_id,
action_mask,
dec_group_id,
dec_group_reward,
)
terminal_step = TerminalSteps(
ter_vec_obs,
ter_reward,
interrupted,
ter_agent_id,
ter_group_id,
ter_group_reward,
)
return (decision_step, terminal_step)
def step(self) -> None:
# Steps all environments and calls reset if all agents are done.
for name in self.names:
for i in range(self.num_agents):
name_and_num = name + str(i)
# Does not step the env if done
if not self.dones[name_and_num]:
env = self.envs[name_and_num]
# Reproducing part of env step to intercept Dones
assert all(action is not None for action in env.action.values())
done = env._take_action(name)
reward = env._compute_reward(name, done)
self.dones[name_and_num] = done
if done:
self.just_died.add(name_and_num)
if self.all_done:
env.step_result[name] = env._make_batched_step(
name, done, 0.0, reward
)
self.final_rewards[name].append(reward)
self.reset()
elif done:
# This agent has finished but others are still running.
# This gives a reward of the time penalty if this agent
# is successful and the negative env reward if it fails.
ceil_reward = min(-TIME_PENALTY, reward)
env.step_result[name] = env._make_batched_step(
name, done, ceil_reward, 0.0
)
self.final_rewards[name].append(reward)
else:
env.step_result[name] = env._make_batched_step(
name, done, reward, 0.0
)
def reset(self) -> None: # type: ignore
for name in self.names:
for i in range(self.num_agents):
name_and_num = name + str(i)
self.dones[name_and_num] = False
@property
def reset_parameters(self) -> Dict[str, str]:
return {}
def close(self):
pass
class RecordEnvironment(SimpleEnvironment):
def __init__(
self,

16
ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py


from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer
from mlagents.trainers.settings import TrainerSettings, PPOSettings, SACSettings
from mlagents.trainers.settings import (
TrainerSettings,
PPOSettings,
SACSettings,
POCASettings,
)
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.torch.test_policy import create_policy_mock
from mlagents.trainers.torch.utils import ModelUtils

@pytest.mark.parametrize(
"optimizer",
[(TorchPPOOptimizer, PPOSettings), (TorchSACOptimizer, SACSettings)],
ids=["ppo", "sac"],
[
(TorchPPOOptimizer, PPOSettings),
(TorchSACOptimizer, SACSettings),
(TorchPOCAOptimizer, POCASettings),
],
ids=["ppo", "sac", "poca"],
)
def test_load_save_optimizer(tmp_path, optimizer):
OptimizerClass, HyperparametersClass = optimizer

11
ml-agents/mlagents/trainers/tests/torch/test_agent_action.py


assert (agent_1_act.discrete_tensor[3:] == 0).all()
def test_slice():
# Both continuous and discrete
aa = AgentAction(
torch.tensor([[1.0], [1.0], [1.0]]),
[torch.tensor([2, 1, 0]), torch.tensor([1, 2, 0])],
)
saa = aa.slice(0, 2)
assert saa.continuous_tensor.shape == (2, 1)
assert saa.discrete_tensor.shape == (2, 2)
def test_to_flat():
# Both continuous and discrete
aa = AgentAction(

2
ml-agents/mlagents/trainers/tests/torch/test_hybrid.py


buffer_init_steps=0,
)
config = attr.evolve(
SAC_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=6000
SAC_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=4000
)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9)

151
ml-agents/mlagents/trainers/tests/torch/test_networks.py


import pytest
from mlagents.torch_utils import torch
from mlagents.trainers.torch.agent_action import AgentAction
MultiAgentNetworkBody,
ValueNetwork,
SimpleActor,
SharedActorCritic,

def test_networkbody_lstm():
torch.manual_seed(0)
obs_size = 4
seq_len = 16
seq_len = 6
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=12)
)

create_observation_specs_with_shapes(obs_shapes), network_settings
)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4)
sample_obs = torch.ones((1, seq_len, obs_size))
sample_obs = torch.ones((seq_len, obs_size))
for _ in range(200):
encoded, _ = networkbody([sample_obs], memories=torch.ones(1, seq_len, 12))
for _ in range(300):
encoded, _ = networkbody(
[sample_obs], memories=torch.ones(1, 1, 12), sequence_length=seq_len
)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()

assert _enc == pytest.approx(1.0, abs=0.1)
@pytest.mark.parametrize("with_actions", [True, False], ids=["actions", "no_actions"])
def test_multinetworkbody_vector(with_actions):
torch.manual_seed(0)
obs_size = 4
act_size = 2
n_agents = 3
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
networkbody = MultiAgentNetworkBody(
create_observation_specs_with_shapes(obs_shapes), network_settings, action_spec
)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = [[0.1 * torch.ones((1, obs_size))] for _ in range(n_agents)]
# simulate baseline in POCA
sample_act = [
AgentAction(
0.1 * torch.ones((1, 2)), [0.1 * torch.ones(1) for _ in range(act_size)]
)
for _ in range(n_agents - 1)
]
for _ in range(300):
if with_actions:
encoded, _ = networkbody(
obs_only=sample_obs[:1], obs=sample_obs[1:], actions=sample_act
)
else:
encoded, _ = networkbody(obs_only=sample_obs, obs=[], actions=[])
assert encoded.shape == (1, network_settings.hidden_units)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten().tolist():
assert _enc == pytest.approx(1.0, abs=0.1)
@pytest.mark.parametrize("with_actions", [True, False], ids=["actions", "no_actions"])
def test_multinetworkbody_lstm(with_actions):
torch.manual_seed(0)
obs_size = 4
act_size = 2
seq_len = 16
n_agents = 3
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=12)
)
obs_shapes = [(obs_size,)]
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
networkbody = MultiAgentNetworkBody(
create_observation_specs_with_shapes(obs_shapes), network_settings, action_spec
)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4)
sample_obs = [[0.1 * torch.ones((seq_len, obs_size))] for _ in range(n_agents)]
# simulate baseline in POCA
sample_act = [
AgentAction(
0.1 * torch.ones((seq_len, 2)),
[0.1 * torch.ones(seq_len) for _ in range(act_size)],
)
for _ in range(n_agents - 1)
]
for _ in range(300):
if with_actions:
encoded, _ = networkbody(
obs_only=sample_obs[:1],
obs=sample_obs[1:],
actions=sample_act,
memories=torch.ones(1, 1, 12),
sequence_length=seq_len,
)
else:
encoded, _ = networkbody(
obs_only=sample_obs,
obs=[],
actions=[],
memories=torch.ones(1, 1, 12),
sequence_length=seq_len,
)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten().tolist():
assert _enc == pytest.approx(1.0, abs=0.1)
@pytest.mark.parametrize("with_actions", [True, False], ids=["actions", "no_actions"])
def test_multinetworkbody_visual(with_actions):
torch.manual_seed(0)
act_size = 2
n_agents = 3
obs_size = 4
vis_obs_size = (84, 84, 3)
network_settings = NetworkSettings()
obs_shapes = [(obs_size,), vis_obs_size]
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
networkbody = MultiAgentNetworkBody(
create_observation_specs_with_shapes(obs_shapes), network_settings, action_spec
)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = [
[0.1 * torch.ones((1, obs_size))] + [0.1 * torch.ones((1, 84, 84, 3))]
for _ in range(n_agents)
]
# simulate baseline in POCA
sample_act = [
AgentAction(
0.1 * torch.ones((1, 2)), [0.1 * torch.ones(1) for _ in range(act_size)]
)
for _ in range(n_agents - 1)
]
for _ in range(300):
if with_actions:
encoded, _ = networkbody(
obs_only=sample_obs[:1], obs=sample_obs[1:], actions=sample_act
)
else:
encoded, _ = networkbody(obs_only=sample_obs, obs=[], actions=[])
assert encoded.shape == (1, network_settings.hidden_units)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten().tolist():
assert _enc == pytest.approx(1.0, abs=0.1)
def test_valuenetwork():
torch.manual_seed(0)
obs_size = 4

act_size = 2
mask = torch.ones([1, act_size * 2])
stream_names = [f"stream_name{n}" for n in range(4)]
# action_spec = ActionSpec.create_continuous(act_size[0])
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
if shared:
actor = critic = SharedActorCritic(

19
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py


from mlagents.trainers.buffer import BufferKey
import numpy as np
from mlagents.trainers.torch.components.reward_providers import (
ExtrinsicRewardProvider,
create_reward_provider,

extrinsic_rp = ExtrinsicRewardProvider(behavior_spec, settings)
generated_rewards = extrinsic_rp.evaluate(buffer)
assert (generated_rewards == reward).all()
# Test group rewards. Rewards should be double of the environment rewards, but shouldn't count
# the groupmate rewards.
buffer[BufferKey.GROUP_REWARD] = buffer[BufferKey.ENVIRONMENT_REWARDS]
# 2 agents with identical rewards
buffer[BufferKey.GROUPMATE_REWARDS].set(
[np.ones(1, dtype=np.float32) * reward] * 2
for _ in range(buffer.num_experiences)
)
generated_rewards = extrinsic_rp.evaluate(buffer)
assert (generated_rewards == 2 * reward).all()
# Test groupmate rewards. Total reward should be indiv_reward + 2 * teammate_reward + group_reward
extrinsic_rp = ExtrinsicRewardProvider(behavior_spec, settings)
extrinsic_rp.add_groupmate_rewards = True
generated_rewards = extrinsic_rp.evaluate(buffer)
assert (generated_rewards == 4 * reward).all()

79
ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py


from mlagents.trainers.tests.simple_test_envs import (
SimpleEnvironment,
MultiAgentEnvironment,
MemoryEnvironment,
RecordEnvironment,
)

ActionSpecProto,
)
from mlagents.trainers.tests.dummy_config import ppo_dummy_config, sac_dummy_config
from mlagents.trainers.tests.dummy_config import (
ppo_dummy_config,
sac_dummy_config,
poca_dummy_config,
)
from mlagents.trainers.tests.check_env_trains import (
check_environment_trains,
default_reward_processor,

PPO_TORCH_CONFIG = ppo_dummy_config()
SAC_TORCH_CONFIG = sac_dummy_config()
POCA_TORCH_CONFIG = poca_dummy_config()
@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)])
def test_simple_poca(action_sizes):
env = MultiAgentEnvironment([BRAIN_NAME], action_sizes=action_sizes, num_agents=2)
config = attr.evolve(POCA_TORCH_CONFIG)
check_environment_trains(env, {BRAIN_NAME: config})
@pytest.mark.parametrize("num_visual", [1, 2])
def test_visual_poca(num_visual):
env = MultiAgentEnvironment(
[BRAIN_NAME], action_sizes=(0, 1), num_agents=2, num_visual=num_visual
)
new_hyperparams = attr.evolve(
POCA_TORCH_CONFIG.hyperparameters, learning_rate=3.0e-4
)
config = attr.evolve(POCA_TORCH_CONFIG, hyperparameters=new_hyperparams)
check_environment_trains(env, {BRAIN_NAME: config})
@pytest.mark.parametrize("num_var_len", [1, 2])
@pytest.mark.parametrize("num_vector", [0, 1])
@pytest.mark.parametrize("num_vis", [0, 1])
def test_var_len_obs_poca(num_vis, num_vector, num_var_len):
env = MultiAgentEnvironment(
[BRAIN_NAME],
action_sizes=(0, 1),
num_visual=num_vis,
num_vector=num_vector,
num_var_len=num_var_len,
step_size=0.2,
num_agents=2,
)
new_hyperparams = attr.evolve(
POCA_TORCH_CONFIG.hyperparameters, learning_rate=3.0e-4
)
config = attr.evolve(POCA_TORCH_CONFIG, hyperparameters=new_hyperparams)
check_environment_trains(env, {BRAIN_NAME: config})
@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)])
@pytest.mark.parametrize("is_multiagent", [True, False])
def test_recurrent_poca(action_sizes, is_multiagent):
if is_multiagent:
# This is not a recurrent environment, just check if LSTM doesn't crash
env = MultiAgentEnvironment(
[BRAIN_NAME], action_sizes=action_sizes, num_agents=2
)
else:
# Actually test LSTM here
env = MemoryEnvironment([BRAIN_NAME], action_sizes=action_sizes)
new_network_settings = attr.evolve(
POCA_TORCH_CONFIG.network_settings,
memory=NetworkSettings.MemorySettings(memory_size=16),
)
new_hyperparams = attr.evolve(
POCA_TORCH_CONFIG.hyperparameters,
learning_rate=1.0e-3,
batch_size=64,
buffer_size=128,
)
config = attr.evolve(
POCA_TORCH_CONFIG,
hyperparameters=new_hyperparams,
network_settings=new_network_settings,
max_steps=500 if is_multiagent else 6000,
)
check_environment_trains(
env, {BRAIN_NAME: config}, success_threshold=None if is_multiagent else 0.9
)
@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)])

16
ml-agents/mlagents/trainers/torch/agent_action.py


else:
return torch.empty(0)
def slice(self, start: int, end: int) -> "AgentAction":
"""
Returns an AgentAction with the continuous and discrete tensors slices
from index start to index end.
"""
_cont = None
_disc_list = []
if self.continuous_tensor is not None:
_cont = self.continuous_tensor[start:end]
if self.discrete_list is not None and len(self.discrete_list) > 0:
for _disc in self.discrete_list:
_disc_list.append(_disc[start:end])
return AgentAction(_cont, _disc_list)
def to_action_tuple(self, clip: bool = False) -> ActionTuple:
"""
Returns an ActionTuple

:return: Tensor of flattened actions.
"""
# if there are any discrete actions, create one-hot
if self.discrete_list is not None and self.discrete_list:
if self.discrete_list is not None and len(self.discrete_list) > 0:
discrete_oh = ModelUtils.actions_to_onehot(
self.discrete_tensor, discrete_branches
)

30
ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py


from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (
BaseRewardProvider,
)
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.settings import RewardSignalSettings
"""
Evaluates extrinsic reward. For single-agent, this equals the individual reward
given to the agent. For the POCA algorithm, we want not only the individual reward
but also the team and the individual rewards of the other agents.
"""
def __init__(self, specs: BehaviorSpec, settings: RewardSignalSettings) -> None:
super().__init__(specs, settings)
self.add_groupmate_rewards = False
return np.array(mini_batch[BufferKey.ENVIRONMENT_REWARDS], dtype=np.float32)
indiv_rewards = np.array(
mini_batch[BufferKey.ENVIRONMENT_REWARDS], dtype=np.float32
)
total_rewards = indiv_rewards
if BufferKey.GROUPMATE_REWARDS in mini_batch and self.add_groupmate_rewards:
groupmate_rewards_list = mini_batch[BufferKey.GROUPMATE_REWARDS]
groupmate_rewards_sum = np.array(
[sum(_rew) for _rew in groupmate_rewards_list], dtype=np.float32
)
total_rewards += groupmate_rewards_sum
if BufferKey.GROUP_REWARD in mini_batch:
group_rewards = np.array(
mini_batch[BufferKey.GROUP_REWARD], dtype=np.float32
)
# Add all the group rewards to the individual rewards
total_rewards += group_rewards
return total_rewards
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
return {}

320
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.torch.action_model import ActionModel
from mlagents.trainers.torch.agent_action import AgentAction
from mlagents.trainers.torch.action_log_probs import ActionLogProbs
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.settings import NetworkSettings, EncoderType
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, Initialization
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil

get_zero_entities_mask,
)
from mlagents.trainers.exception import UnityTrainerException
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]

EPSILON = 1e-7
class NetworkBody(nn.Module):
class ObservationEncoder(nn.Module):
network_settings: NetworkSettings,
encoded_act_size: int = 0,
h_size: int,
vis_encode_type: EncoderType,
normalize: bool = False,
"""
Returns an ObservationEncoder that can process and encode a set of observations.
Will use an RSA if needed for variable length observations.
"""
self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
self.h_size = network_settings.hidden_units
self.m_size = (
network_settings.memory.memory_size
if network_settings.memory is not None
else 0
)
observation_specs,
self.h_size,
network_settings.vis_encode_type,
normalize=self.normalize,
observation_specs, h_size, vis_encode_type, normalize=normalize
entity_num_max: int = 0
var_processors = [p for p in self.processors if isinstance(p, EntityEmbedding)]
for processor in var_processors:
entity_max: int = processor.entity_num_max_elements
# Only adds entity max if it was known at construction
if entity_max > 0:
entity_num_max += entity_max
if len(var_processors) > 0:
if sum(self.embedding_sizes):
self.x_self_encoder = LinearEncoder(
sum(self.embedding_sizes),
1,
self.h_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.h_size) ** 0.5,
)
self.rsa = ResidualSelfAttention(self.h_size, entity_num_max)
total_enc_size = sum(self.embedding_sizes) + self.h_size
self.rsa, self.x_self_encoder = ModelUtils.create_residual_self_attention(
self.processors, self.embedding_sizes, h_size
)
if self.rsa is not None:
total_enc_size = sum(self.embedding_sizes) + h_size
total_enc_size += encoded_act_size
self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
)
self.normalize = normalize
self._total_enc_size = total_enc_size
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)
else:
self.lstm = None # type: ignore
@property
def total_enc_size(self) -> int:
"""
Returns the total encoding size for this ObservationEncoder.
"""
return self._total_enc_size
def update_normalization(self, buffer: AgentBuffer) -> None:
obs = ObsUtil.from_buffer(buffer, len(self.processors))

def copy_normalization(self, other_network: "NetworkBody") -> None:
def copy_normalization(self, other_encoder: "ObservationEncoder") -> None:
for n1, n2 in zip(self.processors, other_network.processors):
for n1, n2 in zip(self.processors, other_encoder.processors):
@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
def forward(
self,
inputs: List[torch.Tensor],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
"""
Encode observations using a list of processors and an RSA.
:param inputs: List of Tensors corresponding to a set of obs.
:param processors: a ModuleList of the input processors to be applied to these obs.
:param rsa: Optionally, an RSA to use for variable length obs.
:param x_self_encoder: Optionally, an encoder to use for x_self (in this case, the non-variable inputs.).
"""
encodes = []
var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []

input_exist = True
else:
input_exist = False
if len(var_len_processor_inputs) > 0:
if len(var_len_processor_inputs) > 0 and self.rsa is not None:
processed_self = self.x_self_encoder(encoded_self) if input_exist else None
processed_self = (
self.x_self_encoder(encoded_self)
if input_exist and self.x_self_encoder is not None
else None
)
for processor, var_len_input in var_len_processor_inputs:
embeddings.append(processor(processed_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)

encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)
if not input_exist:
raise Exception(
raise UnityTrainerException(
return encoded_self
class NetworkBody(nn.Module):
def __init__(
self,
observation_specs: List[ObservationSpec],
network_settings: NetworkSettings,
encoded_act_size: int = 0,
):
super().__init__()
self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
self.h_size = network_settings.hidden_units
self.m_size = (
network_settings.memory.memory_size
if network_settings.memory is not None
else 0
)
self.observation_encoder = ObservationEncoder(
observation_specs,
self.h_size,
network_settings.vis_encode_type,
self.normalize,
)
self.processors = self.observation_encoder.processors
total_enc_size = self.observation_encoder.total_enc_size
total_enc_size += encoded_act_size
self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
)
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)
else:
self.lstm = None # type: ignore
def update_normalization(self, buffer: AgentBuffer) -> None:
self.observation_encoder.update_normalization(buffer)
def copy_normalization(self, other_network: "NetworkBody") -> None:
self.observation_encoder.copy_normalization(other_network.observation_encoder)
@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
def forward(
self,
inputs: List[torch.Tensor],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encoded_self = self.observation_encoder(inputs)
if actions is not None:
encoded_self = torch.cat([encoded_self, actions], dim=1)
encoding = self.linear_encoder(encoded_self)

return encoding, memories
class MultiAgentNetworkBody(torch.nn.Module):
"""
A network body that uses a self attention layer to handle state
and action input from a potentially variable number of agents that
share the same observation and action space.
"""
def __init__(
self,
observation_specs: List[ObservationSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
):
super().__init__()
self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
self.h_size = network_settings.hidden_units
self.m_size = (
network_settings.memory.memory_size
if network_settings.memory is not None
else 0
)
self.action_spec = action_spec
self.observation_encoder = ObservationEncoder(
observation_specs,
self.h_size,
network_settings.vis_encode_type,
self.normalize,
)
self.processors = self.observation_encoder.processors
# Modules for multi-agent self-attention
obs_only_ent_size = self.observation_encoder.total_enc_size
q_ent_size = (
obs_only_ent_size
+ sum(self.action_spec.discrete_branches)
+ self.action_spec.continuous_size
)
self.obs_encoder = EntityEmbedding(obs_only_ent_size, None, self.h_size)
self.obs_action_encoder = EntityEmbedding(q_ent_size, None, self.h_size)
self.self_attn = ResidualSelfAttention(self.h_size)
self.linear_encoder = LinearEncoder(
self.h_size,
network_settings.num_layers,
self.h_size,
kernel_gain=(0.125 / self.h_size) ** 0.5,
)
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)
else:
self.lstm = None # type: ignore
@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
def update_normalization(self, buffer: AgentBuffer) -> None:
self.observation_encoder.update_normalization(buffer)
def copy_normalization(self, other_network: "MultiAgentNetworkBody") -> None:
self.observation_encoder.copy_normalization(other_network.observation_encoder)
def _get_masks_from_nans(self, obs_tensors: List[torch.Tensor]) -> torch.Tensor:
"""
Get attention masks by grabbing an arbitrary obs across all the agents
Since these are raw obs, the padded values are still NaN
"""
only_first_obs = [_all_obs[0] for _all_obs in obs_tensors]
# Just get the first element in each obs regardless of its dimension. This will speed up
# searching for NaNs.
only_first_obs_flat = torch.stack(
[_obs.flatten(start_dim=1)[:, 0] for _obs in only_first_obs], dim=1
)
# Get the mask from NaNs
attn_mask = only_first_obs_flat.isnan().type(torch.FloatTensor)
return attn_mask
def _copy_and_remove_nans_from_obs(
self, all_obs: List[List[torch.Tensor]], attention_mask: torch.Tensor
) -> List[List[torch.Tensor]]:
"""
Helper function to remove NaNs from observations using an attention mask.
"""
obs_with_no_nans = []
for i_agent, single_agent_obs in enumerate(all_obs):
no_nan_obs = []
for obs in single_agent_obs:
new_obs = obs.clone()
new_obs[
attention_mask.type(torch.BoolTensor)[:, i_agent], ::
] = 0.0 # Remoove NaNs fast
no_nan_obs.append(new_obs)
obs_with_no_nans.append(no_nan_obs)
return obs_with_no_nans
def forward(
self,
obs_only: List[List[torch.Tensor]],
obs: List[List[torch.Tensor]],
actions: List[AgentAction],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns sampled actions.
If memory is enabled, return the memories as well.
:param obs_only: Observations to be processed that do not have corresponding actions.
These are encoded with the obs_encoder.
:param obs: Observations to be processed that do have corresponding actions.
After concatenation with actions, these are processed with obs_action_encoder.
:param actions: After concatenation with obs, these are processed with obs_action_encoder.
:param memories: If using memory, a Tensor of initial memories.
:param sequence_length: If using memory, the sequence length.
"""
self_attn_masks = []
self_attn_inputs = []
concat_f_inp = []
if obs:
obs_attn_mask = self._get_masks_from_nans(obs)
obs = self._copy_and_remove_nans_from_obs(obs, obs_attn_mask)
for inputs, action in zip(obs, actions):
encoded = self.observation_encoder(inputs)
cat_encodes = [
encoded,
action.to_flat(self.action_spec.discrete_branches),
]
concat_f_inp.append(torch.cat(cat_encodes, dim=1))
f_inp = torch.stack(concat_f_inp, dim=1)
self_attn_masks.append(obs_attn_mask)
self_attn_inputs.append(self.obs_action_encoder(None, f_inp))
concat_encoded_obs = []
if obs_only:
obs_only_attn_mask = self._get_masks_from_nans(obs_only)
obs_only = self._copy_and_remove_nans_from_obs(obs_only, obs_only_attn_mask)
for inputs in obs_only:
encoded = self.observation_encoder(inputs)
concat_encoded_obs.append(encoded)
g_inp = torch.stack(concat_encoded_obs, dim=1)
self_attn_masks.append(obs_only_attn_mask)
self_attn_inputs.append(self.obs_encoder(None, g_inp))
encoded_entity = torch.cat(self_attn_inputs, dim=1)
encoded_state = self.self_attn(encoded_entity, self_attn_masks)
encoding = self.linear_encoder(encoded_state)
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
return encoding, memories
class Critic(abc.ABC):
@abc.abstractmethod
def update_normalization(self, buffer: AgentBuffer) -> None:

end = 0
vis_index = 0
var_len_index = 0
for i, enc in enumerate(self.network_body.processors):
for i, enc in enumerate(self.network_body.observation_encoder.processors):
vec_size = self.network_body.embedding_sizes[i]
vec_size = self.network_body.observation_encoder.embedding_sizes[i]
end = start + vec_size
inputs.append(concatenated_vec_obs[:, start:end])
start = end

93
ml-agents/mlagents/trainers/torch/utils.py


from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict
from mlagents.trainers.torch.layers import LinearEncoder, Initialization
import numpy as np
from mlagents.trainers.torch.encoders import (

VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.torch.attention import EntityEmbedding
from mlagents.trainers.torch.attention import EntityEmbedding, ResidualSelfAttention
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import ObservationSpec, DimensionProperty

alpha=tau,
out=target_param.data,
)
@staticmethod
def create_residual_self_attention(
input_processors: nn.ModuleList, embedding_sizes: List[int], hidden_size: int
) -> Tuple[Optional[ResidualSelfAttention], Optional[LinearEncoder]]:
"""
Creates an RSA if there are variable length observations found in the input processors.
:param input_processors: A ModuleList of input processors as returned by the function
create_input_processors().
:param embedding sizes: A List of embedding sizes as returned by create_input_processors().
:param hidden_size: The hidden size to use for the RSA.
:returns: A Tuple of the RSA itself, a self encoder, and the embedding size after the RSA.
Returns None for the RSA and encoder if no var len inputs are detected.
"""
rsa, x_self_encoder = None, None
entity_num_max: int = 0
var_processors = [p for p in input_processors if isinstance(p, EntityEmbedding)]
for processor in var_processors:
entity_max: int = processor.entity_num_max_elements
# Only adds entity max if it was known at construction
if entity_max > 0:
entity_num_max += entity_max
if len(var_processors) > 0:
if sum(embedding_sizes):
x_self_encoder = LinearEncoder(
sum(embedding_sizes),
1,
hidden_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / hidden_size) ** 0.5,
)
rsa = ResidualSelfAttention(hidden_size, entity_num_max)
return rsa, x_self_encoder
@staticmethod
def trust_region_value_loss(
values: Dict[str, torch.Tensor],
old_values: Dict[str, torch.Tensor],
returns: Dict[str, torch.Tensor],
epsilon: float,
loss_masks: torch.Tensor,
) -> torch.Tensor:
"""
Evaluates value loss, clipping to stay within a trust region of old value estimates.
Used for PPO and POCA.
:param values: Value output of the current network.
:param old_values: Value stored with experiences in buffer.
:param returns: Computed returns.
:param epsilon: Clipping value for value estimate.
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
value_losses = []
for name, head in values.items():
old_val_tensor = old_values[name]
returns_tensor = returns[name]
clipped_value_estimate = old_val_tensor + torch.clamp(
head - old_val_tensor, -1 * epsilon, epsilon
)
v_opt_a = (returns_tensor - head) ** 2
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))
return value_loss
@staticmethod
def trust_region_policy_loss(
advantages: torch.Tensor,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
loss_masks: torch.Tensor,
epsilon: float,
) -> torch.Tensor:
"""
Evaluate policy loss clipped to stay within a trust region. Used for PPO and POCA.
:param advantages: Computed advantages.
:param log_probs: Current policy probabilities
:param old_log_probs: Past policy probabilities
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
advantage = advantages.unsqueeze(-1)
r_theta = torch.exp(log_probs - old_log_probs)
p_opt_a = r_theta * advantage
p_opt_b = torch.clamp(r_theta, 1.0 - epsilon, 1.0 + epsilon) * advantage
policy_loss = -1 * ModelUtils.masked_mean(
torch.min(p_opt_a, p_opt_b), loss_masks
)
return policy_loss

11
ml-agents/mlagents/trainers/trainer/trainer_factory.py


from mlagents.trainers.trainer import Trainer
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.poca.trainer import POCATrainer
from mlagents.trainers.ghost.trainer import GhostTrainer
from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.settings import TrainerSettings, TrainerType

if trainer_type == TrainerType.PPO:
trainer = PPOTrainer(
brain_name,
min_lesson_length,
trainer_settings,
train_model,
load_model,
seed,
trainer_artifact_path,
)
elif trainer_type == TrainerType.POCA:
trainer = POCATrainer(
brain_name,
min_lesson_length,
trainer_settings,

4
ml-agents/mlagents/trainers/trajectory.py


return self.steps[-1].done
@property
def teammate_dones_reached(self) -> bool:
def all_group_dones_reached(self) -> bool:
Returns true if all teammates are done at the end of the trajectory.
Returns true if all other agents in this trajectory are done at the end of the trajectory.
Combine with done_reached to check if the whole team is done.
"""
return all(_status.done for _status in self.steps[-1].group_status)

290
ml-agents/mlagents/trainers/tests/torch/test_poca.py


import pytest
import numpy as np
import attr
from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.mock_brain import copy_buffer_fields
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.tests.dummy_config import ( # noqa: F401
ppo_dummy_config,
curiosity_dummy_config,
gail_dummy_config,
)
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.buffer import BufferKey, RewardSignalUtil
@pytest.fixture
def dummy_config():
# poca has the same hyperparameters as ppo for now
return ppo_dummy_config()
VECTOR_ACTION_SPACE = 2
VECTOR_OBS_SPACE = 8
DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 64
NUM_AGENTS = 4
CONTINUOUS_ACTION_SPEC = ActionSpec.create_continuous(VECTOR_ACTION_SPACE)
DISCRETE_ACTION_SPEC = ActionSpec.create_discrete(tuple(DISCRETE_ACTION_SPACE))
def create_test_poca_optimizer(dummy_config, use_rnn, use_discrete, use_visual):
mock_specs = mb.setup_test_behavior_specs(
use_discrete,
use_visual,
vector_action_space=DISCRETE_ACTION_SPACE
if use_discrete
else VECTOR_ACTION_SPACE,
vector_obs_space=VECTOR_OBS_SPACE,
)
trainer_settings = attr.evolve(dummy_config)
trainer_settings.reward_signals = {
RewardSignalType.EXTRINSIC: RewardSignalSettings(strength=1.0, gamma=0.99)
}
trainer_settings.network_settings.memory = (
NetworkSettings.MemorySettings(sequence_length=16, memory_size=10)
if use_rnn
else None
)
policy = TorchPolicy(0, mock_specs, trainer_settings, "test", False)
optimizer = TorchPOCAOptimizer(policy, trainer_settings)
return optimizer
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_poca_optimizer_update(dummy_config, rnn, visual, discrete):
# Test evaluate
optimizer = create_test_poca_optimizer(
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
# Test update
update_buffer = mb.simulate_rollout(
BUFFER_INIT_SAMPLES,
optimizer.policy.behavior_spec,
memory_size=optimizer.policy.m_size,
num_other_agents_in_group=NUM_AGENTS,
)
# Mock out reward signal eval
copy_buffer_fields(
update_buffer,
BufferKey.ENVIRONMENT_REWARDS,
[
BufferKey.ADVANTAGES,
RewardSignalUtil.returns_key("extrinsic"),
RewardSignalUtil.value_estimates_key("extrinsic"),
RewardSignalUtil.baseline_estimates_key("extrinsic"),
],
)
# Copy memories to critic memories
copy_buffer_fields(
update_buffer,
BufferKey.MEMORY,
[BufferKey.CRITIC_MEMORY, BufferKey.BASELINE_MEMORY],
)
return_stats = optimizer.update(
update_buffer,
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length,
)
# Make sure we have the right stats
required_stats = [
"Losses/Policy Loss",
"Losses/Value Loss",
"Policy/Learning Rate",
"Policy/Epsilon",
"Policy/Beta",
]
for stat in required_stats:
assert stat in return_stats.keys()
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_poca_get_value_estimates(dummy_config, rnn, visual, discrete):
optimizer = create_test_poca_optimizer(
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
time_horizon = 15
trajectory = make_fake_trajectory(
length=time_horizon,
observation_specs=optimizer.policy.behavior_spec.observation_specs,
action_spec=DISCRETE_ACTION_SPEC if discrete else CONTINUOUS_ACTION_SPEC,
max_step_complete=True,
num_other_agents_in_group=NUM_AGENTS,
)
(
value_estimates,
baseline_estimates,
value_next,
value_memories,
baseline_memories,
) = optimizer.get_trajectory_and_baseline_value_estimates(
trajectory.to_agentbuffer(),
trajectory.next_obs,
trajectory.next_group_obs,
done=False,
)
for key, val in value_estimates.items():
assert type(key) is str
assert len(val) == 15
for key, val in baseline_estimates.items():
assert type(key) is str
assert len(val) == 15
if value_memories is not None:
assert len(value_memories) == 15
assert len(baseline_memories) == 15
(
value_estimates,
baseline_estimates,
value_next,
value_memories,
baseline_memories,
) = optimizer.get_trajectory_and_baseline_value_estimates(
trajectory.to_agentbuffer(),
trajectory.next_obs,
trajectory.next_group_obs,
done=True,
)
for key, val in value_next.items():
assert type(key) is str
assert val == 0.0
# Check if we ignore terminal states properly
optimizer.reward_signals["extrinsic"].use_terminal_states = False
(
value_estimates,
baseline_estimates,
value_next,
value_memories,
baseline_memories,
) = optimizer.get_trajectory_and_baseline_value_estimates(
trajectory.to_agentbuffer(),
trajectory.next_obs,
trajectory.next_group_obs,
done=False,
)
for key, val in value_next.items():
assert type(key) is str
assert val != 0.0
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
# We need to test this separately from test_reward_signals.py to ensure no interactions
def test_ppo_optimizer_update_curiosity(
dummy_config, curiosity_dummy_config, rnn, visual, discrete # noqa: F811
):
# Test evaluate
dummy_config.reward_signals = curiosity_dummy_config
optimizer = create_test_poca_optimizer(
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
# Test update
update_buffer = mb.simulate_rollout(
BUFFER_INIT_SAMPLES,
optimizer.policy.behavior_spec,
memory_size=optimizer.policy.m_size,
)
# Mock out reward signal eval
copy_buffer_fields(
update_buffer,
src_key=BufferKey.ENVIRONMENT_REWARDS,
dst_keys=[
BufferKey.ADVANTAGES,
RewardSignalUtil.returns_key("extrinsic"),
RewardSignalUtil.value_estimates_key("extrinsic"),
RewardSignalUtil.baseline_estimates_key("extrinsic"),
RewardSignalUtil.returns_key("curiosity"),
RewardSignalUtil.value_estimates_key("curiosity"),
RewardSignalUtil.baseline_estimates_key("curiosity"),
],
)
# Copy memories to critic memories
copy_buffer_fields(
update_buffer,
BufferKey.MEMORY,
[BufferKey.CRITIC_MEMORY, BufferKey.BASELINE_MEMORY],
)
optimizer.update(
update_buffer,
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length,
)
# We need to test this separately from test_reward_signals.py to ensure no interactions
def test_ppo_optimizer_update_gail(gail_dummy_config, dummy_config): # noqa: F811
# Test evaluate
dummy_config.reward_signals = gail_dummy_config
config = ppo_dummy_config()
optimizer = create_test_poca_optimizer(
config, use_rnn=False, use_discrete=False, use_visual=False
)
# Test update
update_buffer = mb.simulate_rollout(
BUFFER_INIT_SAMPLES, optimizer.policy.behavior_spec
)
# Mock out reward signal eval
copy_buffer_fields(
update_buffer,
src_key=BufferKey.ENVIRONMENT_REWARDS,
dst_keys=[
BufferKey.ADVANTAGES,
RewardSignalUtil.returns_key("extrinsic"),
RewardSignalUtil.value_estimates_key("extrinsic"),
RewardSignalUtil.baseline_estimates_key("extrinsic"),
RewardSignalUtil.returns_key("gail"),
RewardSignalUtil.value_estimates_key("gail"),
RewardSignalUtil.baseline_estimates_key("gail"),
],
)
update_buffer[BufferKey.CONTINUOUS_LOG_PROBS] = np.ones_like(
update_buffer[BufferKey.CONTINUOUS_ACTION]
)
optimizer.update(
update_buffer,
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length,
)
# Check if buffer size is too big
update_buffer = mb.simulate_rollout(3000, optimizer.policy.behavior_spec)
# Mock out reward signal eval
copy_buffer_fields(
update_buffer,
src_key=BufferKey.ENVIRONMENT_REWARDS,
dst_keys=[
BufferKey.ADVANTAGES,
RewardSignalUtil.returns_key("extrinsic"),
RewardSignalUtil.value_estimates_key("extrinsic"),
RewardSignalUtil.baseline_estimates_key("extrinsic"),
RewardSignalUtil.returns_key("gail"),
RewardSignalUtil.value_estimates_key("gail"),
RewardSignalUtil.baseline_estimates_key("gail"),
],
)
optimizer.update(
update_buffer,
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length,
)
if __name__ == "__main__":
pytest.main()

0
ml-agents/mlagents/trainers/poca/__init__.py

674
ml-agents/mlagents/trainers/poca/optimizer_torch.py


from typing import Dict, cast, List, Tuple, Optional
from mlagents.trainers.torch.components.reward_providers.extrinsic_reward_provider import (
ExtrinsicRewardProvider,
)
import numpy as np
import math
from mlagents.torch_utils import torch
from mlagents.trainers.buffer import (
AgentBuffer,
BufferKey,
RewardSignalUtil,
AgentBufferField,
)
from mlagents_envs.timers import timed
from mlagents_envs.base_env import ObservationSpec, ActionSpec
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.settings import (
RewardSignalSettings,
RewardSignalType,
TrainerSettings,
POCASettings,
)
from mlagents.trainers.torch.networks import Critic, MultiAgentNetworkBody
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.agent_action import AgentAction
from mlagents.trainers.torch.action_log_probs import ActionLogProbs
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.trajectory import ObsUtil, GroupObsUtil
from mlagents.trainers.settings import NetworkSettings
from mlagents_envs.logging_util import get_logger
logger = get_logger(__name__)
class TorchPOCAOptimizer(TorchOptimizer):
class POCAValueNetwork(torch.nn.Module, Critic):
"""
The POCAValueNetwork uses the MultiAgentNetworkBody to compute the value
and POCA baseline for a variable number of agents in a group that all
share the same observation and action space.
"""
def __init__(
self,
stream_names: List[str],
observation_specs: List[ObservationSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
):
torch.nn.Module.__init__(self)
self.network_body = MultiAgentNetworkBody(
observation_specs, network_settings, action_spec
)
if network_settings.memory is not None:
encoding_size = network_settings.memory.memory_size // 2
else:
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, 1)
@property
def memory_size(self) -> int:
return self.network_body.memory_size
def update_normalization(self, buffer: AgentBuffer) -> None:
self.network_body.update_normalization(buffer)
def baseline(
self,
obs_without_actions: List[torch.Tensor],
obs_with_actions: Tuple[List[List[torch.Tensor]], List[AgentAction]],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""
The POCA baseline marginalizes the action of the agent associated with self_obs.
It calls the forward pass of the MultiAgentNetworkBody with the state action
pairs of groupmates but just the state of the agent in question.
:param obs_without_actions: The obs of the agent for which to compute the baseline.
:param obs_with_actions: Tuple of observations and actions for all groupmates.
:param memories: If using memory, a Tensor of initial memories.
:param sequence_length: If using memory, the sequence length.
:return: A Tuple of Dict of reward stream to tensor and critic memories.
"""
(obs, actions) = obs_with_actions
encoding, memories = self.network_body(
obs_only=[obs_without_actions],
obs=obs,
actions=actions,
memories=memories,
sequence_length=sequence_length,
)
value_outputs, critic_mem_out = self.forward(
encoding, memories, sequence_length
)
return value_outputs, critic_mem_out
def critic_pass(
self,
obs: List[List[torch.Tensor]],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""
A centralized value function. It calls the forward pass of MultiAgentNetworkBody
with just the states of all agents.
:param obs: List of observations for all agents in group
:param memories: If using memory, a Tensor of initial memories.
:param sequence_length: If using memory, the sequence length.
:return: A Tuple of Dict of reward stream to tensor and critic memories.
"""
encoding, memories = self.network_body(
obs_only=obs,
obs=[],
actions=[],
memories=memories,
sequence_length=sequence_length,
)
value_outputs, critic_mem_out = self.forward(
encoding, memories, sequence_length
)
return value_outputs, critic_mem_out
def forward(
self,
encoding: torch.Tensor,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
output = self.value_heads(encoding)
return output, memories
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
"""
Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy.
:param policy: A TorchPolicy object that will be updated by this POCA Optimizer.
:param trainer_params: Trainer parameters dictionary that specifies the
properties of the trainer.
"""
# Create the graph here to give more granular control of the TF graph to the Optimizer.
super().__init__(policy, trainer_settings)
reward_signal_configs = trainer_settings.reward_signals
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]
self._critic = TorchPOCAOptimizer.POCAValueNetwork(
reward_signal_names,
policy.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
action_spec=policy.behavior_spec.action_spec,
)
params = list(self.policy.actor.parameters()) + list(self.critic.parameters())
self.hyperparameters: POCASettings = cast(
POCASettings, trainer_settings.hyperparameters
)
self.decay_learning_rate = ModelUtils.DecayedValue(
self.hyperparameters.learning_rate_schedule,
self.hyperparameters.learning_rate,
1e-10,
self.trainer_settings.max_steps,
)
self.decay_epsilon = ModelUtils.DecayedValue(
self.hyperparameters.learning_rate_schedule,
self.hyperparameters.epsilon,
0.1,
self.trainer_settings.max_steps,
)
self.decay_beta = ModelUtils.DecayedValue(
self.hyperparameters.learning_rate_schedule,
self.hyperparameters.beta,
1e-5,
self.trainer_settings.max_steps,
)
self.optimizer = torch.optim.Adam(
params, lr=self.trainer_settings.hyperparameters.learning_rate
)
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
}
self.stream_names = list(self.reward_signals.keys())
self.value_memory_dict: Dict[str, torch.Tensor] = {}
self.baseline_memory_dict: Dict[str, torch.Tensor] = {}
def create_reward_signals(
self, reward_signal_configs: Dict[RewardSignalType, RewardSignalSettings]
) -> None:
"""
Create reward signals. Override default to provide warnings for Curiosity and
GAIL, and make sure Extrinsic adds team rewards.
:param reward_signal_configs: Reward signal config.
"""
for reward_signal in reward_signal_configs.keys():
if reward_signal != RewardSignalType.EXTRINSIC:
logger.warning(
f"Reward signal {reward_signal.value.capitalize()} is not supported with the POCA trainer; "
"results may be unexpected."
)
super().create_reward_signals(reward_signal_configs)
# Make sure we add the groupmate rewards in POCA, so agents learn how to help each
# other achieve individual rewards as well
for reward_provider in self.reward_signals.values():
if isinstance(reward_provider, ExtrinsicRewardProvider):
reward_provider.add_groupmate_rewards = True
@property
def critic(self):
return self._critic
@timed
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"""
Performs update on model.
:param batch: Batch of experiences.
:param num_sequences: Number of sequences to process.
:return: Results of update.
"""
# Get decayed parameters
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step())
decay_bet = self.decay_beta.get_value(self.policy.get_current_step())
returns = {}
old_values = {}
old_baseline_values = {}
for name in self.reward_signals:
old_values[name] = ModelUtils.list_to_tensor(
batch[RewardSignalUtil.value_estimates_key(name)]
)
returns[name] = ModelUtils.list_to_tensor(
batch[RewardSignalUtil.returns_key(name)]
)
old_baseline_values[name] = ModelUtils.list_to_tensor(
batch[RewardSignalUtil.baseline_estimates_key(name)]
)
n_obs = len(self.policy.behavior_spec.observation_specs)
current_obs = ObsUtil.from_buffer(batch, n_obs)
# Convert to tensors
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
groupmate_obs = GroupObsUtil.from_buffer(batch, n_obs)
groupmate_obs = [
[ModelUtils.list_to_tensor(obs) for obs in _groupmate_obs]
for _groupmate_obs in groupmate_obs
]
act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK])
actions = AgentAction.from_buffer(batch)
groupmate_actions = AgentAction.group_from_buffer(batch)
memories = [
ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i])
for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length)
]
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
value_memories = [
ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i])
for i in range(
0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length
)
]
baseline_memories = [
ModelUtils.list_to_tensor(batch[BufferKey.BASELINE_MEMORY][i])
for i in range(
0, len(batch[BufferKey.BASELINE_MEMORY]), self.policy.sequence_length
)
]
if len(value_memories) > 0:
value_memories = torch.stack(value_memories).unsqueeze(0)
baseline_memories = torch.stack(baseline_memories).unsqueeze(0)
log_probs, entropy = self.policy.evaluate_actions(
current_obs,
masks=act_masks,
actions=actions,
memories=memories,
seq_len=self.policy.sequence_length,
)
all_obs = [current_obs] + groupmate_obs
values, _ = self.critic.critic_pass(
all_obs,
memories=value_memories,
sequence_length=self.policy.sequence_length,
)
groupmate_obs_and_actions = (groupmate_obs, groupmate_actions)
baselines, _ = self.critic.baseline(
current_obs,
groupmate_obs_and_actions,
memories=baseline_memories,
sequence_length=self.policy.sequence_length,
)
old_log_probs = ActionLogProbs.from_buffer(batch).flatten()
log_probs = log_probs.flatten()
loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool)
baseline_loss = ModelUtils.trust_region_value_loss(
baselines, old_baseline_values, returns, decay_eps, loss_masks
)
value_loss = ModelUtils.trust_region_value_loss(
values, old_values, returns, decay_eps, loss_masks
)
policy_loss = ModelUtils.trust_region_policy_loss(
ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]),
log_probs,
old_log_probs,
loss_masks,
decay_eps,
)
loss = (
policy_loss
+ 0.5 * (value_loss + 0.5 * baseline_loss)
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
)
# Set optimizer learning rate
ModelUtils.update_learning_rate(self.optimizer, decay_lr)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
update_stats = {
# NOTE: abs() is not technically correct, but matches the behavior in TensorFlow.
# TODO: After PyTorch is default, change to something more correct.
"Losses/Policy Loss": torch.abs(policy_loss).item(),
"Losses/Value Loss": value_loss.item(),
"Losses/Baseline Loss": baseline_loss.item(),
"Policy/Learning Rate": decay_lr,
"Policy/Epsilon": decay_eps,
"Policy/Beta": decay_bet,
}
for reward_provider in self.reward_signals.values():
update_stats.update(reward_provider.update(batch))
return update_stats
def get_modules(self):
modules = {"Optimizer:adam": self.optimizer, "Optimizer:critic": self._critic}
for reward_provider in self.reward_signals.values():
modules.update(reward_provider.get_modules())
return modules
def _evaluate_by_sequence_team(
self,
self_obs: List[torch.Tensor],
obs: List[List[torch.Tensor]],
actions: List[AgentAction],
init_value_mem: torch.Tensor,
init_baseline_mem: torch.Tensor,
) -> Tuple[
Dict[str, torch.Tensor],
Dict[str, torch.Tensor],
AgentBufferField,
AgentBufferField,
torch.Tensor,
torch.Tensor,
]:
"""
Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the
intermediate memories for the critic.
:param tensor_obs: A List of tensors of shape (trajectory_len, <obs_dim>) that are the agent's
observations for this trajectory.
:param initial_memory: The memory that preceeds this trajectory. Of shape (1,1,<mem_size>), i.e.
what is returned as the output of a MemoryModules.
:return: A Tuple of the value estimates as a Dict of [name, tensor], an AgentBufferField of the initial
memories to be used during value function update, and the final memory at the end of the trajectory.
"""
num_experiences = self_obs[0].shape[0]
all_next_value_mem = AgentBufferField()
all_next_baseline_mem = AgentBufferField()
# In the buffer, the 1st sequence are the ones that are padded. So if seq_len = 3 and
# trajectory is of length 10, the 1st sequence is [pad,pad,obs].
# Compute the number of elements in this padded seq.
leftover = num_experiences % self.policy.sequence_length
# Compute values for the potentially truncated initial sequence
first_seq_len = leftover if leftover > 0 else self.policy.sequence_length
self_seq_obs = []
groupmate_seq_obs = []
groupmate_seq_act = []
seq_obs = []
for _self_obs in self_obs:
first_seq_obs = _self_obs[0:first_seq_len]
seq_obs.append(first_seq_obs)
self_seq_obs.append(seq_obs)
for groupmate_obs, groupmate_action in zip(obs, actions):
seq_obs = []
for _obs in groupmate_obs:
first_seq_obs = _obs[0:first_seq_len]
seq_obs.append(first_seq_obs)
groupmate_seq_obs.append(seq_obs)
_act = groupmate_action.slice(0, first_seq_len)
groupmate_seq_act.append(_act)
# For the first sequence, the initial memory should be the one at the
# beginning of this trajectory.
for _ in range(first_seq_len):
all_next_value_mem.append(ModelUtils.to_numpy(init_value_mem.squeeze()))
all_next_baseline_mem.append(
ModelUtils.to_numpy(init_baseline_mem.squeeze())
)
all_seq_obs = self_seq_obs + groupmate_seq_obs
init_values, _value_mem = self.critic.critic_pass(
all_seq_obs, init_value_mem, sequence_length=first_seq_len
)
all_values = {
signal_name: [init_values[signal_name]]
for signal_name in init_values.keys()
}
groupmate_obs_and_actions = (groupmate_seq_obs, groupmate_seq_act)
init_baseline, _baseline_mem = self.critic.baseline(
self_seq_obs[0],
groupmate_obs_and_actions,
init_baseline_mem,
sequence_length=first_seq_len,
)
all_baseline = {
signal_name: [init_baseline[signal_name]]
for signal_name in init_baseline.keys()
}
# Evaluate other trajectories, carrying over _mem after each
# trajectory
for seq_num in range(
1, math.ceil((num_experiences) / (self.policy.sequence_length))
):
for _ in range(self.policy.sequence_length):
all_next_value_mem.append(ModelUtils.to_numpy(_value_mem.squeeze()))
all_next_baseline_mem.append(
ModelUtils.to_numpy(_baseline_mem.squeeze())
)
start = seq_num * self.policy.sequence_length - (
self.policy.sequence_length - leftover
)
end = (seq_num + 1) * self.policy.sequence_length - (
self.policy.sequence_length - leftover
)
self_seq_obs = []
groupmate_seq_obs = []
groupmate_seq_act = []
seq_obs = []
for _self_obs in self_obs:
seq_obs.append(_obs[start:end])
self_seq_obs.append(seq_obs)
for groupmate_obs, team_action in zip(obs, actions):
seq_obs = []
for (_obs,) in groupmate_obs:
first_seq_obs = _obs[start:end]
seq_obs.append(first_seq_obs)
groupmate_seq_obs.append(seq_obs)
_act = team_action.slice(start, end)
groupmate_seq_act.append(_act)
all_seq_obs = self_seq_obs + groupmate_seq_obs
values, _value_mem = self.critic.critic_pass(
all_seq_obs, _value_mem, sequence_length=self.policy.sequence_length
)
all_values = {
signal_name: [init_values[signal_name]] for signal_name in values.keys()
}
groupmate_obs_and_actions = (groupmate_seq_obs, groupmate_seq_act)
baselines, _baseline_mem = self.critic.baseline(
self_seq_obs[0],
groupmate_obs_and_actions,
_baseline_mem,
sequence_length=first_seq_len,
)
all_baseline = {
signal_name: [baselines[signal_name]]
for signal_name in baselines.keys()
}
# Create one tensor per reward signal
all_value_tensors = {
signal_name: torch.cat(value_list, dim=0)
for signal_name, value_list in all_values.items()
}
all_baseline_tensors = {
signal_name: torch.cat(baseline_list, dim=0)
for signal_name, baseline_list in all_baseline.items()
}
next_value_mem = _value_mem
next_baseline_mem = _baseline_mem
return (
all_value_tensors,
all_baseline_tensors,
all_next_value_mem,
all_next_baseline_mem,
next_value_mem,
next_baseline_mem,
)
def get_trajectory_value_estimates(
self,
batch: AgentBuffer,
next_obs: List[np.ndarray],
done: bool,
agent_id: str = "",
) -> Tuple[Dict[str, np.ndarray], Dict[str, float], Optional[AgentBufferField]]:
"""
Override base class method. Unused in the trainer, but needed to make sure class heirarchy is maintained.
Assume that there are no group obs.
"""
(
value_estimates,
_,
next_value_estimates,
all_next_value_mem,
_,
) = self.get_trajectory_and_baseline_value_estimates(
batch, next_obs, [], done, agent_id
)
return value_estimates, next_value_estimates, all_next_value_mem
def get_trajectory_and_baseline_value_estimates(
self,
batch: AgentBuffer,
next_obs: List[np.ndarray],
next_groupmate_obs: List[List[np.ndarray]],
done: bool,
agent_id: str = "",
) -> Tuple[
Dict[str, np.ndarray],
Dict[str, np.ndarray],
Dict[str, float],
Optional[AgentBufferField],
Optional[AgentBufferField],
]:
"""
Get value estimates, baseline estimates, and memories for a trajectory, in batch form.
:param batch: An AgentBuffer that consists of a trajectory.
:param next_obs: the next observation (after the trajectory). Used for boostrapping
if this is not a termiinal trajectory.
:param next_groupmate_obs: the next observations from other members of the group.
:param done: Set true if this is a terminal trajectory.
:param agent_id: Agent ID of the agent that this trajectory belongs to.
:returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)],
the baseline estimates as a Dict, the final value estimate as a Dict of [name, float], and
optionally (if using memories) an AgentBufferField of initial critic and baseline memories to be used
during update.
"""
n_obs = len(self.policy.behavior_spec.observation_specs)
current_obs = ObsUtil.from_buffer(batch, n_obs)
groupmate_obs = GroupObsUtil.from_buffer(batch, n_obs)
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
groupmate_obs = [
[ModelUtils.list_to_tensor(obs) for obs in _groupmate_obs]
for _groupmate_obs in groupmate_obs
]
groupmate_actions = AgentAction.group_from_buffer(batch)
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]
next_obs = [obs.unsqueeze(0) for obs in next_obs]
next_groupmate_obs = [
ModelUtils.list_to_tensor_list(_list_obs)
for _list_obs in next_groupmate_obs
]
# Expand dimensions of next critic obs
next_groupmate_obs = [
[_obs.unsqueeze(0) for _obs in _list_obs]
for _list_obs in next_groupmate_obs
]
if agent_id in self.value_memory_dict:
# The agent_id should always be in both since they are added together
_init_value_mem = self.value_memory_dict[agent_id]
_init_baseline_mem = self.baseline_memory_dict[agent_id]
else:
_init_value_mem = (
torch.zeros((1, 1, self.critic.memory_size))
if self.policy.use_recurrent
else None
)
_init_baseline_mem = (
torch.zeros((1, 1, self.critic.memory_size))
if self.policy.use_recurrent
else None
)
all_obs = (
[current_obs] + groupmate_obs
if groupmate_obs is not None
else [current_obs]
)
all_next_value_mem: Optional[AgentBufferField] = None
all_next_baseline_mem: Optional[AgentBufferField] = None
with torch.no_grad():
if self.policy.use_recurrent:
(
value_estimates,
baseline_estimates,
all_next_value_mem,
all_next_baseline_mem,
next_value_mem,
next_baseline_mem,
) = self._evaluate_by_sequence_team(
current_obs,
groupmate_obs,
groupmate_actions,
_init_value_mem,
_init_baseline_mem,
)
else:
value_estimates, next_value_mem = self.critic.critic_pass(
all_obs, _init_value_mem, sequence_length=batch.num_experiences
)
groupmate_obs_and_actions = (groupmate_obs, groupmate_actions)
baseline_estimates, next_baseline_mem = self.critic.baseline(
current_obs,
groupmate_obs_and_actions,
_init_baseline_mem,
sequence_length=batch.num_experiences,
)
# Store the memory for the next trajectory
self.value_memory_dict[agent_id] = next_value_mem
self.baseline_memory_dict[agent_id] = next_baseline_mem
all_next_obs = (
[next_obs] + next_groupmate_obs
if next_groupmate_obs is not None
else [next_obs]
)
next_value_estimates, _ = self.critic.critic_pass(
all_next_obs, next_value_mem, sequence_length=1
)
for name, estimate in baseline_estimates.items():
baseline_estimates[name] = ModelUtils.to_numpy(estimate)
for name, estimate in value_estimates.items():
value_estimates[name] = ModelUtils.to_numpy(estimate)
# the base line and V shpuld not be on the same done flag
for name, estimate in next_value_estimates.items():
next_value_estimates[name] = ModelUtils.to_numpy(estimate)
if done:
for k in next_value_estimates:
if not self.reward_signals[k].ignore_done:
next_value_estimates[k][-1] = 0.0
return (
value_estimates,
baseline_estimates,
next_value_estimates,
all_next_value_mem,
all_next_baseline_mem,
)

310
ml-agents/mlagents/trainers/poca/trainer.py


# # Unity ML-Agents Toolkit
# ## ML-Agents Learning (POCA)
# Contains an implementation of MA-POCA.
from collections import defaultdict
from typing import cast, Dict
import numpy as np
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.buffer import BufferKey, RewardSignalUtil
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.policy import Policy
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, POCASettings
logger = get_logger(__name__)
class POCATrainer(RLTrainer):
"""The POCATrainer is an implementation of the MA-POCA algorithm."""
def __init__(
self,
behavior_name: str,
reward_buff_cap: int,
trainer_settings: TrainerSettings,
training: bool,
load: bool,
seed: int,
artifact_path: str,
):
"""
Responsible for collecting experiences and training POCA model.
:param behavior_name: The name of the behavior associated with trainer config
:param reward_buff_cap: Max reward history to track in the reward buffer
:param trainer_settings: The parameters for the trainer.
:param training: Whether the trainer is set for training.
:param load: Whether the model should be loaded.
:param seed: The seed the model will be initialized with
:param artifact_path: The directory within which to store artifacts from this trainer.
"""
super().__init__(
behavior_name,
trainer_settings,
training,
load,
artifact_path,
reward_buff_cap,
)
self.hyperparameters: POCASettings = cast(
POCASettings, self.trainer_settings.hyperparameters
)
self.seed = seed
self.policy: TorchPolicy = None # type: ignore
self.collected_group_rewards: Dict[str, int] = defaultdict(lambda: 0)
def _process_trajectory(self, trajectory: Trajectory) -> None:
"""
Takes a trajectory and processes it, putting it into the update buffer.
Processing involves calculating value and advantage targets for model updating step.
:param trajectory: The Trajectory tuple containing the steps to be processed.
"""
super()._process_trajectory(trajectory)
agent_id = trajectory.agent_id # All the agents should have the same ID
agent_buffer_trajectory = trajectory.to_agentbuffer()
# Update the normalization
if self.is_training:
self.policy.update_normalization(agent_buffer_trajectory)
# Get all value estimates
(
value_estimates,
baseline_estimates,
value_next,
value_memories,
baseline_memories,
) = self.optimizer.get_trajectory_and_baseline_value_estimates(
agent_buffer_trajectory,
trajectory.next_obs,
trajectory.next_group_obs,
trajectory.all_group_dones_reached
and trajectory.done_reached
and not trajectory.interrupted,
)
if value_memories is not None and baseline_memories is not None:
agent_buffer_trajectory[BufferKey.CRITIC_MEMORY].set(value_memories)
agent_buffer_trajectory[BufferKey.BASELINE_MEMORY].set(baseline_memories)
for name, v in value_estimates.items():
agent_buffer_trajectory[RewardSignalUtil.value_estimates_key(name)].extend(
v
)
agent_buffer_trajectory[
RewardSignalUtil.baseline_estimates_key(name)
].extend(baseline_estimates[name])
self._stats_reporter.add_stat(
f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Baseline Estimate",
np.mean(baseline_estimates[name]),
)
self._stats_reporter.add_stat(
f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Value Estimate",
np.mean(value_estimates[name]),
)
self.collected_rewards["environment"][agent_id] += np.sum(
agent_buffer_trajectory[BufferKey.ENVIRONMENT_REWARDS]
)
self.collected_group_rewards[agent_id] += np.sum(
agent_buffer_trajectory[BufferKey.GROUP_REWARD]
)
for name, reward_signal in self.optimizer.reward_signals.items():
evaluate_result = (
reward_signal.evaluate(agent_buffer_trajectory) * reward_signal.strength
)
agent_buffer_trajectory[RewardSignalUtil.rewards_key(name)].extend(
evaluate_result
)
# Report the reward signals
self.collected_rewards[name][agent_id] += np.sum(evaluate_result)
# Compute lambda returns and advantage
tmp_advantages = []
for name in self.optimizer.reward_signals:
local_rewards = np.array(
agent_buffer_trajectory[RewardSignalUtil.rewards_key(name)].get_batch(),
dtype=np.float32,
)
baseline_estimate = agent_buffer_trajectory[
RewardSignalUtil.baseline_estimates_key(name)
].get_batch()
v_estimates = agent_buffer_trajectory[
RewardSignalUtil.value_estimates_key(name)
].get_batch()
lambd_returns = lambda_return(
r=local_rewards,
value_estimates=v_estimates,
gamma=self.optimizer.reward_signals[name].gamma,
lambd=self.hyperparameters.lambd,
value_next=value_next[name],
)
local_advantage = np.array(lambd_returns) - np.array(baseline_estimate)
agent_buffer_trajectory[RewardSignalUtil.returns_key(name)].set(
lambd_returns
)
agent_buffer_trajectory[RewardSignalUtil.advantage_key(name)].set(
local_advantage
)
tmp_advantages.append(local_advantage)
# Get global advantages
global_advantages = list(
np.mean(np.array(tmp_advantages, dtype=np.float32), axis=0)
)
agent_buffer_trajectory[BufferKey.ADVANTAGES].set(global_advantages)
# Append to update buffer
agent_buffer_trajectory.resequence_and_append(
self.update_buffer, training_length=self.policy.sequence_length
)
# If this was a terminal trajectory, append stats and reset reward collection
if trajectory.done_reached:
self._update_end_episode_stats(agent_id, self.optimizer)
# Remove dead agents from group reward recording
if not trajectory.all_group_dones_reached:
self.collected_group_rewards.pop(agent_id)
# If the whole team is done, average the remaining group rewards.
if trajectory.all_group_dones_reached and trajectory.done_reached:
self.stats_reporter.add_stat(
"Environment/Group Cumulative Reward",
self.collected_group_rewards.get(agent_id, 0),
aggregation=StatsAggregationMethod.HISTOGRAM,
)
self.collected_group_rewards.pop(agent_id)
def _is_ready_update(self):
"""
Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to whether or not update_model() can be run
"""
size_of_buffer = self.update_buffer.num_experiences
return size_of_buffer > self.hyperparameters.buffer_size
def _update_policy(self):
"""
Uses demonstration_buffer to update the policy.
The reward signal generators must be updated in this method at their own pace.
"""
buffer_length = self.update_buffer.num_experiences
self.cumulative_returns_since_policy_update.clear()
# Make sure batch_size is a multiple of sequence length. During training, we
# will need to reshape the data into a batch_size x sequence_length tensor.
batch_size = (
self.hyperparameters.batch_size
- self.hyperparameters.batch_size % self.policy.sequence_length
)
# Make sure there is at least one sequence
batch_size = max(batch_size, self.policy.sequence_length)
n_sequences = max(
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)
advantages = np.array(
self.update_buffer[BufferKey.ADVANTAGES].get_batch(), dtype=np.float32
)
self.update_buffer[BufferKey.ADVANTAGES].set(
(advantages - advantages.mean()) / (advantages.std() + 1e-10)
)
num_epoch = self.hyperparameters.num_epoch
batch_update_stats = defaultdict(list)
for _ in range(num_epoch):
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length)
buffer = self.update_buffer
max_num_batch = buffer_length // batch_size
for i in range(0, max_num_batch * batch_size, batch_size):
update_stats = self.optimizer.update(
buffer.make_mini_batch(i, i + batch_size), n_sequences
)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)
for stat, stat_list in batch_update_stats.items():
self._stats_reporter.add_stat(stat, np.mean(stat_list))
if self.optimizer.bc_module:
update_stats = self.optimizer.bc_module.update()
for stat, val in update_stats.items():
self._stats_reporter.add_stat(stat, val)
self._clear_update_buffer()
return True
def create_torch_policy(
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
) -> TorchPolicy:
"""
Creates a policy with a PyTorch backend and POCA hyperparameters
:param parsed_behavior_id:
:param behavior_spec: specifications for policy construction
:return policy
"""
policy = TorchPolicy(
self.seed,
behavior_spec,
self.trainer_settings,
condition_sigma_on_obs=False, # Faster training for POCA
separate_critic=True, # Match network architecture with TF
)
return policy
def create_poca_optimizer(self) -> TorchPOCAOptimizer:
return TorchPOCAOptimizer(self.policy, self.trainer_settings)
def add_policy(
self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy
) -> None:
"""
Adds policy to trainer.
:param parsed_behavior_id: Behavior identifiers that the policy should belong to.
:param policy: Policy to associate with name_behavior_id.
"""
if not isinstance(policy, TorchPolicy):
raise RuntimeError(f"policy {policy} must be an instance of TorchPolicy.")
self.policy = policy
self.policies[parsed_behavior_id.behavior_id] = policy
self.optimizer = self.create_poca_optimizer()
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
self.model_saver.register(self.policy)
self.model_saver.register(self.optimizer)
self.model_saver.initialize_or_load()
# Needed to resume loads properly
self.step = policy.get_current_step()
def get_policy(self, name_behavior_id: str) -> Policy:
"""
Gets policy from trainer associated with name_behavior_id
:param name_behavior_id: full identifier of policy
"""
return self.policy
def lambda_return(r, value_estimates, gamma=0.99, lambd=0.8, value_next=0.0):
returns = np.zeros_like(r)
returns[-1] = r[-1] + gamma * value_next
for t in reversed(range(0, r.size - 1)):
returns[t] = (
gamma * lambd * returns[t + 1]
+ r[t]
+ (1 - lambd) * gamma * value_estimates[t + 1]
)
return returns
正在加载...
取消
保存