GitHub
5 年前
当前提交
2fd305e7
共有 32 个文件被更改,包括 1261 次插入 和 837 次删除
-
3ml-agents/mlagents/trainers/action_info.py
-
198ml-agents/mlagents/trainers/agent_processor.py
-
29ml-agents/mlagents/trainers/buffer.py
-
1ml-agents/mlagents/trainers/components/bc/module.py
-
2ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
-
13ml-agents/mlagents/trainers/curriculum.py
-
32ml-agents/mlagents/trainers/demo_loader.py
-
6ml-agents/mlagents/trainers/learn.py
-
26ml-agents/mlagents/trainers/models.py
-
45ml-agents/mlagents/trainers/ppo/policy.py
-
212ml-agents/mlagents/trainers/ppo/trainer.py
-
246ml-agents/mlagents/trainers/rl_trainer.py
-
2ml-agents/mlagents/trainers/sac/policy.py
-
140ml-agents/mlagents/trainers/sac/trainer.py
-
45ml-agents/mlagents/trainers/tests/mock_brain.py
-
95ml-agents/mlagents/trainers/tests/test_buffer.py
-
2ml-agents/mlagents/trainers/tests/test_meta_curriculum.py
-
174ml-agents/mlagents/trainers/tests/test_ppo.py
-
48ml-agents/mlagents/trainers/tests/test_rl_trainer.py
-
38ml-agents/mlagents/trainers/tests/test_sac.py
-
3ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
22ml-agents/mlagents/trainers/tests/test_trainer_controller.py
-
4ml-agents/mlagents/trainers/tests/test_trainer_util.py
-
62ml-agents/mlagents/trainers/tf_policy.py
-
108ml-agents/mlagents/trainers/trainer.py
-
39ml-agents/mlagents/trainers/trainer_controller.py
-
4ml-agents/mlagents/trainers/trainer_util.py
-
118ml-agents/mlagents/trainers/stats.py
-
63ml-agents/mlagents/trainers/tests/test_agent_processor.py
-
80ml-agents/mlagents/trainers/tests/test_stats.py
-
110ml-agents/mlagents/trainers/tests/test_trajectory.py
-
128ml-agents/mlagents/trainers/trajectory.py
|
|||
from typing import List, Union |
|||
import sys |
|||
from typing import List, Dict |
|||
from collections import defaultdict, Counter |
|||
from mlagents.trainers.buffer import AgentBuffer, BufferException |
|||
from mlagents.trainers.trainer import Trainer |
|||
from mlagents.trainers.trajectory import Trajectory, AgentExperience |
|||
from mlagents.trainers.brain import BrainInfo |
|||
from mlagents.trainers.tf_policy import TFPolicy |
|||
from mlagents.trainers.action_info import ActionInfoOutputs |
|||
from mlagents.trainers.stats import StatsReporter |
|||
class ProcessingBuffer(dict): |
|||
class AgentProcessor: |
|||
ProcessingBuffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id. |
|||
AgentProcessor contains a dictionary per-agent trajectory buffers. The buffers are indexed by agent_id. |
|||
Buffer also contains an update_buffer that corresponds to the buffer used when updating the model. |
|||
One AgentProcessor should be created per agent group. |
|||
def __str__(self): |
|||
return "local_buffers :\n{0}".format( |
|||
"\n".join(["\tagent {0} :{1}".format(k, str(self[k])) for k in self.keys()]) |
|||
) |
|||
|
|||
def __getitem__(self, key): |
|||
if key not in self.keys(): |
|||
self[key] = AgentBuffer() |
|||
return super().__getitem__(key) |
|||
|
|||
def reset_local_buffers(self) -> None: |
|||
def __init__( |
|||
self, |
|||
trainer: Trainer, |
|||
policy: TFPolicy, |
|||
stats_reporter: StatsReporter, |
|||
max_trajectory_length: int = sys.maxsize, |
|||
): |
|||
Resets all the local AgentBuffers. |
|||
Create an AgentProcessor. |
|||
:param trainer: Trainer instance connected to this AgentProcessor. Trainer is given trajectory |
|||
when it is finished. |
|||
:param policy: Policy instance associated with this AgentProcessor. |
|||
:param max_trajectory_length: Maximum length of a trajectory before it is added to the trainer. |
|||
:param stats_category: The category under which to write the stats. Usually, this comes from the Trainer. |
|||
for buf in self.values(): |
|||
buf.reset_agent() |
|||
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list) |
|||
self.last_brain_info: Dict[str, BrainInfo] = {} |
|||
self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {} |
|||
# Note: this is needed until we switch to AgentExperiences as the data input type. |
|||
# We still need some info from the policy (memories, previous actions) |
|||
# that really should be gathered by the env-manager. |
|||
self.policy = policy |
|||
self.episode_steps: Counter = Counter() |
|||
self.episode_rewards: Dict[str, float] = defaultdict(float) |
|||
self.stats_reporter = stats_reporter |
|||
self.trainer = trainer |
|||
self.max_trajectory_length = max_trajectory_length |
|||
def append_to_update_buffer( |
|||
def add_experiences( |
|||
update_buffer: AgentBuffer, |
|||
agent_id: Union[int, str], |
|||
key_list: List[str] = None, |
|||
batch_size: int = None, |
|||
training_length: int = None, |
|||
curr_info: BrainInfo, |
|||
next_info: BrainInfo, |
|||
take_action_outputs: ActionInfoOutputs, |
|||
Appends the buffer of an agent to the update buffer. |
|||
:param update_buffer: A reference to an AgentBuffer to append the agent's buffer to |
|||
:param agent_id: The id of the agent which data will be appended |
|||
:param key_list: The fields that must be added. If None: all fields will be appended. |
|||
:param batch_size: The number of elements that must be appended. If None: All of them will be. |
|||
:param training_length: The length of the samples that must be appended. If None: only takes one element. |
|||
Adds experiences to each agent's experience history. |
|||
:param curr_info: current BrainInfo. |
|||
:param next_info: next BrainInfo. |
|||
:param take_action_outputs: The outputs of the Policy's get_action method. |
|||
if key_list is None: |
|||
key_list = self[agent_id].keys() |
|||
if not self[agent_id].check_length(key_list): |
|||
raise BufferException( |
|||
"The length of the fields {0} for agent {1} were not of same length".format( |
|||
key_list, agent_id |
|||
) |
|||
if take_action_outputs: |
|||
self.stats_reporter.add_stat( |
|||
"Policy/Entropy", take_action_outputs["entropy"].mean() |
|||
for field_key in key_list: |
|||
update_buffer[field_key].extend( |
|||
self[agent_id][field_key].get_batch( |
|||
batch_size=batch_size, training_length=training_length |
|||
) |
|||
self.stats_reporter.add_stat( |
|||
"Policy/Learning Rate", take_action_outputs["learning_rate"] |
|||
def append_all_agent_batch_to_update_buffer( |
|||
self, |
|||
update_buffer: AgentBuffer, |
|||
key_list: List[str] = None, |
|||
batch_size: int = None, |
|||
training_length: int = None, |
|||
) -> None: |
|||
""" |
|||
Appends the buffer of all agents to the update buffer. |
|||
:param key_list: The fields that must be added. If None: all fields will be appended. |
|||
:param batch_size: The number of elements that must be appended. If None: All of them will be. |
|||
:param training_length: The length of the samples that must be appended. If None: only takes one element. |
|||
""" |
|||
for agent_id in self.keys(): |
|||
self.append_to_update_buffer( |
|||
update_buffer, agent_id, key_list, batch_size, training_length |
|||
) |
|||
for agent_id in curr_info.agents: |
|||
self.last_brain_info[agent_id] = curr_info |
|||
self.last_take_action_outputs[agent_id] = take_action_outputs |
|||
|
|||
# Store the environment reward |
|||
tmp_environment_reward = next_info.rewards |
|||
|
|||
for next_idx, agent_id in enumerate(next_info.agents): |
|||
stored_info = self.last_brain_info.get(agent_id, None) |
|||
if stored_info is not None: |
|||
stored_take_action_outputs = self.last_take_action_outputs[agent_id] |
|||
idx = stored_info.agents.index(agent_id) |
|||
obs = [] |
|||
if not stored_info.local_done[idx]: |
|||
for i, _ in enumerate(stored_info.visual_observations): |
|||
obs.append(stored_info.visual_observations[i][idx]) |
|||
if self.policy.use_vec_obs: |
|||
obs.append(stored_info.vector_observations[idx]) |
|||
if self.policy.use_recurrent: |
|||
memory = self.policy.retrieve_memories([agent_id])[0, :] |
|||
else: |
|||
memory = None |
|||
|
|||
done = next_info.local_done[next_idx] |
|||
max_step = next_info.max_reached[next_idx] |
|||
|
|||
# Add the outputs of the last eval |
|||
action = stored_take_action_outputs["action"][idx] |
|||
if self.policy.use_continuous_act: |
|||
action_pre = stored_take_action_outputs["pre_action"][idx] |
|||
else: |
|||
action_pre = None |
|||
action_probs = stored_take_action_outputs["log_probs"][idx] |
|||
action_masks = stored_info.action_masks[idx] |
|||
prev_action = self.policy.retrieve_previous_action([agent_id])[0, :] |
|||
|
|||
experience = AgentExperience( |
|||
obs=obs, |
|||
reward=tmp_environment_reward[next_idx], |
|||
done=done, |
|||
action=action, |
|||
action_probs=action_probs, |
|||
action_pre=action_pre, |
|||
action_mask=action_masks, |
|||
prev_action=prev_action, |
|||
max_step=max_step, |
|||
memory=memory, |
|||
) |
|||
# Add the value outputs if needed |
|||
self.experience_buffers[agent_id].append(experience) |
|||
self.episode_rewards[agent_id] += tmp_environment_reward[next_idx] |
|||
if ( |
|||
next_info.local_done[next_idx] |
|||
or ( |
|||
len(self.experience_buffers[agent_id]) |
|||
>= self.max_trajectory_length |
|||
) |
|||
) and len(self.experience_buffers[agent_id]) > 0: |
|||
# Make next AgentExperience |
|||
next_obs = [] |
|||
for i, _ in enumerate(next_info.visual_observations): |
|||
next_obs.append(next_info.visual_observations[i][next_idx]) |
|||
if self.policy.use_vec_obs: |
|||
next_obs.append(next_info.vector_observations[next_idx]) |
|||
trajectory = Trajectory( |
|||
steps=self.experience_buffers[agent_id], |
|||
agent_id=agent_id, |
|||
next_obs=next_obs, |
|||
) |
|||
# This will eventually be replaced with a queue |
|||
self.trainer.process_trajectory(trajectory) |
|||
self.experience_buffers[agent_id] = [] |
|||
if next_info.local_done[next_idx]: |
|||
self.stats_reporter.add_stat( |
|||
"Environment/Cumulative Reward", |
|||
self.episode_rewards.get(agent_id, 0), |
|||
) |
|||
self.stats_reporter.add_stat( |
|||
"Environment/Episode Length", |
|||
self.episode_steps.get(agent_id, 0), |
|||
) |
|||
del self.episode_steps[agent_id] |
|||
del self.episode_rewards[agent_id] |
|||
elif not next_info.local_done[next_idx]: |
|||
self.episode_steps[agent_id] += 1 |
|||
self.policy.save_previous_action( |
|||
curr_info.agents, take_action_outputs["action"] |
|||
) |
|
|||
from collections import defaultdict |
|||
from typing import List, Dict, NamedTuple |
|||
import numpy as np |
|||
import abc |
|||
import os |
|||
|
|||
from mlagents.tf_utils import tf |
|||
|
|||
|
|||
class StatsWriter(abc.ABC): |
|||
""" |
|||
A StatsWriter abstract class. A StatsWriter takes in a category, key, scalar value, and step |
|||
and writes it out by some method. |
|||
""" |
|||
|
|||
@abc.abstractmethod |
|||
def write_stats(self, category: str, key: str, value: float, step: int) -> None: |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def write_text(self, category: str, text: str, step: int) -> None: |
|||
pass |
|||
|
|||
|
|||
class TensorboardWriter(StatsWriter): |
|||
def __init__(self, base_dir: str): |
|||
self.summary_writers: Dict[str, tf.summary.FileWriter] = {} |
|||
self.base_dir: str = base_dir |
|||
|
|||
def write_stats(self, category: str, key: str, value: float, step: int) -> None: |
|||
self._maybe_create_summary_writer(category) |
|||
summary = tf.Summary() |
|||
summary.value.add(tag="{}".format(key), simple_value=value) |
|||
self.summary_writers[category].add_summary(summary, step) |
|||
self.summary_writers[category].flush() |
|||
|
|||
def _maybe_create_summary_writer(self, category: str) -> None: |
|||
if category not in self.summary_writers: |
|||
filewriter_dir = "{basedir}/{category}".format( |
|||
basedir=self.base_dir, category=category |
|||
) |
|||
os.makedirs(filewriter_dir, exist_ok=True) |
|||
self.summary_writers[category] = tf.summary.FileWriter(filewriter_dir) |
|||
|
|||
def write_text(self, category: str, text: str, step: int) -> None: |
|||
self._maybe_create_summary_writer(category) |
|||
self.summary_writers[category].add_summary(text, step) |
|||
|
|||
|
|||
class StatsSummary(NamedTuple): |
|||
mean: float |
|||
std: float |
|||
num: int |
|||
|
|||
|
|||
class StatsReporter: |
|||
writers: List[StatsWriter] = [] |
|||
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list)) |
|||
|
|||
def __init__(self, category): |
|||
""" |
|||
Generic StatsReporter. A category is the broadest type of storage (would |
|||
correspond the run name and trainer name, e.g. 3DBalltest_3DBall. A key is the |
|||
type of stat it is (e.g. Environment/Reward). Finally the Value is the float value |
|||
attached to this stat. |
|||
""" |
|||
self.category: str = category |
|||
|
|||
@staticmethod |
|||
def add_writer(writer: StatsWriter) -> None: |
|||
StatsReporter.writers.append(writer) |
|||
|
|||
def add_stat(self, key: str, value: float) -> None: |
|||
""" |
|||
Add a float value stat to the StatsReporter. |
|||
:param category: The highest categorization of the statistic, e.g. behavior name. |
|||
:param key: The type of statistic, e.g. Environment/Reward. |
|||
:param value: the value of the statistic. |
|||
""" |
|||
StatsReporter.stats_dict[self.category][key].append(value) |
|||
|
|||
def write_stats(self, step: int) -> None: |
|||
""" |
|||
Write out all stored statistics that fall under the category specified. |
|||
The currently stored values will be averaged, written out as a single value, |
|||
and the buffer cleared. |
|||
:param category: The category which to write out the stats. |
|||
:param step: Training step which to write these stats as. |
|||
""" |
|||
for key in StatsReporter.stats_dict[self.category]: |
|||
if len(StatsReporter.stats_dict[self.category][key]) > 0: |
|||
stat_mean = float(np.mean(StatsReporter.stats_dict[self.category][key])) |
|||
for writer in StatsReporter.writers: |
|||
writer.write_stats(self.category, key, stat_mean, step) |
|||
del StatsReporter.stats_dict[self.category] |
|||
|
|||
def write_text(self, text: str, step: int) -> None: |
|||
""" |
|||
Write out some text. |
|||
:param category: The highest categorization of the statistic, e.g. behavior name. |
|||
:param text: The text to write out. |
|||
:param step: Training step which to write these stats as. |
|||
""" |
|||
for writer in StatsReporter.writers: |
|||
writer.write_text(self.category, text, step) |
|||
|
|||
def get_stats_summaries(self, key: str) -> StatsSummary: |
|||
""" |
|||
Get the mean, std, and count of a particular statistic, since last write. |
|||
:param category: The highest categorization of the statistic, e.g. behavior name. |
|||
:param key: The type of statistic, e.g. Environment/Reward. |
|||
:returns: A StatsSummary NamedTuple containing (mean, std, count). |
|||
""" |
|||
return StatsSummary( |
|||
mean=np.mean(StatsReporter.stats_dict[self.category][key]), |
|||
std=np.std(StatsReporter.stats_dict[self.category][key]), |
|||
num=len(StatsReporter.stats_dict[self.category][key]), |
|||
) |
|
|||
import unittest.mock as mock |
|||
import pytest |
|||
import mlagents.trainers.tests.mock_brain as mb |
|||
import numpy as np |
|||
from mlagents.trainers.agent_processor import AgentProcessor |
|||
from mlagents.trainers.stats import StatsReporter |
|||
|
|||
|
|||
def create_mock_brain(): |
|||
mock_brain = mb.create_mock_brainparams( |
|||
vector_action_space_type="continuous", |
|||
vector_action_space_size=[2], |
|||
vector_observation_space_size=8, |
|||
number_visual_observations=1, |
|||
) |
|||
return mock_brain |
|||
|
|||
|
|||
def create_mock_policy(): |
|||
mock_policy = mock.Mock() |
|||
mock_policy.reward_signals = {} |
|||
mock_policy.retrieve_memories.return_value = np.zeros((1, 1), dtype=np.float32) |
|||
mock_policy.retrieve_previous_action.return_value = np.zeros( |
|||
(1, 1), dtype=np.float32 |
|||
) |
|||
return mock_policy |
|||
|
|||
|
|||
@pytest.mark.parametrize("num_vis_obs", [0, 1, 2], ids=["vec", "1 viz", "2 viz"]) |
|||
def test_agentprocessor(num_vis_obs): |
|||
policy = create_mock_policy() |
|||
trainer = mock.Mock() |
|||
processor = AgentProcessor( |
|||
trainer, |
|||
policy, |
|||
max_trajectory_length=5, |
|||
stats_reporter=StatsReporter("testcat"), |
|||
) |
|||
fake_action_outputs = { |
|||
"action": [0.1, 0.1], |
|||
"entropy": np.array([1.0], dtype=np.float32), |
|||
"learning_rate": 1.0, |
|||
"pre_action": [0.1, 0.1], |
|||
"log_probs": [0.1, 0.1], |
|||
} |
|||
mock_braininfo = mb.create_mock_braininfo( |
|||
num_agents=2, |
|||
num_vector_observations=8, |
|||
num_vector_acts=2, |
|||
num_vis_observations=num_vis_obs, |
|||
) |
|||
for i in range(5): |
|||
processor.add_experiences(mock_braininfo, mock_braininfo, fake_action_outputs) |
|||
|
|||
# Assert that two trajectories have been added to the Trainer |
|||
assert len(trainer.process_trajectory.call_args_list) == 2 |
|||
|
|||
# Assert that the trajectory is of length 5 |
|||
trajectory = trainer.process_trajectory.call_args_list[0][0][0] |
|||
assert len(trajectory.steps) == 5 |
|||
|
|||
# Assert that the AgentProcessor is empty |
|||
assert len(processor.experience_buffers[0]) == 0 |
|
|||
import unittest.mock as mock |
|||
import os |
|||
import pytest |
|||
import tempfile |
|||
|
|||
from mlagents.trainers.stats import StatsReporter, TensorboardWriter |
|||
|
|||
|
|||
def test_stat_reporter_add_summary_write(): |
|||
# Test add_writer |
|||
StatsReporter.writers.clear() |
|||
mock_writer1 = mock.Mock() |
|||
mock_writer2 = mock.Mock() |
|||
StatsReporter.add_writer(mock_writer1) |
|||
StatsReporter.add_writer(mock_writer2) |
|||
assert len(StatsReporter.writers) == 2 |
|||
|
|||
# Test add_stats and summaries |
|||
statsreporter1 = StatsReporter("category1") |
|||
statsreporter2 = StatsReporter("category2") |
|||
for i in range(10): |
|||
statsreporter1.add_stat("key1", float(i)) |
|||
statsreporter2.add_stat("key2", float(i)) |
|||
|
|||
statssummary1 = statsreporter1.get_stats_summaries("key1") |
|||
statssummary2 = statsreporter2.get_stats_summaries("key2") |
|||
|
|||
assert statssummary1.num == 10 |
|||
assert statssummary2.num == 10 |
|||
assert statssummary1.mean == 4.5 |
|||
assert statssummary2.mean == 4.5 |
|||
assert statssummary1.std == pytest.approx(2.9, abs=0.1) |
|||
assert statssummary2.std == pytest.approx(2.9, abs=0.1) |
|||
|
|||
# Test write_stats |
|||
step = 10 |
|||
statsreporter1.write_stats(step) |
|||
mock_writer1.write_stats.assert_called_once_with("category1", "key1", 4.5, step) |
|||
mock_writer2.write_stats.assert_called_once_with("category1", "key1", 4.5, step) |
|||
|
|||
|
|||
def test_stat_reporter_text(): |
|||
# Test add_writer |
|||
mock_writer = mock.Mock() |
|||
StatsReporter.writers.clear() |
|||
StatsReporter.add_writer(mock_writer) |
|||
assert len(StatsReporter.writers) == 1 |
|||
|
|||
statsreporter1 = StatsReporter("category1") |
|||
|
|||
# Test write_text |
|||
step = 10 |
|||
statsreporter1.write_text("this is a text", step) |
|||
mock_writer.write_text.assert_called_once_with("category1", "this is a text", step) |
|||
|
|||
|
|||
@mock.patch("mlagents.tf_utils.tf.Summary") |
|||
@mock.patch("mlagents.tf_utils.tf.summary.FileWriter") |
|||
def test_tensorboard_writer(mock_filewriter, mock_summary): |
|||
# Test write_stats |
|||
category = "category1" |
|||
with tempfile.TemporaryDirectory(prefix="unittest-") as base_dir: |
|||
tb_writer = TensorboardWriter(base_dir) |
|||
tb_writer.write_stats("category1", "key1", 1.0, 10) |
|||
|
|||
# Test that the filewriter has been created and the directory has been created. |
|||
filewriter_dir = "{basedir}/{category}".format( |
|||
basedir=base_dir, category=category |
|||
) |
|||
assert os.path.exists(filewriter_dir) |
|||
mock_filewriter.assert_called_once_with(filewriter_dir) |
|||
|
|||
# Test that the filewriter was written to and the summary was added. |
|||
mock_summary.return_value.value.add.assert_called_once_with( |
|||
tag="key1", simple_value=1.0 |
|||
) |
|||
mock_filewriter.return_value.add_summary.assert_called_once_with( |
|||
mock_summary.return_value, 10 |
|||
) |
|||
mock_filewriter.return_value.flush.assert_called_once() |
|
|||
import numpy as np |
|||
import pytest |
|||
|
|||
from mlagents.trainers.trajectory import AgentExperience, Trajectory, SplitObservations |
|||
|
|||
VEC_OBS_SIZE = 6 |
|||
ACTION_SIZE = 4 |
|||
|
|||
|
|||
def make_fake_trajectory( |
|||
length: int, |
|||
max_step_complete: bool = False, |
|||
vec_obs_size: int = VEC_OBS_SIZE, |
|||
num_vis_obs: int = 1, |
|||
action_space: int = ACTION_SIZE, |
|||
) -> Trajectory: |
|||
""" |
|||
Makes a fake trajectory of length length. If max_step_complete, |
|||
the trajectory is terminated by a max step rather than a done. |
|||
""" |
|||
steps_list = [] |
|||
for i in range(length - 1): |
|||
obs = [] |
|||
for i in range(num_vis_obs): |
|||
obs.append(np.ones((84, 84, 3), dtype=np.float32)) |
|||
obs.append(np.ones(vec_obs_size, dtype=np.float32)) |
|||
reward = 1.0 |
|||
done = False |
|||
action = np.zeros(action_space, dtype=np.float32) |
|||
action_probs = np.ones(action_space, dtype=np.float32) |
|||
action_pre = np.zeros(action_space, dtype=np.float32) |
|||
action_mask = np.ones(action_space, dtype=np.float32) |
|||
prev_action = np.ones(action_space, dtype=np.float32) |
|||
max_step = False |
|||
memory = np.ones(10, dtype=np.float32) |
|||
agent_id = "test_agent" |
|||
experience = AgentExperience( |
|||
obs=obs, |
|||
reward=reward, |
|||
done=done, |
|||
action=action, |
|||
action_probs=action_probs, |
|||
action_pre=action_pre, |
|||
action_mask=action_mask, |
|||
prev_action=prev_action, |
|||
max_step=max_step, |
|||
memory=memory, |
|||
) |
|||
steps_list.append(experience) |
|||
last_experience = AgentExperience( |
|||
obs=obs, |
|||
reward=reward, |
|||
done=not max_step_complete, |
|||
action=action, |
|||
action_probs=action_probs, |
|||
action_pre=action_pre, |
|||
action_mask=action_mask, |
|||
prev_action=prev_action, |
|||
max_step=max_step_complete, |
|||
memory=memory, |
|||
) |
|||
steps_list.append(last_experience) |
|||
return Trajectory(steps=steps_list, agent_id=agent_id, next_obs=obs) |
|||
|
|||
|
|||
@pytest.mark.parametrize("num_visual_obs", [0, 1, 2]) |
|||
@pytest.mark.parametrize("num_vec_obs", [0, 1]) |
|||
def test_split_obs(num_visual_obs, num_vec_obs): |
|||
obs = [] |
|||
for i in range(num_visual_obs): |
|||
obs.append(np.ones((84, 84, 3), dtype=np.float32)) |
|||
for i in range(num_vec_obs): |
|||
obs.append(np.ones(VEC_OBS_SIZE, dtype=np.float32)) |
|||
split_observations = SplitObservations.from_observations(obs) |
|||
|
|||
if num_vec_obs == 1: |
|||
assert len(split_observations.vector_observations) == VEC_OBS_SIZE |
|||
else: |
|||
assert len(split_observations.vector_observations) == 0 |
|||
|
|||
# Assert the number of vector observations. |
|||
assert len(split_observations.visual_observations) == num_visual_obs |
|||
|
|||
|
|||
def test_trajectory_to_agentbuffer(): |
|||
length = 15 |
|||
wanted_keys = [ |
|||
"next_visual_obs0", |
|||
"visual_obs0", |
|||
"vector_obs", |
|||
"next_vector_in", |
|||
"memory", |
|||
"masks", |
|||
"done", |
|||
"actions_pre", |
|||
"actions", |
|||
"action_probs", |
|||
"action_mask", |
|||
"prev_action", |
|||
"environment_rewards", |
|||
] |
|||
wanted_keys = set(wanted_keys) |
|||
trajectory = make_fake_trajectory(length=length) |
|||
agentbuffer = trajectory.to_agentbuffer() |
|||
seen_keys = set() |
|||
for key, field in agentbuffer.items(): |
|||
assert len(field) == length |
|||
seen_keys.add(key) |
|||
|
|||
assert seen_keys == wanted_keys |
|
|||
from typing import List, NamedTuple |
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
|
|||
|
|||
class AgentExperience(NamedTuple): |
|||
obs: List[np.ndarray] |
|||
reward: float |
|||
done: bool |
|||
action: np.ndarray |
|||
action_probs: np.ndarray |
|||
action_pre: np.ndarray # TODO: Remove this |
|||
action_mask: np.ndarray |
|||
prev_action: np.ndarray |
|||
max_step: bool |
|||
memory: np.ndarray |
|||
|
|||
|
|||
class SplitObservations(NamedTuple): |
|||
vector_observations: np.ndarray |
|||
visual_observations: List[np.ndarray] |
|||
|
|||
@staticmethod |
|||
def from_observations(obs: List[np.ndarray]) -> "SplitObservations": |
|||
""" |
|||
Divides a List of numpy arrays into a SplitObservations NamedTuple. |
|||
This allows you to access the vector and visual observations directly, |
|||
without enumerating the list over and over. |
|||
:param obs: List of numpy arrays (observation) |
|||
:returns: A SplitObservations object. |
|||
""" |
|||
vis_obs_list: List[np.ndarray] = [] |
|||
vec_obs_list: List[np.ndarray] = [] |
|||
for observation in obs: |
|||
if len(observation.shape) == 1: |
|||
vec_obs_list.append(observation) |
|||
if len(observation.shape) == 3: |
|||
vis_obs_list.append(observation) |
|||
vec_obs = ( |
|||
np.concatenate(vec_obs_list, axis=0) |
|||
if len(vec_obs_list) > 0 |
|||
else np.array([], dtype=np.float32) |
|||
) |
|||
return SplitObservations( |
|||
vector_observations=vec_obs, visual_observations=vis_obs_list |
|||
) |
|||
|
|||
|
|||
class Trajectory(NamedTuple): |
|||
steps: List[AgentExperience] |
|||
next_obs: List[ |
|||
np.ndarray |
|||
] # Observation following the trajectory, for bootstrapping |
|||
agent_id: str |
|||
|
|||
def to_agentbuffer(self) -> AgentBuffer: |
|||
""" |
|||
Converts a Trajectory to an AgentBuffer |
|||
:param trajectory: A Trajectory |
|||
:returns: AgentBuffer. Note that the length of the AgentBuffer will be one |
|||
less than the trajectory, as the next observation need to be populated from the last |
|||
step of the trajectory. |
|||
""" |
|||
agent_buffer_trajectory = AgentBuffer() |
|||
vec_vis_obs = SplitObservations.from_observations(self.steps[0].obs) |
|||
for step, exp in enumerate(self.steps): |
|||
if step < len(self.steps) - 1: |
|||
next_vec_vis_obs = SplitObservations.from_observations( |
|||
self.steps[step + 1].obs |
|||
) |
|||
else: |
|||
next_vec_vis_obs = SplitObservations.from_observations(self.next_obs) |
|||
|
|||
for i, _ in enumerate(vec_vis_obs.visual_observations): |
|||
agent_buffer_trajectory["visual_obs%d" % i].append( |
|||
vec_vis_obs.visual_observations[i] |
|||
) |
|||
agent_buffer_trajectory["next_visual_obs%d" % i].append( |
|||
next_vec_vis_obs.visual_observations[i] |
|||
) |
|||
agent_buffer_trajectory["vector_obs"].append( |
|||
vec_vis_obs.vector_observations |
|||
) |
|||
agent_buffer_trajectory["next_vector_in"].append( |
|||
next_vec_vis_obs.vector_observations |
|||
) |
|||
if exp.memory is not None: |
|||
agent_buffer_trajectory["memory"].append(exp.memory) |
|||
|
|||
agent_buffer_trajectory["masks"].append(1.0) |
|||
agent_buffer_trajectory["done"].append(exp.done) |
|||
# Add the outputs of the last eval |
|||
if exp.action_pre is not None: |
|||
actions_pre = exp.action_pre |
|||
agent_buffer_trajectory["actions_pre"].append(actions_pre) |
|||
|
|||
# value is a dictionary from name of reward to value estimate of the value head |
|||
agent_buffer_trajectory["actions"].append(exp.action) |
|||
agent_buffer_trajectory["action_probs"].append(exp.action_probs) |
|||
|
|||
# Store action masks if necessary. Eventually these will be |
|||
# None for continuous actions |
|||
if exp.action_mask is not None: |
|||
agent_buffer_trajectory["action_mask"].append( |
|||
exp.action_mask, padding_value=1 |
|||
) |
|||
|
|||
agent_buffer_trajectory["prev_action"].append(exp.prev_action) |
|||
agent_buffer_trajectory["environment_rewards"].append(exp.reward) |
|||
|
|||
# Store the next visual obs as the current |
|||
vec_vis_obs = next_vec_vis_obs |
|||
return agent_buffer_trajectory |
|||
|
|||
@property |
|||
def done_reached(self) -> bool: |
|||
""" |
|||
Returns true if trajectory is terminated with a Done. |
|||
""" |
|||
return self.steps[-1].done |
|||
|
|||
@property |
|||
def max_step_reached(self) -> bool: |
|||
""" |
|||
Returns true if trajectory was terminated because max steps was reached. |
|||
""" |
|||
return self.steps[-1].max_step |
撰写
预览
正在加载...
取消
保存
Reference in new issue