浏览代码

Added Batch setting to active learning

/active-variablespeed
Scott Jordan 4 年前
当前提交
cab9d77e
共有 3 个文件被更改,包括 10 次插入5 次删除
  1. 7
      ml-agents/mlagents/trainers/active_learning.py
  2. 1
      ml-agents/mlagents/trainers/settings.py
  3. 7
      ml-agents/mlagents/trainers/task_manager.py

7
ml-agents/mlagents/trainers/active_learning.py


return MultivariateNormal(mean_x, covar_x)
class ActiveLearningTaskSampler(object):
def __init__(self,ranges, warmup_steps=30, capacity=600, num_mc=500, beta=1.96, raw_samples=128, num_restarts=1):
def __init__(self,ranges, warmup_steps:int=30, capacity:int=600, num_mc:int=500, beta:float=1.96, raw_samples:int=128, num_restarts:int=1, num_batch:int=16):
self.ranges = ranges
self.warmup_steps = warmup_steps
self.capacity = capacity

self.num_batch = num_batch
self.num_restarts = num_restarts
self.xdim = ranges.shape[0] + 1
self.model = None

self.model.set_train_data(self.X, self.Y)
# self.model = self.model.condition_on_observations(new_X, new_Y) # TODO: might be faster than setting the data need to test
def get_design_points(self, num_points:int=1, time=None):
def get_design_points(self, num_points:int=1, time=None, get_batch=True):
if get_batch:
num_points = min(num_points, self.num_batch)
if not self.model or time < self.warmup_steps:
return sample_random_points(self.bounds, num_points)

1
ml-agents/mlagents/trainers/settings.py


beta:float=1.96
raw_samples:int=128
num_restarts:int=1
num_batch:int=16
@attr.s(auto_attribs=True)
class TaskParameterSettings:

7
ml-agents/mlagents/trainers/task_manager.py


from mlagents_envs.logging_util import get_logger
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.active_learning import ActiveLearningTaskSampler
from mlagents.trainers.active_learning import ActiveLearningTaskSampler, sample_random_points
logger = get_logger(__name__)

self._taskSamplers[behavior_name] = ActiveLearningTaskSampler(task_ranges,
warmup_steps=active_hyps.warmup_steps, capacity=active_hyps.capacity,
num_mc=active_hyps.num_mc, beta=active_hyps.beta,
raw_samples=active_hyps.raw_samples, num_restarts=active_hyps.num_restarts
raw_samples=active_hyps.raw_samples, num_restarts=active_hyps.num_restarts,
num_batch=active_hyps.num_batch
self._taskSamplers[behavior_name] = lambda n: uniform_sample(task_ranges, n)
self._taskSamplers[behavior_name] = lambda n: sample_random_points(task_ranges.T, n)
self.t = {name: 0.0 for name in self.behavior_names}
def _make_task(self, behavior_name, tau):

正在加载...
取消
保存