浏览代码

clean up types/comments

/develop/add-fire/bc
Andrew Cohen 4 年前
当前提交
8ced43ee
共有 1 个文件被更改,包括 5 次插入10 次删除
  1. 15
      ml-agents/mlagents/trainers/torch/components/bc/module.py

15
ml-agents/mlagents/trainers/torch/components/bc/module.py


from typing import Dict, Any
from typing import Dict
import numpy as np
import torch

self.has_updated = False
self.use_recurrent = self.policy.use_recurrent
self.samples_per_update = settings.samples_per_update
# self.out_dict = {
# "loss": self.model.loss,
# "update": self.model.update_batch,
# "learning_rate": self.model.annealed_learning_rate,
# }
def update(self) -> Dict[str, Any]:
def update(self) -> Dict[str, np.ndarray]:
"""
Updates model using buffer.
:param max_batches: The maximum number of batches to use per update.

return bc_loss
def _update_batch(
self, mini_batch_demo: Dict[str, Any], n_sequences: int
) -> Dict[str, Any]:
self, mini_batch_demo: Dict[str, np.ndarray], n_sequences: int
) -> Dict[str, float]:
"""
Helper function for update_batch.
"""

else:
vis_obs = []
selected_actions, all_log_probs, entropies, values, memories = self.policy.sample_actions(
selected_actions, all_log_probs, _, _, _ = self.policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

正在加载...
取消
保存