|
|
|
|
|
|
from typing import Dict, cast, List, Tuple, Optional |
|
|
|
import numpy as np |
|
|
|
import math |
|
|
|
from mlagents.trainers.buffer import AgentBuffer, BufferKey, RewardSignalUtil |
|
|
|
from mlagents.trainers.buffer import ( |
|
|
|
AgentBuffer, |
|
|
|
BufferKey, |
|
|
|
RewardSignalUtil, |
|
|
|
AgentBufferField, |
|
|
|
) |
|
|
|
|
|
|
|
from mlagents_envs.timers import timed |
|
|
|
from mlagents_envs.base_env import ObservationSpec, ActionSpec |
|
|
|
|
|
|
modules.update(reward_provider.get_modules()) |
|
|
|
return modules |
|
|
|
|
|
|
|
def _evaluate_by_sequence( |
|
|
|
def _evaluate_by_sequence_team( |
|
|
|
init_value_mem: Optional[torch.Tensor] = None, |
|
|
|
init_baseline_mem: Optional[torch.Tensor] = None, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], AgentBufferField, torch.Tensor]: |
|
|
|
init_value_mem: torch.Tensor, |
|
|
|
init_baseline_mem: torch.Tensor, |
|
|
|
) -> Tuple[ |
|
|
|
Dict[str, torch.Tensor], |
|
|
|
Dict[str, torch.Tensor], |
|
|
|
AgentBufferField, |
|
|
|
AgentBufferField, |
|
|
|
torch.Tensor, |
|
|
|
torch.Tensor, |
|
|
|
]: |
|
|
|
""" |
|
|
|
Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the |
|
|
|
intermediate memories for the critic. |
|
|
|
|
|
|
leftover = num_experiences % self.policy.sequence_length |
|
|
|
|
|
|
|
# Compute values for the potentially truncated initial sequence |
|
|
|
|
|
|
|
first_seq_len = leftover if leftover > 0 else self.policy.sequence_length |
|
|
|
|
|
|
|
first_seq_len = leftover if leftover > 0 else self.policy.sequence_length |
|
|
|
seq_obs = [] |
|
|
|
for _self_obs in self_obs: |
|
|
|
first_seq_obs = _self_obs[0:first_seq_len] |
|
|
|
|
|
|
for team_obs, team_action in zip(obs, actions): |
|
|
|
seq_obs = [] |
|
|
|
for (_obs,) in _team_obs: |
|
|
|
for (_obs,) in team_obs: |
|
|
|
first_seq_obs = _obs[0:first_seq_len] |
|
|
|
seq_obs.append(first_seq_obs) |
|
|
|
team_seq_obs.append(seq_obs) |
|
|
|
|
|
|
# For the first sequence, the initial memory should be the one at the |
|
|
|
# beginning of this trajectory. |
|
|
|
for _ in range(first_seq_len): |
|
|
|
all_next_value_mem.append(ModelUtils.to_numpy(_init_value_mem.squeeze())) |
|
|
|
all_next_value_mem.append(ModelUtils.to_numpy(init_value_mem.squeeze())) |
|
|
|
ModelUtils.to_numpy(_init_baseline_mem.squeeze()) |
|
|
|
ModelUtils.to_numpy(init_baseline_mem.squeeze()) |
|
|
|
init_values, _mem = self.critic.critic_pass( |
|
|
|
all_seq_obs, _init_value_memory, sequence_length=first_seq_len |
|
|
|
init_values, _value_mem = self.critic.critic_pass( |
|
|
|
all_seq_obs, init_value_mem, sequence_length=first_seq_len |
|
|
|
) |
|
|
|
all_values = { |
|
|
|
signal_name: [init_values[signal_name]] |
|
|
|
|
|
|
init_baseline, _mem = self.critic.baseline( |
|
|
|
init_baseline, _baseline_mem = self.critic.baseline( |
|
|
|
_init_baseline_memory, |
|
|
|
init_baseline_mem, |
|
|
|
for signal_name in init_values.keys() |
|
|
|
for signal_name in init_baseline.keys() |
|
|
|
} |
|
|
|
|
|
|
|
# Evaluate other trajectories, carrying over _mem after each |
|
|
|
|
|
|
): |
|
|
|
seq_obs = [] |
|
|
|
all_next_memories.append(ModelUtils.to_numpy(_mem.squeeze())) |
|
|
|
for _obs in tensor_obs: |
|
|
|
start = seq_num * self.policy.sequence_length - ( |
|
|
|
self.policy.sequence_length - leftover |
|
|
|
all_next_value_mem.append(ModelUtils.to_numpy(_value_mem.squeeze())) |
|
|
|
all_next_baseline_mem.append( |
|
|
|
ModelUtils.to_numpy(_baseline_mem.squeeze()) |
|
|
|
end = (seq_num + 1) * self.policy.sequence_length - ( |
|
|
|
self.policy.sequence_length - leftover |
|
|
|
) |
|
|
|
|
|
|
|
start = seq_num * self.policy.sequence_length - ( |
|
|
|
self.policy.sequence_length - leftover |
|
|
|
) |
|
|
|
end = (seq_num + 1) * self.policy.sequence_length - ( |
|
|
|
self.policy.sequence_length - leftover |
|
|
|
) |
|
|
|
|
|
|
|
self_seq_obs = [] |
|
|
|
team_seq_obs = [] |
|
|
|
team_seq_act = [] |
|
|
|
seq_obs = [] |
|
|
|
for _self_obs in self_obs: |
|
|
|
values, _mem = self.critic.critic_pass( |
|
|
|
seq_obs, _mem, sequence_length=self.policy.sequence_length |
|
|
|
self_seq_obs.append(seq_obs) |
|
|
|
|
|
|
|
for team_obs, team_action in zip(obs, actions): |
|
|
|
seq_obs = [] |
|
|
|
for (_obs,) in team_obs: |
|
|
|
first_seq_obs = _obs[start:end] |
|
|
|
seq_obs.append(first_seq_obs) |
|
|
|
team_seq_obs.append(seq_obs) |
|
|
|
_act = team_action[start:end] |
|
|
|
team_seq_act.append(_act) |
|
|
|
|
|
|
|
all_seq_obs = self_seq_obs + team_seq_obs |
|
|
|
values, _value_mem = self.critic.critic_pass( |
|
|
|
all_seq_obs, _value_mem, sequence_length=self.policy.sequence_length |
|
|
|
) |
|
|
|
all_values = { |
|
|
|
signal_name: [init_values[signal_name]] for signal_name in values.keys() |
|
|
|
} |
|
|
|
|
|
|
|
baselines, _baseline_mem = self.critic.baseline( |
|
|
|
self_seq_obs, |
|
|
|
team_seq_obs, |
|
|
|
team_seq_act, |
|
|
|
_baseline_mem, |
|
|
|
sequence_length=first_seq_len, |
|
|
|
for signal_name, _val in values.items(): |
|
|
|
all_values[signal_name].append(_val) |
|
|
|
all_baseline = { |
|
|
|
signal_name: [baselines[signal_name]] |
|
|
|
for signal_name in baselines.keys() |
|
|
|
} |
|
|
|
next_mem = _mem |
|
|
|
return all_value_tensors, all_next_memories, next_mem |
|
|
|
all_baseline_tensors = { |
|
|
|
signal_name: torch.cat(baseline_list, dim=0) |
|
|
|
for signal_name, baseline_list in all_baseline.items() |
|
|
|
} |
|
|
|
next_value_mem = _value_mem |
|
|
|
next_baseline_mem = _baseline_mem |
|
|
|
return ( |
|
|
|
all_value_tensors, |
|
|
|
all_baseline_tensors, |
|
|
|
all_next_value_mem, |
|
|
|
next_value_mem, |
|
|
|
all_next_baseline_mem, |
|
|
|
next_baseline_mem, |
|
|
|
) |
|
|
|
|
|
|
|
def get_trajectory_and_baseline_value_estimates( |
|
|
|
self, |
|
|
|
|
|
|
done: bool, |
|
|
|
) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, float]]: |
|
|
|
agent_id: str = "", |
|
|
|
) -> Tuple[ |
|
|
|
Dict[str, np.ndarray], |
|
|
|
Dict[str, np.ndarray], |
|
|
|
Dict[str, float], |
|
|
|
Optional[AgentBufferField], |
|
|
|
Optional[AgentBufferField], |
|
|
|
]: |
|
|
|
|
|
|
|
n_obs = len(self.policy.behavior_spec.observation_specs) |
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
all_obs = [current_obs] + team_obs if team_obs is not None else [current_obs] |
|
|
|
all_next_value_mem: Optional[AgentBufferField] = None |
|
|
|
all_next_baseline_mem: Optional[AgentBufferField] = None |
|
|
|
value_estimates, baseline_estimates, value_mem, baseline_mem = self.critic._evaluate_by_sequence( |
|
|
|
value_estimates, baseline_estimates, all_next_value_mem, all_next_baseline_mem, next_value_mem, next_baseline_mem = self.critic._evaluate_by_sequence_team( |
|
|
|
current_obs, team_obs, team_actions, _init_value_mem, _init_baseline_mem |
|
|
|
) |
|
|
|
else: |
|
|
|
|
|
|
sequence_length=batch.num_experiences, |
|
|
|
) |
|
|
|
# Store the memory for the next trajectory |
|
|
|
self.value_memory_dict[agent_id] = value_mem |
|
|
|
self.baseline_memory_dict[agent_id] = baseline_mem |
|
|
|
self.value_memory_dict[agent_id] = next_value_mem |
|
|
|
self.baseline_memory_dict[agent_id] = next_baseline_mem |
|
|
|
|
|
|
|
all_next_obs = ( |
|
|
|
[next_obs] + next_group_obs if next_group_obs is not None else [next_obs] |
|
|
|
|
|
|
if not self.reward_signals[k].ignore_done: |
|
|
|
next_value_estimates[k][-1] = 0.0 |
|
|
|
|
|
|
|
return (value_estimates, baseline_estimates, next_value_estimates) |
|
|
|
return ( |
|
|
|
value_estimates, |
|
|
|
baseline_estimates, |
|
|
|
next_value_estimates, |
|
|
|
all_next_value_mem, |
|
|
|
all_next_baseline_mem, |
|
|
|
) |