import numpy as np from mlagents.torch_utils import torch from abc import ABC, abstractmethod from typing import Dict from mlagents.trainers.buffer import AgentBuffer from mlagents.trainers.settings import RewardSignalSettings from mlagents_envs.base_env import BehaviorSpec class BaseRewardProvider(ABC): def __init__(self, specs: BehaviorSpec, settings: RewardSignalSettings) -> None: self._policy_specs = specs self._gamma = settings.gamma self._strength = settings.strength self._ignore_done = False @property def gamma(self) -> float: """ The discount factor for the reward signal """ return self._gamma @property def strength(self) -> float: """ The strength multiplier of the reward provider """ return self._strength @property def name(self) -> str: """ The name of the reward provider. Is used for reporting and identification """ class_name = self.__class__.__name__ return class_name.replace("RewardProvider", "") @property def ignore_done(self) -> bool: """ If true, when the agent is done, the rewards of the next episode must be used to calculate the return of the current episode. Is used to mitigate the positive bias in rewards with no natural end. """ return self._ignore_done @abstractmethod def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: """ Evaluates the reward for the data present in the Dict mini_batch. Use this when evaluating a reward function drawn straight from a Buffer. :param mini_batch: A Dict of numpy arrays (the format used by our Buffer) when drawing from the update buffer. :return: a np.ndarray of rewards generated by the reward provider """ raise NotImplementedError( "The reward provider's evaluate method has not been implemented " ) @abstractmethod def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: """ Update the reward for the data present in the Dict mini_batch. Use this when updating a reward function drawn straight from a Buffer. :param mini_batch: A Dict of numpy arrays (the format used by our Buffer) when drawing from the update buffer. :return: A dictionary from string to stats values """ raise NotImplementedError( "The reward provider's update method has not been implemented " ) def get_modules(self) -> Dict[str, torch.nn.Module]: """ Returns a dictionary of string identifiers to the torch.nn.Modules used by the reward providers. This method is used for loading and saving the weights of the reward providers. """ return {}