浏览代码

Added num repeat parameter for tasks

/active-variablespeed
Scott Jordan 4 年前
当前提交
cf2a06ad
共有 2 个文件被更改,包括 17 次插入2 次删除
  1. 1
      ml-agents/mlagents/trainers/settings.py
  2. 18
      ml-agents/mlagents/trainers/task_manager.py

1
ml-agents/mlagents/trainers/settings.py


parameters: Dict[str, UniformSettings]
active_learning: Optional[ActiveLearnerSettings] = None
repeat:int=8
@staticmethod
def structure(d: Mapping, t: type) -> Dict[str, "TaskParameterSettings"]:

18
ml-agents/mlagents/trainers/task_manager.py


self.param_names = {name: list(self._dict_settings[name].parameters.keys()) for name in self.behavior_names}
self._taskSamplers = {}
self.report_buffer = []
self.num_repeat = {name: 1 for name in self.behavior_names}
for behavior_name in self.behavior_names:
lows = []

lows.append(low)
highs.append(high)
task_ranges = torch.tensor([lows, highs]).float().T
num_repeat = self._dict_settings[behavior_name].num_repeat
self.num_repeat[behavior_name] = num_repeat
active_hyps = self._dict_settings[behavior_name].active_learning
if active_hyps:
self._taskSamplers[behavior_name] = ActiveLearningTaskSampler(task_ranges,

else:
self._taskSamplers[behavior_name] = lambda n: sample_random_points(task_ranges.T, n)
self.t = {name: 0.0 for name in self.behavior_names}
self.counter = {name: 0 for name in self.behavior_names}
def _make_task(self, behavior_name, tau):
task = {}

# print("sampled taus", current_time, taus)
tasks = [self._make_task(behavior_name, tau) for tau in taus]
self.report_buffer.extend(tasks)
return tasks
tasks_repeated = []
for i in range(self.num_repeat[behavior_name]):
tasks_repeated.extend(tasks)
return tasks_repeated
def update(self, behavior_name: str, task_perfs: List[Tuple[Dict, float]]
) -> Tuple[bool, bool]:

X = torch.stack(taus, dim=0)
Y = torch.tensor(perfs).float().reshape(-1, 1)
self._taskSamplers[behavior_name].update_model(X, Y, refit=True)
N = len(task_perfs)
self.counter[behavior_name] += N
if self.counter[behavior_name] >= self.num_repeat:
refit = True
self.counter[behavior_name] = 0
else:
refit = False
self._taskSamplers[behavior_name].update_model(X, Y, refit=refit)
return updated, must_reset

正在加载...
取消
保存