浏览代码

add some types to the reward signals (#2215)

* WIP add some types to the reward signals

* fix next_visual_in

* cleanup TODO

* fix bad merge
/develop-generalizationTraining-TrainerController
GitHub 5 年前
当前提交
d80d5852
共有 5 个文件被更改,包括 70 次插入24 次删除
  1. 14
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
  2. 21
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
  3. 13
      ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
  4. 20
      ml-agents/mlagents/trainers/components/reward_signals/reward_signal.py
  5. 26
      ml-agents/mlagents/trainers/models.py

14
ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py


from typing import List, Tuple
import tensorflow as tf
from mlagents.trainers.models import LearningModel

"""
self.encoding_size = encoding_size
self.policy_model = policy_model
self.next_visual_in: List[tf.Tensor] = []
def create_curiosity_encoders(self):
def create_curiosity_encoders(self) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Creates state encoders for current and future observations.
Used for implementation of Curiosity-driven Exploration by Self-supervised Prediction

encoded_next_state = tf.concat(encoded_next_state_list, axis=1)
return encoded_state, encoded_next_state
def create_inverse_model(self, encoded_state, encoded_next_state):
def create_inverse_model(
self, encoded_state: tf.Tensor, encoded_next_state: tf.Tensor
) -> None:
"""
Creates inverse model TensorFlow ops for Curiosity module.
Predicts action taken given current and future encoded states.

tf.dynamic_partition(cross_entropy, self.policy_model.mask, 2)[1]
)
def create_forward_model(self, encoded_state, encoded_next_state):
def create_forward_model(
self, encoded_state: tf.Tensor, encoded_next_state: tf.Tensor
) -> None:
"""
Creates forward model TensorFlow ops for Curiosity module.
Predicts encoded future state based on encoded current state and given action.

tf.dynamic_partition(squared_difference, self.policy_model.mask, 2)[1]
)
def create_loss(self, learning_rate):
def create_loss(self, learning_rate: float) -> None:
"""
Creates the loss node of the model as well as the update_batch optimizer to update the model.
:param learning_rate: The learning rate for the optimizer.

21
ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py


from typing import Any, Dict, List
from mlagents.envs.brain import BrainInfo
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult
from mlagents.trainers.components.reward_signals.curiosity.model import CuriosityModel
from mlagents.trainers.tf_policy import TFPolicy

}
self.has_updated = False
def evaluate(self, current_info, next_info):
def evaluate(
self, current_info: BrainInfo, next_info: BrainInfo
) -> RewardSignalResult:
"""
Evaluates the reward for the agents present in current_info given the next_info
:param current_info: The current BrainInfo.

return RewardSignalResult(scaled_reward, unscaled_reward)
@classmethod
def check_config(cls, config_dict):
def check_config(
cls, config_dict: Dict[str, Any], param_keys: List[str] = None
) -> None:
"""
Checks the config and throw an exception if a hyperparameter is missing. Curiosity requires strength,
gamma, and encoding size at minimum.

def update(self, update_buffer, num_sequences):
def update(self, update_buffer: Buffer, num_sequences: int) -> Dict[str, float]:
"""
Updates Curiosity model using training buffer. Divides training buffer into mini batches and performs
gradient descent.

"""
forward_total, inverse_total = [], []
forward_total: List[float] = []
inverse_total: List[float] = []
for _ in range(self.num_epoch):
update_buffer.shuffle()
buffer = update_buffer

}
return update_stats
def _update_batch(self, mini_batch, num_sequences):
def _update_batch(
self, mini_batch: Dict[str, np.ndarray], num_sequences: int
) -> Dict[str, float]:
"""
Updates model using buffer.
:param num_sequences: Number of trajectories in batch.

13
ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py


from typing import Any, Dict, List
from mlagents.envs.brain import BrainInfo
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult
from mlagents.trainers.tf_policy import TFPolicy

super().__init__(policy, strength, gamma)
@classmethod
def check_config(cls, config_dict):
def check_config(
cls, config_dict: Dict[str, Any], param_keys: List[str] = None
) -> None:
"""
Checks the config and throw an exception if a hyperparameter is missing. Extrinsic requires strength and gamma
at minimum.

def evaluate(self, current_info, next_info):
def evaluate(
self, current_info: BrainInfo, next_info: BrainInfo
) -> RewardSignalResult:
"""
Evaluates the reward for the agents present in current_info given the next_info
:param current_info: The current BrainInfo.

scaled_reward = self.strength * unscaled_reward
return RewardSignalResult(scaled_reward, unscaled_reward)
def update(self, update_buffer, num_sequences):
def update(self, update_buffer: Buffer, num_sequences: int) -> Dict[str, float]:
"""
This method does nothing, as there is nothing to update.
"""

20
ml-agents/mlagents/trainers/components/reward_signals/reward_signal.py


import logging
from mlagents.trainers.trainer import UnityTrainerException
from mlagents.trainers.tf_policy import TFPolicy
from typing import Any, Dict, List
from mlagents.envs.brain import BrainInfo
from mlagents.trainers.trainer import UnityTrainerException
from mlagents.trainers.tf_policy import TFPolicy
from mlagents.trainers.buffer import Buffer
logger = logging.getLogger("mlagents.trainers")

self.policy = policy
self.strength = strength
def evaluate(self, current_info, next_info):
def evaluate(
self, current_info: BrainInfo, next_info: BrainInfo
) -> RewardSignalResult:
"""
Evaluates the reward for the agents present in current_info given the next_info
:param current_info: The current BrainInfo.

return (
return RewardSignalResult(
def update(self, update_buffer, n_sequences):
def update(self, update_buffer: Buffer, num_sequences: int) -> Dict[str, float]:
"""
If the reward signal has an internal model (e.g. GAIL or Curiosity), update that model.
:param update_buffer: An AgentBuffer that contains the live data from which to update.

return {}
@classmethod
def check_config(cls, config_dict, param_keys=None):
def check_config(
cls, config_dict: Dict[str, Any], param_keys: List[str] = None
) -> None:
"""
Check the config dict, and throw an error if there are missing hyperparameters.
"""

26
ml-agents/mlagents/trainers/models.py


import logging
from typing import Any, Callable, Dict
import numpy as np
import tensorflow as tf

ActivationFunction = Callable[[tf.Tensor], tf.Tensor]
class LearningModel(object):

return c_layers.variance_scaling_initializer(scale)
@staticmethod
def swish(input_activation):
def swish(input_activation: tf.Tensor) -> tf.Tensor:
def create_visual_input(camera_parameters, name):
def create_visual_input(camera_parameters: Dict[str, Any], name: str) -> tf.Tensor:
"""
Creates image input op.
:param camera_parameters: Parameters for visual observation from BrainInfo.

@staticmethod
def create_vector_observation_encoder(
observation_input, h_size, activation, num_layers, scope, reuse
):
observation_input: tf.Tensor,
h_size: int,
activation: ActivationFunction,
num_layers: int,
scope: str,
reuse: bool,
) -> tf.Tensor:
"""
Builds a set of hidden state encoders.
:param reuse: Whether to re-use the weights within the same scope.

return hidden
def create_visual_observation_encoder(
self, image_input, h_size, activation, num_layers, scope, reuse
):
self,
image_input: tf.Tensor,
h_size: int,
activation: ActivationFunction,
num_layers: int,
scope: str,
reuse: bool,
) -> tf.Tensor:
"""
Builds a set of visual (CNN) encoders.
:param reuse: Whether to re-use the weights within the same scope.

正在加载...
取消
保存