浏览代码

Add optional burn-in for SAC as well

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
ce110201
共有 2 个文件被更改,包括 9 次插入2 次删除
  1. 2
      ml-agents/mlagents/trainers/ppo/optimizer.py
  2. 9
      ml-agents/mlagents/trainers/sac/optimizer.py

2
ml-agents/mlagents/trainers/ppo/optimizer.py


def construct_feed_dict(
self, mini_batch: AgentBuffer, num_sequences: int
) -> Dict[tf.Tensor, Any]:
# Do a burn-in for memories
# Do an optional burn-in for memories
num_burn_in = int(BURN_IN_RATIO * self.policy.sequence_length)
burn_in_mask = np.ones((self.policy.sequence_length), dtype=np.float32)
burn_in_mask[range(0, num_burn_in)] = 0

9
ml-agents/mlagents/trainers/sac/optimizer.py


POLICY_SCOPE = ""
TARGET_SCOPE = "target_network"
BURN_IN_RATIO = 0.0
class SACOptimizer(TFOptimizer):
def __init__(self, policy: TFPolicy, trainer_params: Dict[str, Any]):

:param batch: Mini-batch to use to update.
:param num_sequences: Number of LSTM sequences in batch.
"""
# Do an optional burn-in for memories
num_burn_in = int(BURN_IN_RATIO * self.policy.sequence_length)
burn_in_mask = np.ones((self.policy.sequence_length), dtype=np.float32)
burn_in_mask[range(0, num_burn_in)] = 0
burn_in_mask = np.tile(burn_in_mask, num_sequences)
self.policy.mask_input: batch["masks"],
self.policy.mask_input: batch["masks"] * burn_in_mask,
}
for name in self.reward_signals:
feed_dict[self.rewards_holders[name]] = batch["{}_rewards".format(name)]

正在加载...
取消
保存