|
|
|
|
|
|
from typing import Dict, cast |
|
|
|
from typing import Dict, cast, List, Tuple, Optional |
|
|
|
import numpy as np |
|
|
|
from mlagents_envs.base_env import ObservationSpec, ActionSpec |
|
|
|
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|
|
|
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|
|
|
from mlagents.trainers.settings import TrainerSettings, PPOSettings |
|
|
|
|
|
|
from mlagents.trainers.torch.action_log_probs import ActionLogProbs |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.trajectory import ObsUtil |
|
|
|
from mlagents.trainers.trajectory import ObsUtil, GroupObsUtil |
|
|
|
from mlagents.trainers.settings import NetworkSettings |
|
|
|
|
|
|
|
|
|
|
|
class TorchCOMAOptimizer(TorchOptimizer): |
|
|
|
|
|
|
else: |
|
|
|
encoding_size = network_settings.hidden_units |
|
|
|
|
|
|
|
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) |
|
|
|
self.value_heads = ValueHeads(stream_names, encoding_size, 1) |
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size |
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size |
|
|
|
def update_normalization(self, buffer: AgentBuffer) -> None: |
|
|
|
self.network_body.update_normalization(buffer) |
|
|
|
|
|
|
|
def baseline( |
|
|
|
self, |
|
|
|
self_obs: List[List[torch.Tensor]], |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
actions: List[AgentAction], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
def update_normalization(self, buffer: AgentBuffer) -> None: |
|
|
|
self.network_body.update_normalization(buffer) |
|
|
|
|
|
|
|
def baseline( |
|
|
|
self, |
|
|
|
self_obs: List[List[torch.Tensor]], |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
actions: List[AgentAction], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encoding, memories = self.network_body(obs_only=self_obs, obs=obs, actions=actions, memories, sequence_length) |
|
|
|
value_outputs, critic_mem_out = self.forward(encoding, memories, sequence_length) |
|
|
|
return value_outputs, critic_mem_out |
|
|
|
encoding, memories = self.network_body( |
|
|
|
obs_only=self_obs, |
|
|
|
obs=obs, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
value_outputs, critic_mem_out = self.forward( |
|
|
|
encoding, memories, sequence_length |
|
|
|
) |
|
|
|
return value_outputs, critic_mem_out |
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encoding, memories = self.network_body(obs_only=obs, obs=None, actions=None, memories, sequence_length) |
|
|
|
value_outputs, critic_mem_out = self.forward(encoding, memories, sequence_length) |
|
|
|
return value_outputs, critic_mem_out |
|
|
|
encoding, memories = self.network_body( |
|
|
|
obs_only=obs, |
|
|
|
obs=None, |
|
|
|
actions=None, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
value_outputs, critic_mem_out = self.forward( |
|
|
|
encoding, memories, sequence_length |
|
|
|
) |
|
|
|
return value_outputs, critic_mem_out |
|
|
|
def forward( |
|
|
|
self, |
|
|
|
encoding: torch.Tensor, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
def forward( |
|
|
|
self, |
|
|
|
encoding: torch.Tensor, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
|
|
|
|
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|
|
|
""" |
|
|
|
|
|
|
reward_signal_configs = trainer_settings.reward_signals |
|
|
|
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] |
|
|
|
|
|
|
|
self._critic = COMAValueNetwork( |
|
|
|
self._critic = TorchCOMAOptimizer.COMAValueNetwork( |
|
|
|
policy.behavior_spec.action_spec, |
|
|
|
action_spec=policy.behavior_spec.action_spec, |
|
|
|
params = list(self.policy.actor.parameters()) + list( |
|
|
|
self.value_net.parameters() |
|
|
|
) |
|
|
|
params = list(self.policy.actor.parameters()) + list(self.critic.parameters()) |
|
|
|
self.hyperparameters: PPOSettings = cast( |
|
|
|
PPOSettings, trainer_settings.hyperparameters |
|
|
|
) |
|
|
|
|
|
|
value_loss = torch.mean(torch.stack(value_losses)) |
|
|
|
return value_loss |
|
|
|
|
|
|
|
def policy_policy_loss( |
|
|
|
def ppo_policy_loss( |
|
|
|
self, |
|
|
|
advantages: torch.Tensor, |
|
|
|
log_probs: torch.Tensor, |
|
|
|
|
|
|
decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) |
|
|
|
returns = {} |
|
|
|
old_values = {} |
|
|
|
old_baseline_values = {} |
|
|
|
for name in self.reward_signals: |
|
|
|
old_values[name] = ModelUtils.list_to_tensor( |
|
|
|
batch[RewardSignalUtil.value_estimates_key(name)] |
|
|
|
|
|
|
) |
|
|
|
old_baseline_values[name] = ModelUtils.list_to_tensor( |
|
|
|
batch[RewardSignalUtil.baseline_estimates_key(name)] |
|
|
|
) |
|
|
|
|
|
|
|
n_obs = len(self.policy.behavior_spec.observation_specs) |
|
|
|
|
|
|
group_obs = GroupObsUtil.from_buffer(batch, n_obs) |
|
|
|
group_obs = [ |
|
|
|
[ModelUtils.list_to_tensor(obs) for obs in _groupmate_obs] |
|
|
|
for _groupmate_obs in group_obs |
|
|
|
] |
|
|
|
group_actions = AgentAction.group_from_buffer(batch) |
|
|
|
|
|
|
|
memories = [ |
|
|
|
ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) |
|
|
|
|
|
|
memories=memories, |
|
|
|
seq_len=self.policy.sequence_length, |
|
|
|
) |
|
|
|
all_obs = [current_obs] + group_obs |
|
|
|
current_obs, memories=memories, sequence_length=self.policy.sequence_length |
|
|
|
all_obs, memories=memories, sequence_length=self.policy.sequence_length |
|
|
|
) |
|
|
|
baselines, _ = self.critic.baseline( |
|
|
|
[current_obs], |
|
|
|
group_obs, |
|
|
|
group_actions, |
|
|
|
memories=memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
value_loss = self.ppo_value_loss( |
|
|
|
|
|
|
|
baseline_loss = self.coma_value_loss( |
|
|
|
baselines, old_baseline_values, returns, decay_eps, loss_masks |
|
|
|
) |
|
|
|
value_loss = self.coma_value_loss( |
|
|
|
values, old_values, returns, decay_eps, loss_masks |
|
|
|
) |
|
|
|
policy_loss = self.ppo_policy_loss( |
|
|
|
|
|
|
) |
|
|
|
loss = ( |
|
|
|
policy_loss |
|
|
|
+ 0.5 * value_loss |
|
|
|
+ 0.5 * (value_loss + baseline_loss) |
|
|
|
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
for reward_provider in self.reward_signals.values(): |
|
|
|
modules.update(reward_provider.get_modules()) |
|
|
|
return modules |
|
|
|
|
|
|
|
def get_trajectory_value_estimates( |
|
|
|
self, |
|
|
|
batch: AgentBuffer, |
|
|
|
next_obs: List[np.ndarray], |
|
|
|
next_group_obs: List[List[np.ndarray]], |
|
|
|
done: bool, |
|
|
|
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: |
|
|
|
|
|
|
|
n_obs = len(self.policy.behavior_spec.observation_specs) |
|
|
|
|
|
|
|
current_obs = ObsUtil.from_buffer(batch, n_obs) |
|
|
|
team_obs = GroupObsUtil.from_buffer(batch, n_obs) |
|
|
|
|
|
|
|
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] |
|
|
|
team_obs = [ |
|
|
|
[ModelUtils.list_to_tensor(obs) for obs in _teammate_obs] |
|
|
|
for _teammate_obs in team_obs |
|
|
|
] |
|
|
|
|
|
|
|
team_actions = AgentAction.group_from_buffer(batch) |
|
|
|
|
|
|
|
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] |
|
|
|
next_obs = [obs.unsqueeze(0) for obs in next_obs] |
|
|
|
|
|
|
|
next_group_obs = [ |
|
|
|
ModelUtils.list_to_tensor_list(_list_obs) for _list_obs in next_group_obs |
|
|
|
] |
|
|
|
# Expand dimensions of next critic obs |
|
|
|
next_group_obs = [ |
|
|
|
[_obs.unsqueeze(0) for _obs in _list_obs] for _list_obs in next_group_obs |
|
|
|
] |
|
|
|
|
|
|
|
memory = torch.zeros([1, 1, self.policy.m_size]) |
|
|
|
all_obs = [current_obs] + team_obs if team_obs is not None else [current_obs] |
|
|
|
value_estimates, mem = self.critic.critic_pass( |
|
|
|
all_obs, memory, sequence_length=batch.num_experiences |
|
|
|
) |
|
|
|
|
|
|
|
baseline_estimates, mem = self.critic.baseline( |
|
|
|
[current_obs], |
|
|
|
team_obs, |
|
|
|
team_actions, |
|
|
|
memory, |
|
|
|
sequence_length=batch.num_experiences, |
|
|
|
) |
|
|
|
all_next_obs = ( |
|
|
|
[next_obs] + next_group_obs if next_group_obs is not None else [next_obs] |
|
|
|
) |
|
|
|
|
|
|
|
next_value_estimates, mem = self.critic.critic_pass( |
|
|
|
all_next_obs, mem, sequence_length=batch.num_experiences |
|
|
|
) |
|
|
|
|
|
|
|
for name, estimate in baseline_estimates.items(): |
|
|
|
baseline_estimates[name] = ModelUtils.to_numpy(estimate) |
|
|
|
|
|
|
|
for name, estimate in value_estimates.items(): |
|
|
|
value_estimates[name] = ModelUtils.to_numpy(estimate) |
|
|
|
|
|
|
|
# the base line and V shpuld not be on the same done flag |
|
|
|
for name, estimate in next_value_estimates.items(): |
|
|
|
next_value_estimates[name] = ModelUtils.to_numpy(estimate) |
|
|
|
|
|
|
|
if done: |
|
|
|
for k in next_value_estimates: |
|
|
|
if not self.reward_signals[k].ignore_done: |
|
|
|
next_value_estimates[k][-1] = 0.0 |
|
|
|
|
|
|
|
return (value_estimates, baseline_estimates, next_value_estimates) |