|
|
|
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents_envs.timers import timed |
|
|
|
|
|
|
|
LOG_STD_MAX = 2 |
|
|
|
LOG_STD_MIN = -20 |
|
|
|
DISCRETE_TARGET_ENTROPY_SCALE = 0.2 # Roughly equal to e-greedy 0.05 |
|
|
|
CONTINUOUS_TARGET_ENTROPY_SCALE = 1.0 # TODO: Make these an optional hyperparam. |
|
|
|
|
|
|
|
BURN_IN_RATIO = 0.0 |
|
|
|
|
|
|
|
|
|
|
|
class SACOptimizer(TFOptimizer): |
|
|
|
|
|
|
trainer_params.get("vis_encode_type", "simple") |
|
|
|
) |
|
|
|
self.tau = trainer_params.get("tau", 0.005) |
|
|
|
self.burn_in_ratio = float(trainer_params.get("burn_in_ratio", 0.0)) |
|
|
|
|
|
|
|
# Non-exposed SAC parameters |
|
|
|
self.discrete_target_entropy_scale = ( |
|
|
|
0.2 |
|
|
|
) # Roughly equal to e-greedy 0.05 |
|
|
|
self.continuous_target_entropy_scale = 1.0 |
|
|
|
stream_names = self.reward_signals.keys() |
|
|
|
stream_names = list(self.reward_signals.keys()) |
|
|
|
# Use to reduce "survivor bonus" when using Curiosity or GAIL. |
|
|
|
self.gammas = [ |
|
|
|
_val["gamma"] for _val in trainer_params["reward_signals"].values() |
|
|
|
|
|
|
|
|
|
|
if discrete: |
|
|
|
self.target_entropy = [ |
|
|
|
DISCRETE_TARGET_ENTROPY_SCALE * np.log(i).astype(np.float32) |
|
|
|
self.discrete_target_entropy_scale * np.log(i).astype(np.float32) |
|
|
|
for i in self.act_size |
|
|
|
] |
|
|
|
discrete_action_probs = tf.exp(self.policy.all_log_probs) |
|
|
|
|
|
|
-1 |
|
|
|
* CONTINUOUS_TARGET_ENTROPY_SCALE |
|
|
|
* self.continuous_target_entropy_scale |
|
|
|
* np.prod(self.act_size[0]).astype(np.float32) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
:param num_sequences: Number of LSTM sequences in batch. |
|
|
|
""" |
|
|
|
# Do an optional burn-in for memories |
|
|
|
num_burn_in = int(BURN_IN_RATIO * self.policy.sequence_length) |
|
|
|
num_burn_in = int(self.burn_in_ratio * self.policy.sequence_length) |
|
|
|
burn_in_mask = np.ones((self.policy.sequence_length), dtype=np.float32) |
|
|
|
burn_in_mask[range(0, num_burn_in)] = 0 |
|
|
|
burn_in_mask = np.tile(burn_in_mask, num_sequences) |
|
|
|