|
|
|
|
|
|
import logging |
|
|
|
from typing import Dict, Any, Optional |
|
|
|
from typing import Dict, Any, Optional, Mapping |
|
|
|
import numpy as np |
|
|
|
from mlagents.tf_utils import tf |
|
|
|
|
|
|
|
|
|
|
return update_stats |
|
|
|
|
|
|
|
def update_reward_signals( |
|
|
|
self, reward_signal_minibatches: Dict[str, Dict], num_sequences: int |
|
|
|
self, reward_signal_minibatches: Mapping[str, Dict], num_sequences: int |
|
|
|
) -> Dict[str, float]: |
|
|
|
""" |
|
|
|
Only update the reward signals. |
|
|
|
|
|
|
feed_dict: Dict[tf.Tensor, Any], |
|
|
|
update_dict: Dict[str, tf.Tensor], |
|
|
|
stats_needed: Dict[str, str], |
|
|
|
reward_signal_minibatches: Dict[str, Dict], |
|
|
|
reward_signal_minibatches: Mapping[str, Dict], |
|
|
|
num_sequences: int, |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|