浏览代码

[skip ci] updating sac

/release_3_distributed
Anupam Bhatnagar 5 年前
当前提交
f7a3c06e
共有 1 个文件被更改,包括 14 次插入5 次删除
  1. 19
      ml-agents/mlagents/trainers/sac/trainer.py

19
ml-agents/mlagents/trainers/sac/trainer.py


)
)
def should_still_train(self) -> bool:
"""
Returns whether or not the trainer should train. A Trainer could
stop training if it wasn't training to begin with, or if max_steps
is reached.
"""
return self.is_training and self.steps_per_update * self.update_steps <= self.get_max_steps
super()._process_trajectory(trajectory)
# super()._process_trajectory(trajectory)
self._maybe_write_summary(self.get_step + int(self.steps_per_update))
self._maybe_save_model(self.get_step + int(self.steps_per_update))
self._increment_step(self.hyperparameters.buffer_size, self.brain_name)
last_step = trajectory.steps[-1]
agent_id = trajectory.agent_id # All the agents should have the same ID

"""
has_updated = False
self.cumulative_returns_since_policy_update.clear()
self._maybe_write_summary(self.get_step + int(self.steps_per_update))
self._maybe_save_model(self.get_step + int(self.steps_per_update))
self._increment_step(self.hyperparameters.buffer_size, self.brain_name)
n_sequences = max(
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)

self.step - self.hyperparameters.buffer_init_steps
) / self.update_steps > self.steps_per_update:
logger.debug("Updating SAC policy at step {}".format(self.step))
buffer = self.update_buffer
if self.update_buffer.num_experiences >= self.hyperparameters.batch_size:

正在加载...
取消
保存