浏览代码
Fix flake8 import warnings (#2584)
Fix flake8 import warnings (#2584)
We have been ignoring unused imports and star imports via flake8. These are both bad practice and grow over time without automated checking. This commit attempts to fix all existing import errors and add back the corresponding flake8 checks./develop-gpu-test
GitHub
5 年前
当前提交
67d754c5
共有 70 个文件被更改,包括 668 次插入 和 774 次删除
-
1gym-unity/gym_unity/__init__.py
-
372gym-unity/gym_unity/envs/__init__.py
-
6gym-unity/gym_unity/tests/test_gym.py
-
5ml-agents-envs/mlagents/envs/__init__.py
-
2ml-agents-envs/mlagents/envs/base_unity_environment.py
-
3ml-agents-envs/mlagents/envs/communicator.py
-
22ml-agents-envs/mlagents/envs/communicator_objects/__init__.py
-
4ml-agents-envs/mlagents/envs/env_manager.py
-
20ml-agents-envs/mlagents/envs/environment.py
-
14ml-agents-envs/mlagents/envs/mock_communicator.py
-
4ml-agents-envs/mlagents/envs/policy.py
-
6ml-agents-envs/mlagents/envs/rpc_communicator.py
-
4ml-agents-envs/mlagents/envs/sampler_class.py
-
3ml-agents-envs/mlagents/envs/simple_env_manager.py
-
4ml-agents-envs/mlagents/envs/socket_communicator.py
-
7ml-agents-envs/mlagents/envs/subprocess_env_manager.py
-
27ml-agents-envs/mlagents/envs/tests/test_envs.py
-
4ml-agents-envs/mlagents/envs/tests/test_rpc_communicator.py
-
1ml-agents-envs/mlagents/envs/tests/test_sampler_class.py
-
3ml-agents-envs/mlagents/envs/tests/test_subprocess_env_manager.py
-
20ml-agents/mlagents/trainers/__init__.py
-
4ml-agents/mlagents/trainers/bc/__init__.py
-
4ml-agents/mlagents/trainers/bc/online_trainer.py
-
5ml-agents/mlagents/trainers/bc/trainer.py
-
3ml-agents/mlagents/trainers/buffer.py
-
1ml-agents/mlagents/trainers/components/bc/__init__.py
-
1ml-agents/mlagents/trainers/components/bc/model.py
-
110ml-agents/mlagents/trainers/components/reward_signals/__init__.py
-
1ml-agents/mlagents/trainers/components/reward_signals/curiosity/__init__.py
-
1ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
-
1ml-agents/mlagents/trainers/components/reward_signals/extrinsic/__init__.py
-
1ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
-
1ml-agents/mlagents/trainers/components/reward_signals/gail/__init__.py
-
1ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
-
2ml-agents/mlagents/trainers/components/reward_signals/reward_signal_factory.py
-
6ml-agents/mlagents/trainers/demo_loader.py
-
4ml-agents/mlagents/trainers/learn.py
-
3ml-agents/mlagents/trainers/ppo/__init__.py
-
2ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py
-
2ml-agents/mlagents/trainers/ppo/policy.py
-
7ml-agents/mlagents/trainers/ppo/trainer.py
-
12ml-agents/mlagents/trainers/rl_trainer.py
-
3ml-agents/mlagents/trainers/sac/__init__.py
-
8ml-agents/mlagents/trainers/sac/policy.py
-
11ml-agents/mlagents/trainers/sac/trainer.py
-
1ml-agents/mlagents/trainers/tests/mock_brain.py
-
2ml-agents/mlagents/trainers/tests/test_barracuda_converter.py
-
22ml-agents/mlagents/trainers/tests/test_bc.py
-
10ml-agents/mlagents/trainers/tests/test_bcmodule.py
-
3ml-agents/mlagents/trainers/tests/test_curriculum.py
-
2ml-agents/mlagents/trainers/tests/test_demo_loader.py
-
5ml-agents/mlagents/trainers/tests/test_learn.py
-
26ml-agents/mlagents/trainers/tests/test_meta_curriculum.py
-
6ml-agents/mlagents/trainers/tests/test_multigpu.py
-
5ml-agents/mlagents/trainers/tests/test_policy.py
-
35ml-agents/mlagents/trainers/tests/test_ppo.py
-
27ml-agents/mlagents/trainers/tests/test_reward_signals.py
-
45ml-agents/mlagents/trainers/tests/test_sac.py
-
4ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
2ml-agents/mlagents/trainers/tests/test_trainer_controller.py
-
2ml-agents/mlagents/trainers/tests/test_trainer_metrics.py
-
8ml-agents/mlagents/trainers/tests/test_trainer_util.py
-
7ml-agents/mlagents/trainers/tf_policy.py
-
10ml-agents/mlagents/trainers/trainer.py
-
4ml-agents/mlagents/trainers/trainer_controller.py
-
4ml-agents/mlagents/trainers/trainer_util.py
-
4setup.cfg
-
1utils/validate_meta_files.py
-
371gym-unity/gym_unity/envs/unity_env.py
-
110ml-agents/mlagents/trainers/components/reward_signals/reward_signal.py
|
|||
from gym.envs.registration import register |
|
|||
from gym_unity.envs.unity_env import UnityEnv, UnityGymException |
|||
import logging |
|||
import itertools |
|||
import gym |
|||
import numpy as np |
|||
from mlagents.envs.environment import UnityEnvironment |
|||
from gym import error, spaces |
|||
|
|||
|
|||
class UnityGymException(error.Error): |
|||
""" |
|||
Any error related to the gym wrapper of ml-agents. |
|||
""" |
|||
|
|||
pass |
|||
|
|||
|
|||
logging.basicConfig(level=logging.INFO) |
|||
logger = logging.getLogger("gym_unity") |
|||
|
|||
|
|||
class UnityEnv(gym.Env): |
|||
""" |
|||
Provides Gym wrapper for Unity Learning Environments. |
|||
Multi-agent environments use lists for object types, as done here: |
|||
https://github.com/openai/multiagent-particle-envs |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
environment_filename: str, |
|||
worker_id: int = 0, |
|||
use_visual: bool = False, |
|||
uint8_visual: bool = False, |
|||
multiagent: bool = False, |
|||
flatten_branched: bool = False, |
|||
no_graphics: bool = False, |
|||
allow_multiple_visual_obs: bool = False, |
|||
): |
|||
""" |
|||
Environment initialization |
|||
:param environment_filename: The UnityEnvironment path or file to be wrapped in the gym. |
|||
:param worker_id: Worker number for environment. |
|||
:param use_visual: Whether to use visual observation or vector observation. |
|||
:param uint8_visual: Return visual observations as uint8 (0-255) matrices instead of float (0.0-1.0). |
|||
:param multiagent: Whether to run in multi-agent mode (lists of obs, reward, done). |
|||
:param flatten_branched: If True, turn branched discrete action spaces into a Discrete space rather than |
|||
MultiDiscrete. |
|||
:param no_graphics: Whether to run the Unity simulator in no-graphics mode |
|||
:param allow_multiple_visual_obs: If True, return a list of visual observations instead of only one. |
|||
""" |
|||
self._env = UnityEnvironment( |
|||
environment_filename, worker_id, no_graphics=no_graphics |
|||
) |
|||
self.name = self._env.academy_name |
|||
self.visual_obs = None |
|||
self._current_state = None |
|||
self._n_agents = None |
|||
self._multiagent = multiagent |
|||
self._flattener = None |
|||
self.game_over = ( |
|||
False |
|||
) # Hidden flag used by Atari environments to determine if the game is over |
|||
self._allow_multiple_visual_obs = allow_multiple_visual_obs |
|||
|
|||
# Check brain configuration |
|||
if len(self._env.brains) != 1: |
|||
raise UnityGymException( |
|||
"There can only be one brain in a UnityEnvironment " |
|||
"if it is wrapped in a gym." |
|||
) |
|||
if len(self._env.external_brain_names) <= 0: |
|||
raise UnityGymException( |
|||
"There are not any external brain in the UnityEnvironment" |
|||
) |
|||
|
|||
self.brain_name = self._env.external_brain_names[0] |
|||
brain = self._env.brains[self.brain_name] |
|||
|
|||
if use_visual and brain.number_visual_observations == 0: |
|||
raise UnityGymException( |
|||
"`use_visual` was set to True, however there are no" |
|||
" visual observations as part of this environment." |
|||
) |
|||
self.use_visual = brain.number_visual_observations >= 1 and use_visual |
|||
|
|||
if not use_visual and uint8_visual: |
|||
logger.warning( |
|||
"`uint8_visual was set to true, but visual observations are not in use. " |
|||
"This setting will not have any effect." |
|||
) |
|||
else: |
|||
self.uint8_visual = uint8_visual |
|||
|
|||
if brain.number_visual_observations > 1 and not self._allow_multiple_visual_obs: |
|||
logger.warning( |
|||
"The environment contains more than one visual observation. " |
|||
"You must define allow_multiple_visual_obs=True to received them all. " |
|||
"Otherwise, please note that only the first will be provided in the observation." |
|||
) |
|||
|
|||
if brain.num_stacked_vector_observations != 1: |
|||
raise UnityGymException( |
|||
"There can only be one stacked vector observation in a UnityEnvironment " |
|||
"if it is wrapped in a gym." |
|||
) |
|||
|
|||
# Check for number of agents in scene. |
|||
initial_info = self._env.reset()[self.brain_name] |
|||
self._check_agents(len(initial_info.agents)) |
|||
|
|||
# Set observation and action spaces |
|||
if brain.vector_action_space_type == "discrete": |
|||
if len(brain.vector_action_space_size) == 1: |
|||
self._action_space = spaces.Discrete(brain.vector_action_space_size[0]) |
|||
else: |
|||
if flatten_branched: |
|||
self._flattener = ActionFlattener(brain.vector_action_space_size) |
|||
self._action_space = self._flattener.action_space |
|||
else: |
|||
self._action_space = spaces.MultiDiscrete( |
|||
brain.vector_action_space_size |
|||
) |
|||
|
|||
else: |
|||
if flatten_branched: |
|||
logger.warning( |
|||
"The environment has a non-discrete action space. It will " |
|||
"not be flattened." |
|||
) |
|||
high = np.array([1] * brain.vector_action_space_size[0]) |
|||
self._action_space = spaces.Box(-high, high, dtype=np.float32) |
|||
high = np.array([np.inf] * brain.vector_observation_space_size) |
|||
self.action_meanings = brain.vector_action_descriptions |
|||
if self.use_visual: |
|||
if brain.camera_resolutions[0]["blackAndWhite"]: |
|||
depth = 1 |
|||
else: |
|||
depth = 3 |
|||
self._observation_space = spaces.Box( |
|||
0, |
|||
1, |
|||
dtype=np.float32, |
|||
shape=( |
|||
brain.camera_resolutions[0]["height"], |
|||
brain.camera_resolutions[0]["width"], |
|||
depth, |
|||
), |
|||
) |
|||
else: |
|||
self._observation_space = spaces.Box(-high, high, dtype=np.float32) |
|||
|
|||
def reset(self): |
|||
"""Resets the state of the environment and returns an initial observation. |
|||
In the case of multi-agent environments, this is a list. |
|||
Returns: observation (object/list): the initial observation of the |
|||
space. |
|||
""" |
|||
info = self._env.reset()[self.brain_name] |
|||
n_agents = len(info.agents) |
|||
self._check_agents(n_agents) |
|||
self.game_over = False |
|||
|
|||
if not self._multiagent: |
|||
obs, reward, done, info = self._single_step(info) |
|||
else: |
|||
obs, reward, done, info = self._multi_step(info) |
|||
return obs |
|||
|
|||
def step(self, action): |
|||
"""Run one timestep of the environment's dynamics. When end of |
|||
episode is reached, you are responsible for calling `reset()` |
|||
to reset this environment's state. |
|||
Accepts an action and returns a tuple (observation, reward, done, info). |
|||
In the case of multi-agent environments, these are lists. |
|||
Args: |
|||
action (object/list): an action provided by the environment |
|||
Returns: |
|||
observation (object/list): agent's observation of the current environment |
|||
reward (float/list) : amount of reward returned after previous action |
|||
done (boolean/list): whether the episode has ended. |
|||
info (dict): contains auxiliary diagnostic information, including BrainInfo. |
|||
""" |
|||
|
|||
# Use random actions for all other agents in environment. |
|||
if self._multiagent: |
|||
if not isinstance(action, list): |
|||
raise UnityGymException( |
|||
"The environment was expecting `action` to be a list." |
|||
) |
|||
if len(action) != self._n_agents: |
|||
raise UnityGymException( |
|||
"The environment was expecting a list of {} actions.".format( |
|||
self._n_agents |
|||
) |
|||
) |
|||
else: |
|||
if self._flattener is not None: |
|||
# Action space is discrete and flattened - we expect a list of scalars |
|||
action = [self._flattener.lookup_action(_act) for _act in action] |
|||
action = np.array(action) |
|||
else: |
|||
if self._flattener is not None: |
|||
# Translate action into list |
|||
action = self._flattener.lookup_action(action) |
|||
|
|||
info = self._env.step(action)[self.brain_name] |
|||
n_agents = len(info.agents) |
|||
self._check_agents(n_agents) |
|||
self._current_state = info |
|||
|
|||
if not self._multiagent: |
|||
obs, reward, done, info = self._single_step(info) |
|||
self.game_over = done |
|||
else: |
|||
obs, reward, done, info = self._multi_step(info) |
|||
self.game_over = all(done) |
|||
return obs, reward, done, info |
|||
|
|||
def _single_step(self, info): |
|||
if self.use_visual: |
|||
visual_obs = info.visual_observations |
|||
|
|||
if self._allow_multiple_visual_obs: |
|||
visual_obs_list = [] |
|||
for obs in visual_obs: |
|||
visual_obs_list.append(self._preprocess_single(obs[0])) |
|||
self.visual_obs = visual_obs_list |
|||
else: |
|||
self.visual_obs = self._preprocess_single(visual_obs[0][0]) |
|||
|
|||
default_observation = self.visual_obs |
|||
else: |
|||
default_observation = info.vector_observations[0, :] |
|||
|
|||
return ( |
|||
default_observation, |
|||
info.rewards[0], |
|||
info.local_done[0], |
|||
{"text_observation": info.text_observations[0], "brain_info": info}, |
|||
) |
|||
|
|||
def _preprocess_single(self, single_visual_obs): |
|||
if self.uint8_visual: |
|||
return (255.0 * single_visual_obs).astype(np.uint8) |
|||
else: |
|||
return single_visual_obs |
|||
|
|||
def _multi_step(self, info): |
|||
if self.use_visual: |
|||
self.visual_obs = self._preprocess_multi(info.visual_observations) |
|||
default_observation = self.visual_obs |
|||
else: |
|||
default_observation = info.vector_observations |
|||
return ( |
|||
list(default_observation), |
|||
info.rewards, |
|||
info.local_done, |
|||
{"text_observation": info.text_observations, "brain_info": info}, |
|||
) |
|||
|
|||
def _preprocess_multi(self, multiple_visual_obs): |
|||
if self.uint8_visual: |
|||
return [ |
|||
(255.0 * _visual_obs).astype(np.uint8) |
|||
for _visual_obs in multiple_visual_obs |
|||
] |
|||
else: |
|||
return multiple_visual_obs |
|||
|
|||
def render(self, mode="rgb_array"): |
|||
return self.visual_obs |
|||
|
|||
def close(self): |
|||
"""Override _close in your subclass to perform any necessary cleanup. |
|||
Environments will automatically close() themselves when |
|||
garbage collected or when the program exits. |
|||
""" |
|||
self._env.close() |
|||
|
|||
def get_action_meanings(self): |
|||
return self.action_meanings |
|||
|
|||
def seed(self, seed=None): |
|||
"""Sets the seed for this env's random number generator(s). |
|||
Currently not implemented. |
|||
""" |
|||
logger.warn("Could not seed environment %s", self.name) |
|||
return |
|||
|
|||
def _check_agents(self, n_agents): |
|||
if not self._multiagent and n_agents > 1: |
|||
raise UnityGymException( |
|||
"The environment was launched as a single-agent environment, however" |
|||
"there is more than one agent in the scene." |
|||
) |
|||
elif self._multiagent and n_agents <= 1: |
|||
raise UnityGymException( |
|||
"The environment was launched as a mutli-agent environment, however" |
|||
"there is only one agent in the scene." |
|||
) |
|||
if self._n_agents is None: |
|||
self._n_agents = n_agents |
|||
logger.info("{} agents within environment.".format(n_agents)) |
|||
elif self._n_agents != n_agents: |
|||
raise UnityGymException( |
|||
"The number of agents in the environment has changed since " |
|||
"initialization. This is not supported." |
|||
) |
|||
|
|||
@property |
|||
def metadata(self): |
|||
return {"render.modes": ["rgb_array"]} |
|||
|
|||
@property |
|||
def reward_range(self): |
|||
return -float("inf"), float("inf") |
|||
|
|||
@property |
|||
def spec(self): |
|||
return None |
|||
|
|||
@property |
|||
def action_space(self): |
|||
return self._action_space |
|||
|
|||
@property |
|||
def observation_space(self): |
|||
return self._observation_space |
|||
|
|||
@property |
|||
def number_agents(self): |
|||
return self._n_agents |
|||
|
|||
|
|||
class ActionFlattener: |
|||
""" |
|||
Flattens branched discrete action spaces into single-branch discrete action spaces. |
|||
""" |
|||
|
|||
def __init__(self, branched_action_space): |
|||
""" |
|||
Initialize the flattener. |
|||
:param branched_action_space: A List containing the sizes of each branch of the action |
|||
space, e.g. [2,3,3] for three branches with size 2, 3, and 3 respectively. |
|||
""" |
|||
self._action_shape = branched_action_space |
|||
self.action_lookup = self._create_lookup(self._action_shape) |
|||
self.action_space = spaces.Discrete(len(self.action_lookup)) |
|||
|
|||
@classmethod |
|||
def _create_lookup(self, branched_action_space): |
|||
""" |
|||
Creates a Dict that maps discrete actions (scalars) to branched actions (lists). |
|||
Each key in the Dict maps to one unique set of branched actions, and each value |
|||
contains the List of branched actions. |
|||
""" |
|||
possible_vals = [range(_num) for _num in branched_action_space] |
|||
all_actions = [list(_action) for _action in itertools.product(*possible_vals)] |
|||
# Dict should be faster than List for large action spaces |
|||
action_lookup = { |
|||
_scalar: _action for (_scalar, _action) in enumerate(all_actions) |
|||
} |
|||
return action_lookup |
|||
|
|||
def lookup_action(self, action): |
|||
""" |
|||
Convert a scalar discrete action into a unique set of branched actions. |
|||
:param: action: A scalar value representing one of the discrete actions. |
|||
:return: The List containing the branched actions. |
|||
""" |
|||
return self.action_lookup[action] |
|
|||
from .brain import AllBrainInfo, BrainInfo, BrainParameters |
|||
from .action_info import ActionInfo, ActionInfoOutputs |
|||
from .policy import Policy |
|||
from .environment import * |
|||
from .exception import * |
|
|||
from .agent_action_proto_pb2 import * |
|||
from .agent_info_proto_pb2 import * |
|||
from .brain_parameters_proto_pb2 import * |
|||
from .command_proto_pb2 import * |
|||
from .custom_action_pb2 import * |
|||
from .custom_observation_pb2 import * |
|||
from .custom_reset_parameters_pb2 import * |
|||
from .demonstration_meta_proto_pb2 import * |
|||
from .engine_configuration_proto_pb2 import * |
|||
from .environment_parameters_proto_pb2 import * |
|||
from .header_pb2 import * |
|||
from .resolution_proto_pb2 import * |
|||
from .space_type_proto_pb2 import * |
|||
from .unity_input_pb2 import * |
|||
from .unity_message_pb2 import * |
|||
from .unity_output_pb2 import * |
|||
from .unity_rl_initialization_input_pb2 import * |
|||
from .unity_rl_initialization_output_pb2 import * |
|||
from .unity_rl_input_pb2 import * |
|||
from .unity_rl_output_pb2 import * |
|||
from .unity_to_external_pb2 import * |
|||
from .unity_to_external_pb2_grpc import * |
|
|||
from .buffer import * |
|||
from .curriculum import * |
|||
from .meta_curriculum import * |
|||
from .models import * |
|||
from .trainer_metrics import * |
|||
from .trainer import * |
|||
from .tf_policy import * |
|||
from .trainer_controller import * |
|||
from .bc.models import * |
|||
from .bc.offline_trainer import * |
|||
from .bc.online_trainer import * |
|||
from .bc.policy import * |
|||
from .ppo.models import * |
|||
from .ppo.trainer import * |
|||
from .ppo.policy import * |
|||
from .sac.models import * |
|||
from .sac.trainer import * |
|||
from .sac.policy import * |
|||
from .exception import * |
|||
from .demo_loader import * |
|
|||
from .models import * |
|||
from .online_trainer import * |
|||
from .offline_trainer import * |
|||
from .policy import * |
|
|||
from .module import BCModule |
|
|||
from .reward_signal import * |
|||
import logging |
|||
from typing import Any, Dict, List |
|||
from collections import namedtuple |
|||
import numpy as np |
|||
import abc |
|||
|
|||
import tensorflow as tf |
|||
|
|||
from mlagents.envs.brain import BrainInfo |
|||
from mlagents.trainers.trainer import UnityTrainerException |
|||
from mlagents.trainers.tf_policy import TFPolicy |
|||
from mlagents.trainers.models import LearningModel |
|||
|
|||
logger = logging.getLogger("mlagents.trainers") |
|||
|
|||
RewardSignalResult = namedtuple( |
|||
"RewardSignalResult", ["scaled_reward", "unscaled_reward"] |
|||
) |
|||
|
|||
|
|||
class RewardSignal(abc.ABC): |
|||
def __init__( |
|||
self, |
|||
policy: TFPolicy, |
|||
policy_model: LearningModel, |
|||
strength: float, |
|||
gamma: float, |
|||
): |
|||
""" |
|||
Initializes a reward signal. At minimum, you must pass in the policy it is being applied to, |
|||
the reward strength, and the gamma (discount factor.) |
|||
:param policy: The Policy object (e.g. PPOPolicy) that this Reward Signal will apply to. |
|||
:param strength: The strength of the reward. The reward's raw value will be multiplied by this value. |
|||
:param gamma: The time discounting factor used for this reward. |
|||
:return: A RewardSignal object. |
|||
""" |
|||
class_name = self.__class__.__name__ |
|||
short_name = class_name.replace("RewardSignal", "") |
|||
self.stat_name = f"Policy/{short_name} Reward" |
|||
self.value_name = f"Policy/{short_name} Value Estimate" |
|||
# Terminate discounted reward computation at Done. Can disable to mitigate positive bias in rewards with |
|||
# no natural end, e.g. GAIL or Curiosity |
|||
self.use_terminal_states = True |
|||
self.update_dict: Dict[str, tf.Tensor] = {} |
|||
self.gamma = gamma |
|||
self.policy = policy |
|||
self.policy_model = policy_model |
|||
self.strength = strength |
|||
self.stats_name_to_update_name: Dict[str, str] = {} |
|||
|
|||
def evaluate( |
|||
self, current_info: BrainInfo, next_info: BrainInfo |
|||
) -> RewardSignalResult: |
|||
""" |
|||
Evaluates the reward for the agents present in current_info given the next_info |
|||
:param current_info: The current BrainInfo. |
|||
:param next_info: The BrainInfo from the next timestep. |
|||
:return: a RewardSignalResult of (scaled intrinsic reward, unscaled intrinsic reward) provided by the generator |
|||
""" |
|||
return RewardSignalResult( |
|||
self.strength * np.zeros(len(current_info.agents)), |
|||
np.zeros(len(current_info.agents)), |
|||
) |
|||
|
|||
def evaluate_batch(self, mini_batch: Dict[str, np.array]) -> RewardSignalResult: |
|||
""" |
|||
Evaluates the reward for the data present in the Dict mini_batch. Note the distiction between |
|||
evaluate(), which takes in two BrainInfos. This reflects the different data formats (i.e. from the Buffer |
|||
vs. before being placed into the Buffer. Use this when evaluating a reward function drawn straight from a |
|||
Buffer. |
|||
:param mini_batch: A Dict of numpy arrays (the format used by our Buffer) |
|||
when drawing from the update buffer. |
|||
:return: a RewardSignalResult of (scaled intrinsic reward, unscaled intrinsic reward) provided by the generator |
|||
""" |
|||
mini_batch_len = len(next(iter(mini_batch.values()))) |
|||
return RewardSignalResult( |
|||
self.strength * np.zeros(mini_batch_len), np.zeros(mini_batch_len) |
|||
) |
|||
|
|||
def prepare_update( |
|||
self, |
|||
policy_model: LearningModel, |
|||
mini_batch: Dict[str, np.ndarray], |
|||
num_sequences: int, |
|||
) -> Dict[tf.Tensor, Any]: |
|||
""" |
|||
If the reward signal has an internal model (e.g. GAIL or Curiosity), get the feed_dict |
|||
needed to update the buffer.. |
|||
:param update_buffer: An AgentBuffer that contains the live data from which to update. |
|||
:param n_sequences: The number of sequences in the training buffer. |
|||
:return: A dict that corresponds to the feed_dict needed for the update. |
|||
""" |
|||
return {} |
|||
|
|||
@classmethod |
|||
def check_config( |
|||
cls, config_dict: Dict[str, Any], param_keys: List[str] = None |
|||
) -> None: |
|||
""" |
|||
Check the config dict, and throw an error if there are missing hyperparameters. |
|||
""" |
|||
param_keys = param_keys or [] |
|||
for k in param_keys: |
|||
if k not in config_dict: |
|||
raise UnityTrainerException( |
|||
"The hyper-parameter {0} could not be found for {1}.".format( |
|||
k, cls.__name__ |
|||
) |
|||
) |
|
|||
from .signal import CuriosityRewardSignal |
|
|||
from .signal import ExtrinsicRewardSignal |
|
|||
from .signal import GAILRewardSignal |
|
|||
from .models import * |
|||
from .trainer import * |
|||
from .policy import * |
|
|||
from .models import * |
|||
from .trainer import * |
|||
from .policy import * |