浏览代码

added histogram recorded, fixed active learning bug

added histogram recorder for task samples. Fixed a bug that prevented active learning from being used.
/active-variablespeed
Scott Jordan 4 年前
当前提交
87969325
共有 5 个文件被更改,包括 148 次插入22 次删除
  1. 11
      ml-agents/mlagents/trainers/active_learning.py
  2. 5
      ml-agents/mlagents/trainers/agent_processor.py
  3. 137
      ml-agents/mlagents/trainers/stats.py
  4. 7
      ml-agents/mlagents/trainers/task_manager.py
  5. 10
      ml-agents/mlagents/trainers/trainer_controller.py

11
ml-agents/mlagents/trainers/active_learning.py


self.xdim = ranges.shape[0] + 1
self.model = None
self.mll = None
self.Xdata = None
self.Ydata = None
self.X = None
self.Y = None
self.bounds = torch.tensor(ranges)
self.bounds = torch.cat([self.bounds, torch.tensor([[0.0,1.0]])]).T

def update_model(self, new_X, new_Y, refit=False):
if self.model is not None:
if self.X is not None:
state_dict = self.model.state_dict()
if self.model is not None:
state_dict = self.model.state_dict()
else:
state_dict = None
else:
self.X = new_X.float()
self.Y = new_Y.float()

5
ml-agents/mlagents/trainers/agent_processor.py


if global_id in self.last_step_result: # Don't store if agent just reset
self.last_take_action_outputs[global_id] = take_action_outputs
if global_id not in self.episode_tasks.keys():
# print("gid: {0} not in episodes tasks - prev actions".format(global_id))
self._assign_task(worker_id, global_id, local_id)
# Iterate over all the terminal steps

terminal_step, global_id, terminal_steps.agent_id_to_index[local_id]
)
if global_id not in self.episode_tasks.keys():
# print("gid: {0} not in episodes tasks - terminal".format(global_id))
self._assign_task(worker_id, global_id, local_id)
# Iterate over all the decision steps
for ongoing_step in decision_steps.values():

ongoing_step, global_id, decision_steps.agent_id_to_index[local_id]
)
if global_id not in self.episode_tasks.keys():
# print("gid: {0} not in episodes tasks - decision".format(global_id))
self._assign_task(worker_id, global_id, local_id)
for _gid in action_global_agent_ids:

task = self.task_queue.pop(0)
self.episode_tasks[global_id] = task
self.set_task_params_fn(worker_id, local_id, task)
# print("assigned gid {0} to task {1}".format(global_id, task))
# print("gid {0} has not already requested a task, requesting new task".format(global_id))
self.tasks_needed[global_id] = (worker_id, local_id)

137
ml-agents/mlagents/trainers/stats.py


from typing import List, Dict, NamedTuple, Any, Optional
import numpy as np
import abc
import csv
import os
import time
from threading import RLock

class StatsPropertyType(Enum):
HYPERPARAMETERS = "hyperparameters"
SELF_PLAY = "selfplay"
SALIENCY = "saliency"
class StatsWriter(abc.ABC):

