Scott Jordan
4 年前
当前提交
d695c044
共有 7 个文件被更改,包括 473 次插入 和 4 次删除
-
6Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
-
31ml-agents/mlagents/trainers/agent_processor.py
-
25ml-agents/mlagents/trainers/settings.py
-
23ml-agents/mlagents/trainers/subprocess_env_manager.py
-
8ml-agents/mlagents/trainers/trainer_controller.py
-
224ml-agents/mlagents/trainers/active_learning.py
-
160ml-agents/mlagents/trainers/active_learning_manager.py
|
|||
import torch |
|||
from torch import Tensor |
|||
|
|||
|
|||
from botorch import settings |
|||
from botorch.acquisition.monte_carlo import MCAcquisitionFunction |
|||
from botorch.acquisition.objective import ScalarizedObjective, IdentityMCObjective |
|||
from botorch.models.gpytorch import GPyTorchModel |
|||
|
|||
from botorch.models.model import Model |
|||
from botorch.models import SingleTaskGP |
|||
from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler |
|||
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform |
|||
from botorch.fit import fit_gpytorch_model |
|||
from botorch.optim import optimize_acqf_cyclic, optimize_acqf |
|||
from botorch.optim.initializers import initialize_q_batch_nonneg |
|||
|
|||
from gpytorch.likelihoods import GaussianLikelihood |
|||
from gpytorch.distributions import MultivariateNormal |
|||
from gpytorch.means import ConstantMean |
|||
from gpytorch.models import ExactGP |
|||
from gpytorch.mlls import ExactMarginalLogLikelihood |
|||
from gpytorch.kernels import ScaleKernel, RBFKernel, Kernel, ProductKernel, AdditiveKernel, GridInterpolationKernel, AdditiveStructureKernel, ProductStructureKernel |
|||
from gpytorch.utils.grid import choose_grid_size |
|||
|
|||
from typing import Optional, Union |
|||
|
|||
|
|||
|
|||
class qEISP(MCAcquisitionFunction): |
|||
|
|||
def __init__( |
|||
self, |
|||
model: Model, |
|||
beta: Union[float, Tensor], |
|||
mc_points: Tensor, |
|||
sampler: Optional[MCSampler] = None, |
|||
objective: Optional[ScalarizedObjective] = None, |
|||
X_pending: Optional[Tensor] = None, |
|||
maximize: bool = True, |
|||
) -> None: |
|||
r"""q-Espected Improvement of Skill Performance. |
|||
|
|||
Args: |
|||
model: A fitted model. |
|||
beta: value to trade off between upper confidence bound and mean of fantasized performance. |
|||
mc_points: A `batch_shape x N x d` tensor of points to use for |
|||
MC-integrating the posterior variance. Usually, these are qMC |
|||
samples on the whole design space, but biased sampling directly |
|||
allows weighted integration of the posterior variance. |
|||
sampler: The sampler used for drawing fantasy samples. In the basic setting |
|||
of a standard GP (default) this is a dummy, since the variance of the |
|||
model after conditioning does not actually depend on the sampled values. |
|||
objective: A ScalarizedObjective. Required for multi-output models. |
|||
X_pending: A `n' x d`-dim Tensor of `n'` design points that have |
|||
points that have been submitted for function evaluation but |
|||
have not yet been evaluated. |
|||
maximize: If true uses the UCB of performance scaled by beta, else it uses LCB |
|||
|
|||
Docstring from BOTorch class and same with comments below |
|||
""" |
|||
super().__init__(model=model, objective=objective) |
|||
if sampler is None: |
|||
# If no sampler is provided, we use the following dummy sampler for the |
|||
# fantasize() method in forward. IMPORTANT: This assumes that the posterior |
|||
# variance does not depend on the samples y (only on x), which is true for |
|||
# standard GP models, but not in general (e.g. for other likelihoods or |
|||
# heteroskedastic GPs using a separate noise model fit on data). |
|||
sampler = SobolQMCNormalSampler( |
|||
num_samples=1, resample=False, collapse_batch_dims=True |
|||
) |
|||
if not torch.is_tensor(beta): |
|||
beta = torch.tensor(beta) |
|||
self.register_buffer("beta", beta) |
|||
self.sampler = sampler |
|||
self.X_pending = X_pending |
|||
self.register_buffer("mc_points", mc_points) |
|||
self.maximize = maximize |
|||
|
|||
@concatenate_pending_points |
|||
@t_batch_mode_transform() |
|||
def forward(self, X: Tensor) -> Tensor: |
|||
self.beta = self.beta.to(X) |
|||
with settings.propagate_grads(True): |
|||
posterior = self.model.posterior(X=X) |
|||
batch_shape = X.shape[:-2] |
|||
mean = posterior.mean.view(*batch_shape, X.shape[-2], -1) |
|||
variance = posterior.variance.view(*batch_shape, X.shape[-2], -1) |
|||
delta = self.beta.expand_as(mean) * variance.sqrt() |
|||
|
|||
if self.maximize: |
|||
Yhat = mean + delta |
|||
else: |
|||
Yhat = mean - delta |
|||
|
|||
bdims = tuple(1 for _ in X.shape[:-2]) |
|||
if self.model.num_outputs > 1: |
|||
# We use q=1 here b/c ScalarizedObjective currently does not fully exploit |
|||
# lazy tensor operations and thus may be slow / overly memory-hungry. |
|||
# TODO (T52818288): Properly use lazy tensors in scalarize_posterior |
|||
mc_points = self.mc_points.view(-1, *bdims, 1, X.size(-1)) |
|||
else: |
|||
# While we only need marginal variances, we can evaluate for q>1 |
|||
# b/c for GPyTorch models lazy evaluation can make this quite a bit |
|||
# faster than evaluting in t-batch mode with q-batch size of 1 |
|||
mc_points = self.mc_points.view(*bdims, -1, X.size(-1)) |
|||
|
|||
Yhat = Yhat.view(*batch_shape, X.shape[-2], -1) |
|||
|
|||
fantasy_model = self.model.condition_on_observations(X=X, Y=Yhat) |
|||
|
|||
posterior1 = self.model.posterior(mc_points) |
|||
posterior2 = fantasy_model.posterior(mc_points) |
|||
|
|||
# transform with the scalarized objective |
|||
posterior1 = self.objective(posterior1.mean) |
|||
posterior2 = self.objective(posterior2.mean) |
|||
|
|||
improvement = posterior2 - posterior1 |
|||
|
|||
return improvement.mean(dim=-1) |
|||
|
|||
|
|||
|
|||
class StandardActiveLearningGP(ExactGP, GPyTorchModel): |
|||
|
|||
_num_outputs = 1 # to inform GPyTorchModel API |
|||
|
|||
def __init__(self, train_X, train_Y, bounds=None): |
|||
# squeeze output dim before passing train_Y to ExactGP |
|||
super(StandardActiveLearningGP, self).__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood()) |
|||
self.mean_module = ConstantMean() |
|||
xdims = train_X.shape[-1] |
|||
self.Kspatial = ScaleKernel(RBFKernel(active_dims=torch.tensor(list(range(xdims-1))))) |
|||
self.Ktime = ScaleKernel(RBFKernel(active_dims=torch.tensor([xdims-1]))) |
|||
# Kspatial = ScaleKernel(RBFKernel()) |
|||
# Ktime = ScaleKernel(RBFKernel()) |
|||
|
|||
# self.covar_module = ScaleKernel(RBFKernel()) # AdditiveKernel(Kspatial, ProductKernel(Kspatial, Ktime)) |
|||
self.covar_module = AdditiveKernel(self.Kspatial, ProductKernel(self.Kspatial, self.Ktime)) |
|||
self.to(train_X) # make sure we're on the right device/dtype |
|||
|
|||
def forward(self, x): |
|||
mean_x = self.mean_module(x) |
|||
covar_x = self.covar_module(x) |
|||
return MultivariateNormal(mean_x, covar_x) |
|||
|
|||
class ActiveLearningTaskSampler(object): |
|||
def __init__(self,ranges): |
|||
self.ranges = ranges |
|||
self.xdim = ranges.shape[0] + 1 |
|||
self.model = None |
|||
self.mll = None |
|||
self.Xdata = None |
|||
self.Ydata = None |
|||
|
|||
self.bounds = torch.tensor(ranges) |
|||
self.bounds = torch.cat([self.bounds, torch.tensor([[0.0,1.0]])]).T |
|||
|
|||
|
|||
|
|||
def update_model(self, new_X, new_Y, refit=False): |
|||
if self.model is not None: |
|||
new_X = new_X.to(self.X) |
|||
new_Y = new_Y.to(self.X) |
|||
self.X = torch.cat([self.X, new_X.to(self.X)]) |
|||
|
|||
self.Y = torch.cat([self.Y, new_Y.to(self.X)]) |
|||
state_dict = self.model.state_dict() |
|||
else: |
|||
self.X = new_X.float() |
|||
self.Y = new_Y.float() |
|||
state_dict = None |
|||
|
|||
T = 12*50 |
|||
if self.X.shape[0] >= T: |
|||
self.X = self.X[-T:, :] |
|||
self.Y = self.Y[-T:, :] |
|||
|
|||
if refit: |
|||
model = StandardActiveLearningGP(self.X, self.Y, bounds=self.bounds) |
|||
mll = ExactMarginalLogLikelihood(model.likelihood, model) |
|||
self.model = model |
|||
self.mll = mll |
|||
if state_dict is not None: |
|||
self.model.load_state_dict(state_dict) |
|||
fit_gpytorch_model(mll) |
|||
else: |
|||
self.model.set_train_data(self.X, self.Y) |
|||
# self.model = self.model.condition_on_observations(new_X, new_Y) |
|||
|
|||
def get_design_points(self, num_points:int=1, time=None): |
|||
if not self.model or time < 30: |
|||
return sample_random_points(self.bounds, num_points) |
|||
|
|||
if not time: |
|||
time = self.X[:, -1].max() + 1 |
|||
|
|||
bounds = self.bounds |
|||
bounds[:, -1] = time |
|||
num_mc = 500 |
|||
mc_points = torch.rand(num_mc, bounds.size(1), device=self.X.device, dtype=self.X.dtype) |
|||
mc_points = bounds[0] + (bounds[1] - bounds[0]) * mc_points |
|||
|
|||
qeisp = qEISP(self.model, mc_points=mc_points, beta=1.96) |
|||
try: |
|||
candidates, acq_value = optimize_acqf( |
|||
acq_function=qeisp, |
|||
bounds=bounds, |
|||
raw_samples=128, |
|||
q=num_points, |
|||
num_restarts=1, |
|||
return_best_only=True, |
|||
) |
|||
return candidates |
|||
except: |
|||
return sample_random_points(self.bounds, num_points) |
|||
|
|||
|
|||
def sample_random_points(bounds, num_points): |
|||
points = torch.rand(num_points, bounds.size(1), device=bounds.device, dtype=bounds.dtype) |
|||
points = bounds[0] + (bounds[1] - bounds[0]) * points |
|||
return points |
|||
|
|
|||
from typing import Dict, List, Tuple, Optional |
|||
from mlagents.trainers.settings import ( |
|||
EnvironmentParameterSettings, |
|||
ParameterRandomizationSettings, |
|||
) |
|||
from collections import defaultdict |
|||
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType |
|||
|
|||
from mlagents_envs.logging_util import get_logger |
|||
|
|||
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager |
|||
from mlagents.trainers.active_learning import ActiveLearningTaskSampler |
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class ActiveLearningTaskManager(EnvironmentParameterManager): |
|||
def __init__( |
|||
self, |
|||
settings: Optional[Dict[str, AgentParameterSettings]] = None, |
|||
run_seed: int = -1, |
|||
restore: bool = False, |
|||
): |
|||
""" |
|||
EnvironmentParameterManager manages all the environment parameters of a training |
|||
session. It determines when parameters should change and gives access to the |
|||
current sampler of each parameter. |
|||
:param settings: A dictionary from environment parameter to |
|||
EnvironmentParameterSettings. |
|||
:param run_seed: When the seed is not provided for an environment parameter, |
|||
this seed will be used instead. |
|||
:param restore: If true, the EnvironmentParameterManager will use the |
|||
GlobalTrainingStatus to try and reload the lesson status of each environment |
|||
parameter. |
|||
""" |
|||
if settings is None: |
|||
settings = {} |
|||
self._dict_settings = settings |
|||
lows = [] |
|||
highs = [] |
|||
for parameter_name in self._dict_settings.keys(): |
|||
self._dict_settings[parameter_name]. |
|||
|
|||
self._smoothed_values: Dict[str, float] = defaultdict(float) |
|||
|
|||
for key in self._dict_settings.keys(): |
|||
self._smoothed_values[key] = 0.0 |
|||
# Update the seeds of the samplers |
|||
self._set_sampler_seeds(run_seed) |
|||
|
|||
task_ranges = [] |
|||
self._taskSampler = ActiveLearningTaskSampler(task_ranges) |
|||
|
|||
def _set_sampler_seeds(self, seed): |
|||
""" |
|||
Sets the seeds for the samplers (if no seed was already present). Note that |
|||
using the provided seed. |
|||
""" |
|||
offset = 0 |
|||
for settings in self._dict_settings.values(): |
|||
for lesson in settings.curriculum: |
|||
if lesson.value.seed == -1: |
|||
lesson.value.seed = seed + offset |
|||
offset += 1 |
|||
|
|||
def get_minimum_reward_buffer_size(self, behavior_name: str) -> int: |
|||
""" |
|||
Calculates the minimum size of the reward buffer a behavior must use. This |
|||
method uses the 'min_lesson_length' sampler_parameter to determine this value. |
|||
:param behavior_name: The name of the behavior the minimum reward buffer |
|||
size corresponds to. |
|||
""" |
|||
result = 1 |
|||
for settings in self._dict_settings.values(): |
|||
for lesson in settings.curriculum: |
|||
if lesson.completion_criteria is not None: |
|||
if lesson.completion_criteria.behavior == behavior_name: |
|||
result = max( |
|||
result, lesson.completion_criteria.min_lesson_length |
|||
) |
|||
return result |
|||
|
|||
def get_current_samplers(self) -> Dict[str, ParameterRandomizationSettings]: |
|||
""" |
|||
Creates a dictionary from environment parameter name to their corresponding |
|||
ParameterRandomizationSettings. If curriculum is used, the |
|||
ParameterRandomizationSettings corresponds to the sampler of the current lesson. |
|||
""" |
|||
samplers: Dict[str, ParameterRandomizationSettings] = {} |
|||
for param_name, settings in self._dict_settings.items(): |
|||
lesson_num = GlobalTrainingStatus.get_parameter_state( |
|||
param_name, StatusType.LESSON_NUM |
|||
) |
|||
lesson = settings.curriculum[lesson_num] |
|||
samplers[param_name] = lesson.value |
|||
return samplers |
|||
|
|||
def get_current_lesson_number(self) -> Dict[str, int]: |
|||
""" |
|||
Creates a dictionary from environment parameter to the current lesson number. |
|||
If not using curriculum, this number is always 0 for that environment parameter. |
|||
""" |
|||
result: Dict[str, int] = {} |
|||
for parameter_name in self._dict_settings.keys(): |
|||
result[parameter_name] = GlobalTrainingStatus.get_parameter_state( |
|||
parameter_name, StatusType.LESSON_NUM |
|||
) |
|||
return result |
|||
|
|||
def update_lessons( |
|||
self, |
|||
trainer_steps: Dict[str, int], |
|||
trainer_max_steps: Dict[str, int], |
|||
trainer_reward_buffer: Dict[str, List[float]], |
|||
) -> Tuple[bool, bool]: |
|||
""" |
|||
Given progress metrics, calculates if at least one environment parameter is |
|||
in a new lesson and if at least one environment parameter requires the env |
|||
to reset. |
|||
:param trainer_steps: A dictionary from behavior_name to the number of training |
|||
steps this behavior's trainer has performed. |
|||
:param trainer_max_steps: A dictionary from behavior_name to the maximum number |
|||
of training steps this behavior's trainer has performed. |
|||
:param trainer_reward_buffer: A dictionary from behavior_name to the list of |
|||
the most recent episode returns for this behavior's trainer. |
|||
:returns: A tuple of two booleans : (True if any lesson has changed, True if |
|||
environment needs to reset) |
|||
""" |
|||
must_reset = False |
|||
updated = False |
|||
for param_name, settings in self._dict_settings.items(): |
|||
lesson_num = GlobalTrainingStatus.get_parameter_state( |
|||
param_name, StatusType.LESSON_NUM |
|||
) |
|||
lesson = settings.curriculum[lesson_num] |
|||
if ( |
|||
lesson.completion_criteria is not None |
|||
and len(settings.curriculum) > lesson_num + 1 |
|||
): |
|||
behavior_to_consider = lesson.completion_criteria.behavior |
|||
if behavior_to_consider in trainer_steps: |
|||
must_increment, new_smoothing = lesson.completion_criteria.need_increment( |
|||
float(trainer_steps[behavior_to_consider]) |
|||
/ float(trainer_max_steps[behavior_to_consider]), |
|||
trainer_reward_buffer[behavior_to_consider], |
|||
self._smoothed_values[param_name], |
|||
) |
|||
self._smoothed_values[param_name] = new_smoothing |
|||
if must_increment: |
|||
GlobalTrainingStatus.set_parameter_state( |
|||
param_name, StatusType.LESSON_NUM, lesson_num + 1 |
|||
) |
|||
new_lesson_name = settings.curriculum[lesson_num + 1].name |
|||
logger.info( |
|||
f"Parameter '{param_name}' has changed. Now in lesson '{new_lesson_name}'" |
|||
) |
|||
updated = True |
|||
if lesson.completion_criteria.require_reset: |
|||
must_reset = True |
|||
return updated, must_reset |
撰写
预览
正在加载...
取消
保存
Reference in new issue