|
|
|
|
|
|
from typing import Dict, Optional, Tuple, List |
|
|
|
import torch |
|
|
|
import numpy as np |
|
|
|
from mlagents_envs.base_env import DecisionSteps |
|
|
|
|
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.components.bc.module import BCModule |
|
|
|
|
|
|
from mlagents.trainers.optimizer import Optimizer |
|
|
|
from mlagents.trainers.settings import TrainerSettings |
|
|
|
from mlagents.trainers.trajectory import SplitObservations |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reward_signal, self.policy.behavior_spec, settings |
|
|
|
) |
|
|
|
|
|
|
|
def get_value_estimates( |
|
|
|
self, decision_requests: DecisionSteps, idx: int, done: bool |
|
|
|
) -> Dict[str, float]: |
|
|
|
""" |
|
|
|
Generates value estimates for bootstrapping. |
|
|
|
:param decision_requests: |
|
|
|
:param idx: Index in BrainInfo of agent. |
|
|
|
:param done: Whether or not this is the last element of the episode, |
|
|
|
in which case the value estimate will be 0. |
|
|
|
:return: The value estimate dictionary with key being the name of the reward signal |
|
|
|
and the value the corresponding value estimate. |
|
|
|
""" |
|
|
|
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs) |
|
|
|
|
|
|
|
value_estimates = self.policy.actor_critic.critic_pass( |
|
|
|
np.expand_dims(vec_vis_obs.vector_observations[idx], 0), |
|
|
|
np.expand_dims(vec_vis_obs.visual_observations[idx], 0), |
|
|
|
) |
|
|
|
|
|
|
|
value_estimates = {k: float(v) for k, v in value_estimates.items()} |
|
|
|
|
|
|
|
# If we're done, reassign all of the value estimates that need terminal states. |
|
|
|
if done: |
|
|
|
for k in value_estimates: |
|
|
|
if not self.reward_signals[k].ignore_done: |
|
|
|
value_estimates[k] = 0.0 |
|
|
|
|
|
|
|
return value_estimates |
|
|
|
|
|
|
|
def get_trajectory_value_estimates( |
|
|
|
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool |
|
|
|
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: |
|
|
|
|
|
|
else: |
|
|
|
visual_obs = [] |
|
|
|
|
|
|
|
memory = torch.zeros([1, len(vector_obs[0]), self.policy.m_size]) |
|
|
|
memory = torch.zeros([1, 1, self.policy.m_size]) |
|
|
|
next_memory = torch.zeros([1, 1, self.policy.m_size]) |
|
|
|
value_estimates = self.policy.actor_critic.critic_pass( |
|
|
|
vector_obs, visual_obs, memory |
|
|
|
value_estimates, next_memory = self.policy.actor_critic.critic_pass( |
|
|
|
vector_obs, visual_obs, memory, sequence_length=batch.num_experiences |
|
|
|
next_value_estimate = self.policy.actor_critic.critic_pass( |
|
|
|
next_obs, next_obs, next_memory |
|
|
|
next_value_estimate, _ = self.policy.actor_critic.critic_pass( |
|
|
|
next_obs, next_obs, next_memory, sequence_length=1 |
|
|
|
) |
|
|
|
|
|
|
|
for name, estimate in value_estimates.items(): |
|
|
|