|
|
|
|
|
|
self.param_names = {name: list(self._dict_settings[name].parameters.keys()) for name in self.behavior_names} |
|
|
|
self._taskSamplers = {} |
|
|
|
self.report_buffer = [] |
|
|
|
self.num_repeat = {name: 1 for name in self.behavior_names} |
|
|
|
|
|
|
|
for behavior_name in self.behavior_names: |
|
|
|
lows = [] |
|
|
|
|
|
|
lows.append(low) |
|
|
|
highs.append(high) |
|
|
|
task_ranges = torch.tensor([lows, highs]).float().T |
|
|
|
num_repeat = self._dict_settings[behavior_name].num_repeat |
|
|
|
self.num_repeat[behavior_name] = num_repeat |
|
|
|
active_hyps = self._dict_settings[behavior_name].active_learning |
|
|
|
if active_hyps: |
|
|
|
self._taskSamplers[behavior_name] = ActiveLearningTaskSampler(task_ranges, |
|
|
|
|
|
|
else: |
|
|
|
self._taskSamplers[behavior_name] = lambda n: sample_random_points(task_ranges.T, n) |
|
|
|
self.t = {name: 0.0 for name in self.behavior_names} |
|
|
|
self.counter = {name: 0 for name in self.behavior_names} |
|
|
|
|
|
|
|
def _make_task(self, behavior_name, tau): |
|
|
|
task = {} |
|
|
|
|
|
|
# print("sampled taus", current_time, taus) |
|
|
|
tasks = [self._make_task(behavior_name, tau) for tau in taus] |
|
|
|
self.report_buffer.extend(tasks) |
|
|
|
return tasks |
|
|
|
tasks_repeated = [] |
|
|
|
for i in range(self.num_repeat[behavior_name]): |
|
|
|
tasks_repeated.extend(tasks) |
|
|
|
return tasks_repeated |
|
|
|
|
|
|
|
def update(self, behavior_name: str, task_perfs: List[Tuple[Dict, float]] |
|
|
|
) -> Tuple[bool, bool]: |
|
|
|
|
|
|
|
|
|
|
X = torch.stack(taus, dim=0) |
|
|
|
Y = torch.tensor(perfs).float().reshape(-1, 1) |
|
|
|
self._taskSamplers[behavior_name].update_model(X, Y, refit=True) |
|
|
|
N = len(task_perfs) |
|
|
|
self.counter[behavior_name] += N |
|
|
|
if self.counter[behavior_name] >= self.num_repeat: |
|
|
|
refit = True |
|
|
|
self.counter[behavior_name] = 0 |
|
|
|
else: |
|
|
|
refit = False |
|
|
|
self._taskSamplers[behavior_name].update_model(X, Y, refit=refit) |
|
|
|
|
|
|
|
return updated, must_reset |
|
|
|
|
|
|
|