"""
Add a generic property to the StatsWriter. This could be e.g. a Dict of hyperparameters,
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
with all types of properties. For instance, a TB writer doesn't need a max step.
with all types of properties. For instance, a TB writer doesn't need a max step, nor should
we write hyperparameters to the CSV.
:param category: The category that the property belongs to.
:param type: The type of property.
:param value: The property itself.

class GaugeWriter(StatsWriter):
"""
Write all stats that we receive to the timer gauges, so we can track them offline easily
Write all stats that we recieve to the timer gauges, so we can track them offline easily
"""
@staticmethod

) -> None:
is_training = "Not Training."
if "Is Training" in values:
stats_summary = values["Is Training"]
stats_summary = stats_summary = values["Is Training"]
elapsed_time = time.time() - self.training_start_time
log_info: List[str] = [category]
log_info.append(f"Step: {step}")
log_info.append(f"Time Elapsed: {elapsed_time:0.3f} s")
log_info.append(f"Mean Reward: {stats_summary.mean:0.3f}")
log_info.append(f"Std of Reward: {stats_summary.std:0.3f}")
log_info.append(is_training)
logger.info(
"{}: Step: {}. "
"Time Elapsed: {:0.3f} s "
"Mean "
"Reward: {:0.3f}"
". Std of Reward: {:0.3f}. {}".format(
category,
step,
time.time() - self.training_start_time,
stats_summary.mean,
stats_summary.std,
is_training,
)
)
log_info.append(f"ELO: {elo_stats.mean:0.3f}")
logger.info(f"{category} ELO: {elo_stats.mean:0.3f}. ")
log_info.append("No episode was completed since last summary")
log_info.append(is_training)
logger.info(". ".join(log_info))
logger.info(
"{}: Step: {}. No episode was completed since last summary. {}".format(
category, step, is_training
)
)
def add_property(
self, category: str, property_type: StatsPropertyType, value: Any

self.summary_writers: Dict[str, tf.summary.FileWriter] = {}
self.base_dir: str = base_dir
self._clear_past_data = clear_past_data
self.trajectories = 0
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int

if summary is not None:
self.summary_writers[category].add_summary(summary, 0)
elif property_type == StatsPropertyType.SALIENCY:
self._maybe_create_summary_writer(category)
# adapted from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
def create_summary(label, values):
values = np.array(values)
counts, bin_edges = np.histogram(values, bins=len(values))
hist = tf.HistogramProto()
# value = value / np.sum(value)
# value = np.log(value)
# value = value - np.min(value)
# value = value / np.sum(value)
# for obs, grad in sorted(enumerate(value), reverse=True, key=lambda x: x[1]):
# print(f"Observation {obs} has relevance {grad}")
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(np.square(values)))
# hist.min = 0.0
# hist.max = float(np.max(value))
# hist.num = len(value)
# hist.sum = float(np.sum(value))
# hist.sum_squares = float(np.sum(value ** 2))
bin_edges = bin_edges[1:]
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)
return tf.Summary.Value(tag=label, histo=hist)
if isinstance(value, dict):
svals = [create_summary(k,v) for k,v in value.items()]
else:
svals = create_summary("Saliency", value)
# Create and write Summary
# summary = tf.Summary(value=[tf.Summary.Value(tag="Saliency", histo=hist)])
summary = tf.Summary(value=svals)
self.summary_writers[category].add_summary(summary, self.trajectories)
self.summary_writers[category].flush()
self.trajectories += 1
def _dict_to_tensorboard(
self, name: str, input_dict: Dict[str, Any]
) -> Optional[bytes]:

return None
class CSVWriter(StatsWriter):
def __init__(self, base_dir: str, required_fields: List[str] = None):
"""
A StatsWriter that writes to a Tensorboard summary.
:param base_dir: The directory within which to place the CSV file, which will be {base_dir}/{category}.csv.
:param required_fields: If provided, the CSV writer won't write until these fields have statistics to write for
them.
"""
# We need to keep track of the fields in the CSV, as all rows need the same fields.
self.csv_fields: Dict[str, List[str]] = {}
self.required_fields = required_fields if required_fields else []
self.base_dir: str = base_dir
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int
) -> None:
if self._maybe_create_csv_file(category, list(values.keys())):
row = [str(step)]
# Only record the stats that showed up in the first valid row
for key in self.csv_fields[category]:
_val = values.get(key, None)
row.append(str(_val.mean) if _val else "None")
with open(self._get_filepath(category), "a") as file:
writer = csv.writer(file)
writer.writerow(row)
def _maybe_create_csv_file(self, category: str, keys: List[str]) -> bool:
"""
If no CSV file exists and the keys have the required values,
make the CSV file and write hte title row.
Returns True if there is now (or already is) a valid CSV file.
"""
if category not in self.csv_fields:
summary_dir = self.base_dir
os.makedirs(summary_dir, exist_ok=True)
# Only store if the row contains the required fields
if all(item in keys for item in self.required_fields):
self.csv_fields[category] = keys
with open(self._get_filepath(category), "w") as file:
title_row = ["Steps"]
title_row.extend(keys)
writer = csv.writer(file)
writer.writerow(title_row)
return True
return False
return True
def _get_filepath(self, category: str) -> str:
file_dir = os.path.join(self.base_dir, category + ".csv")
return file_dir
class StatsReporter:
writers: List[StatsWriter] = []
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list))

"""
Add a generic property to the StatsReporter. This could be e.g. a Dict of hyperparameters,
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
with all types of properties. For instance, a TB writer doesn't need a max step.
with all types of properties. For instance, a TB writer doesn't need a max step, nor should
we write hyperparameters to the CSV.
:param key: The type of property.
:param value: The property itself.
"""

7
ml-agents/mlagents/trainers/task_manager.py


self.behavior_names = list(self._dict_settings.keys())
self.param_names = {name: list(self._dict_settings[name].parameters.keys()) for name in self.behavior_names}
self._taskSamplers = {}
self.report_buffer = []
for behavior_name in self.behavior_names:
lows = []

taus = self._taskSamplers[behavior_name].get_design_points(num_points=num_samples, time=current_time).data.numpy().tolist()
else:
taus = self._taskSamplers[behavior_name](num_samples).tolist()
# print("sampled taus", current_time, taus)
self.report_buffer.extend(tasks)
return tasks
def update(self, behavior_name: str, task_perfs: List[Tuple[Dict, float]]

taus.append(tau)
X = torch.stack(taus, dim=0)
Y = torch.tensor(perfs).float()
Y = torch.tensor(perfs).float().reshape(-1, 1)
def uniform_sample(ranges, num_samples):
low = ranges[:, 0]

10
ml-agents/mlagents/trainers/trainer_controller.py


from mlagents.trainers.trainer_util import TrainerFactory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.agent_processor import AgentManager
from mlagents.trainers.stats import StatsPropertyType
class TrainerController:
def __init__(

self.task_manager.update(behavior_name, task_perf)
K = manager.get_num_tasks_needed()
if K > 0:
# print("num tasks needed: ", K)
if len(self.task_manager.report_buffer) >= 16:
d = defaultdict(list)
for task in self.task_manager.report_buffer:
for k,v in task.items():
d[k].append(v)
manager.stats_reporter.add_property(StatsPropertyType.SALIENCY, d)
self.task_manager.report_buffer = []
for trainer in self.trainers.values():
if not trainer.threaded:

正在加载...
取消
保存