|
|
|
|
|
|
POLICY_SCOPE = "" |
|
|
|
TARGET_SCOPE = "target_network" |
|
|
|
|
|
|
|
BURN_IN_RATIO = 0.0 |
|
|
|
|
|
|
|
|
|
|
|
class SACOptimizer(TFOptimizer): |
|
|
|
def __init__(self, policy: TFPolicy, trainer_params: Dict[str, Any]): |
|
|
|
|
|
|
:param batch: Mini-batch to use to update. |
|
|
|
:param num_sequences: Number of LSTM sequences in batch. |
|
|
|
""" |
|
|
|
# Do an optional burn-in for memories |
|
|
|
num_burn_in = int(BURN_IN_RATIO * self.policy.sequence_length) |
|
|
|
burn_in_mask = np.ones((self.policy.sequence_length), dtype=np.float32) |
|
|
|
burn_in_mask[range(0, num_burn_in)] = 0 |
|
|
|
burn_in_mask = np.tile(burn_in_mask, num_sequences) |
|
|
|
self.policy.mask_input: batch["masks"], |
|
|
|
self.policy.mask_input: batch["masks"] * burn_in_mask, |
|
|
|
} |
|
|
|
for name in self.reward_signals: |
|
|
|
feed_dict[self.rewards_holders[name]] = batch["{}_rewards".format(name)] |
|
|
|