浏览代码

Use steps_per_update to determine SAC train interval

/develop/sac-apex
Ervin Teng 5 年前
当前提交
293579dd
共有 1 个文件被更改,包括 18 次插入14 次删除
  1. 32
      ml-agents/mlagents/trainers/sac/trainer.py

32
ml-agents/mlagents/trainers/sac/trainer.py


self.optimizer: SACOptimizer = None # type: ignore
self.step = 0
self.train_interval = (
trainer_parameters["train_interval"]
if "train_interval" in trainer_parameters
self.update_steps = 0
self.reward_signal_update_steps = 0
self.steps_per_update = (
trainer_parameters["steps_per_update"]
if "steps_per_update" in trainer_parameters
self.reward_signal_updates_per_train = (
trainer_parameters["reward_signals"]["reward_signal_num_update"]
if "reward_signal_num_update" in trainer_parameters["reward_signals"]
else trainer_parameters["num_update"]
self.reward_signal_steps_per_update = (
trainer_parameters["reward_signals"]["reward_signal_steps_per_update"]
if "reward_signal_steps_per_update" in trainer_parameters["reward_signals"]
else self.steps_per_update
)
self.checkpoint_replay_buffer = (

If train_interval is met, update the SAC policy given the current reward signals.
If reward_signal_train_interval is met, update the reward signals from the buffer.
"""
if self.step % self.train_interval == 0:
self.update_sac_policy()
self.update_reward_signals()
self.update_sac_policy()
self.update_reward_signals()
def create_policy(self, brain_parameters: BrainParameters) -> TFPolicy:
policy = NNPolicy(

int(self.trainer_parameters["batch_size"] / self.policy.sequence_length), 1
)
num_updates = self.trainer_parameters["num_update"]
for _ in range(num_updates):
while self.step / self.update_steps > self.steps_per_update:
logger.debug("Updating SAC policy at step {}".format(self.step))
buffer = self.update_buffer
if (

for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)
self.update_steps += 1
# Truncate update buffer if neccessary. Truncate more than we need to to avoid truncating
# a large buffer at each update.
if self.update_buffer.num_experiences > self.trainer_parameters["buffer_size"]:

and policy are updated in parallel.
"""
buffer = self.update_buffer
num_updates = self.reward_signal_updates_per_train
for _ in range(num_updates):
while self.step / self.reward_signal_update_steps > self.steps_per_update:
# Get minibatches for reward signal update if needed
reward_signal_minibatches = {}
for name, signal in self.optimizer.reward_signals.items():

)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)
self.reward_signal_update_steps += 1
for stat, stat_list in batch_update_stats.items():
self._stats_reporter.add_stat(stat, np.mean(stat_list))

正在加载...
取消
保存