from abc import abstractmethod from typing import Dict, List, Optional import numpy as np from mlagents_envs.base_env import DecisionSteps from mlagents_envs.exception import UnityException from mlagents.trainers.action_info import ActionInfo from mlagents_envs.base_env import BehaviorSpec from mlagents.trainers.settings import TrainerSettings, NetworkSettings class UnityPolicyException(UnityException): """ Related to errors with the Trainer. """ pass class Policy: def __init__( self, seed: int, behavior_spec: BehaviorSpec, trainer_settings: TrainerSettings, tanh_squash: bool = False, reparameterize: bool = False, condition_sigma_on_obs: bool = True, ): self.behavior_spec = behavior_spec self.action_spec = behavior_spec.action_spec self.trainer_settings = trainer_settings self.network_settings: NetworkSettings = trainer_settings.network_settings self.seed = seed self.act_size = ( list(self.behavior_spec.action_spec.discrete_branches) if self.behavior_spec.action_spec.is_discrete() else [self.behavior_spec.action_spec.continuous_size] ) self.vec_obs_size = sum( shape[0] for shape in behavior_spec.observation_shapes if len(shape) == 1 ) self.vis_obs_size = sum( 1 for shape in behavior_spec.observation_shapes if len(shape) == 3 ) self.use_continuous_act = self.behavior_spec.action_spec.is_continuous() self.previous_action_dict: Dict[str, Dict[str, np.ndarray]] = {} self.memory_dict: Dict[str, np.ndarray] = {} self.normalize = trainer_settings.network_settings.normalize self.use_recurrent = self.network_settings.memory is not None self.h_size = self.network_settings.hidden_units num_layers = self.network_settings.num_layers if num_layers < 1: num_layers = 1 self.num_layers = num_layers self.vis_encode_type = self.network_settings.vis_encode_type self.tanh_squash = tanh_squash self.reparameterize = reparameterize self.condition_sigma_on_obs = condition_sigma_on_obs self.m_size = 0 self.sequence_length = 1 if self.network_settings.memory is not None: self.m_size = self.network_settings.memory.memory_size self.sequence_length = self.network_settings.memory.sequence_length # Non-exposed parameters; these aren't exposed because they don't have a # good explanation and usually shouldn't be touched. self.log_std_min = -20 self.log_std_max = 2 def make_empty_memory(self, num_agents): """ Creates empty memory for use with RNNs :param num_agents: Number of agents. :return: Numpy array of zeros. """ return np.zeros((num_agents, self.m_size), dtype=np.float32) def save_memories( self, agent_ids: List[str], memory_matrix: Optional[np.ndarray] ) -> None: if memory_matrix is None: return for index, agent_id in enumerate(agent_ids): self.memory_dict[agent_id] = memory_matrix[index, :] def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray: memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32) for index, agent_id in enumerate(agent_ids): if agent_id in self.memory_dict: memory_matrix[index, :] = self.memory_dict[agent_id] return memory_matrix def remove_memories(self, agent_ids): for agent_id in agent_ids: if agent_id in self.memory_dict: self.memory_dict.pop(agent_id) def make_empty_previous_action(self, num_agents: int) -> Dict[str, np.ndarray]: """ Creates empty previous action for use with RNNs and discrete control :param num_agents: Number of agents. :return: Dict of action type to np.ndarray """ act_dict: Dict[str, np.ndarray] = {} action_buffer = self.behavior_spec.action_spec.empty_action(num_agents) if action_buffer.continuous is not None: act_dict["continuous_action"] = action_buffer.continuous if action_buffer.discrete is not None: act_dict["discrete_action"] = action_buffer.discrete return act_dict def save_previous_action( self, agent_ids: List[str], action_dict: Dict[str, np.ndarray] ) -> None: if action_dict is None: return for index, agent_id in enumerate(agent_ids): agent_action_dict: Dict[str, np.ndarray] = {} for act_type in action_dict: agent_action_dict[act_type] = action_dict[act_type][index, :] self.previous_action_dict[agent_id] = agent_action_dict def retrieve_previous_action(self, agent_ids: List[str]) -> Dict[str, np.ndarray]: action_dict = self.make_empty_previous_action(len(agent_ids)) for index, agent_id in enumerate(agent_ids): if agent_id in self.previous_action_dict: for act_type in action_dict: action_dict[act_type][index, :] = self.previous_action_dict[ agent_id ][act_type] return action_dict def remove_previous_action(self, agent_ids): for agent_id in agent_ids: if agent_id in self.previous_action_dict: self.previous_action_dict.pop(agent_id) def get_action( self, decision_requests: DecisionSteps, worker_id: int = 0 ) -> ActionInfo: raise NotImplementedError @abstractmethod def update_normalization(self, vector_obs: np.ndarray) -> None: pass @abstractmethod def increment_step(self, n_steps): pass @abstractmethod def get_current_step(self): pass @abstractmethod def load_weights(self, values: List[np.ndarray]) -> None: pass @abstractmethod def get_weights(self) -> List[np.ndarray]: return [] @abstractmethod def init_load_weights(self) -> None: pass