|
|
|
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
from abc import ABC, abstractmethod |
|
|
|
from typing import Dict |
|
|
|
|
|
|
|
|
|
|
raise NotImplementedError( |
|
|
|
"The reward provider's update method has not been implemented " |
|
|
|
) |
|
|
|
|
|
|
|
def get_modules(self) -> Dict[str, torch.nn.Module]: |
|
|
|
""" |
|
|
|
Returns a dictionary of string identifiers to the torch.nn.Modules used by |
|
|
|
the reward providers. This method is used for loading and saving the weights |
|
|
|
of the reward providers. |
|
|
|
""" |
|
|
|
return {} |