浏览代码
Refactor Trainer and Model (#2360)
Refactor Trainer and Model (#2360)
- Move common functions to trainer.py, model.pyfromppo/trainer.py, ppo/policy.pyandppo/model.py' - Introduce RLTrainer class and move most of add_experiences and some common reward signal code there. PPO and SAC will inherit from this, not so much BC Trainer. - Add methods to Buffer to enable sampling, truncating, and save/loading. - Add scoping to create encoders in model.py/develop-gpu-test
GitHub
5 年前
当前提交
7b69bd14
共有 16 个文件被更改,包括 836 次插入 和 653 次删除
-
2ml-agents/mlagents/trainers/bc/models.py
-
11ml-agents/mlagents/trainers/bc/offline_trainer.py
-
11ml-agents/mlagents/trainers/bc/online_trainer.py
-
30ml-agents/mlagents/trainers/bc/trainer.py
-
61ml-agents/mlagents/trainers/buffer.py
-
5ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
-
260ml-agents/mlagents/trainers/models.py
-
229ml-agents/mlagents/trainers/ppo/models.py
-
26ml-agents/mlagents/trainers/ppo/policy.py
-
301ml-agents/mlagents/trainers/ppo/trainer.py
-
4ml-agents/mlagents/trainers/tests/mock_brain.py
-
43ml-agents/mlagents/trainers/tests/test_buffer.py
-
171ml-agents/mlagents/trainers/trainer.py
-
1ml-agents/setup.py
-
253ml-agents/mlagents/trainers/rl_trainer.py
-
81ml-agents/mlagents/trainers/tests/test_rl_trainer.py
|
|||
# # Unity ML-Agents Toolkit |
|||
import logging |
|||
from typing import Dict, List, Deque, Any |
|||
import os |
|||
import tensorflow as tf |
|||
import numpy as np |
|||
from collections import deque, defaultdict |
|||
|
|||
from mlagents.envs import UnityException, AllBrainInfo, ActionInfoOutputs, BrainInfo |
|||
from mlagents.trainers.buffer import Buffer |
|||
from mlagents.trainers.tf_policy import Policy |
|||
from mlagents.trainers.trainer import Trainer, UnityTrainerException |
|||
from mlagents.envs import BrainParameters |
|||
|
|||
LOGGER = logging.getLogger("mlagents.trainers") |
|||
|
|||
|
|||
class RLTrainer(Trainer): |
|||
""" |
|||
This class is the base class for trainers that use Reward Signals. |
|||
Contains methods for adding BrainInfos to the Buffer. |
|||
""" |
|||
|
|||
def __init__(self, *args, **kwargs): |
|||
super(RLTrainer, self).__init__(*args, **kwargs) |
|||
self.step = 0 |
|||
# Make sure we have at least one reward_signal |
|||
if not self.trainer_parameters["reward_signals"]: |
|||
raise UnityTrainerException( |
|||
"No reward signals were defined. At least one must be used with {}.".format( |
|||
self.__class__.__name__ |
|||
) |
|||
) |
|||
# collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward |
|||
# used for reporting only. We always want to report the environment reward to Tensorboard, regardless |
|||
# of what reward signals are actually present. |
|||
self.collected_rewards = {"environment": {}} |
|||
self.training_buffer = Buffer() |
|||
self.episode_steps = {} |
|||
|
|||
def construct_curr_info(self, next_info: BrainInfo) -> BrainInfo: |
|||
""" |
|||
Constructs a BrainInfo which contains the most recent previous experiences for all agents |
|||
which correspond to the agents in a provided next_info. |
|||
:BrainInfo next_info: A t+1 BrainInfo. |
|||
:return: curr_info: Reconstructed BrainInfo to match agents of next_info. |
|||
""" |
|||
visual_observations: List[List[Any]] = [ |
|||
[] |
|||
] # TODO add types to brain.py methods |
|||
vector_observations = [] |
|||
text_observations = [] |
|||
memories = [] |
|||
rewards = [] |
|||
local_dones = [] |
|||
max_reacheds = [] |
|||
agents = [] |
|||
prev_vector_actions = [] |
|||
prev_text_actions = [] |
|||
action_masks = [] |
|||
for agent_id in next_info.agents: |
|||
agent_brain_info = self.training_buffer[agent_id].last_brain_info |
|||
if agent_brain_info is None: |
|||
agent_brain_info = next_info |
|||
agent_index = agent_brain_info.agents.index(agent_id) |
|||
for i in range(len(next_info.visual_observations)): |
|||
visual_observations[i].append( |
|||
agent_brain_info.visual_observations[i][agent_index] |
|||
) |
|||
vector_observations.append( |
|||
agent_brain_info.vector_observations[agent_index] |
|||
) |
|||
text_observations.append(agent_brain_info.text_observations[agent_index]) |
|||
if self.policy.use_recurrent: |
|||
if len(agent_brain_info.memories) > 0: |
|||
memories.append(agent_brain_info.memories[agent_index]) |
|||
else: |
|||
memories.append(self.policy.make_empty_memory(1)) |
|||
rewards.append(agent_brain_info.rewards[agent_index]) |
|||
local_dones.append(agent_brain_info.local_done[agent_index]) |
|||
max_reacheds.append(agent_brain_info.max_reached[agent_index]) |
|||
agents.append(agent_brain_info.agents[agent_index]) |
|||
prev_vector_actions.append( |
|||
agent_brain_info.previous_vector_actions[agent_index] |
|||
) |
|||
prev_text_actions.append( |
|||
agent_brain_info.previous_text_actions[agent_index] |
|||
) |
|||
action_masks.append(agent_brain_info.action_masks[agent_index]) |
|||
if self.policy.use_recurrent: |
|||
memories = np.vstack(memories) |
|||
curr_info = BrainInfo( |
|||
visual_observations, |
|||
vector_observations, |
|||
text_observations, |
|||
memories, |
|||
rewards, |
|||
agents, |
|||
local_dones, |
|||
prev_vector_actions, |
|||
prev_text_actions, |
|||
max_reacheds, |
|||
action_masks, |
|||
) |
|||
return curr_info |
|||
|
|||
def add_experiences( |
|||
self, |
|||
curr_all_info: AllBrainInfo, |
|||
next_all_info: AllBrainInfo, |
|||
take_action_outputs: ActionInfoOutputs, |
|||
) -> None: |
|||
""" |
|||
Adds experiences to each agent's experience history. |
|||
:param curr_all_info: Dictionary of all current brains and corresponding BrainInfo. |
|||
:param next_all_info: Dictionary of all current brains and corresponding BrainInfo. |
|||
:param take_action_outputs: The outputs of the Policy's get_action method. |
|||
""" |
|||
self.trainer_metrics.start_experience_collection_timer() |
|||
if take_action_outputs: |
|||
self.stats["Policy/Entropy"].append(take_action_outputs["entropy"].mean()) |
|||
self.stats["Policy/Learning Rate"].append( |
|||
take_action_outputs["learning_rate"] |
|||
) |
|||
for name, signal in self.policy.reward_signals.items(): |
|||
self.stats[signal.value_name].append( |
|||
np.mean(take_action_outputs["value_heads"][name]) |
|||
) |
|||
|
|||
curr_info = curr_all_info[self.brain_name] |
|||
next_info = next_all_info[self.brain_name] |
|||
|
|||
for agent_id in curr_info.agents: |
|||
self.training_buffer[agent_id].last_brain_info = curr_info |
|||
self.training_buffer[ |
|||
agent_id |
|||
].last_take_action_outputs = take_action_outputs |
|||
|
|||
if curr_info.agents != next_info.agents: |
|||
curr_to_use = self.construct_curr_info(next_info) |
|||
else: |
|||
curr_to_use = curr_info |
|||
|
|||
tmp_rewards_dict = {} |
|||
for name, signal in self.policy.reward_signals.items(): |
|||
tmp_rewards_dict[name] = signal.evaluate(curr_to_use, next_info) |
|||
|
|||
for agent_id in next_info.agents: |
|||
stored_info = self.training_buffer[agent_id].last_brain_info |
|||
stored_take_action_outputs = self.training_buffer[ |
|||
agent_id |
|||
].last_take_action_outputs |
|||
if stored_info is not None: |
|||
idx = stored_info.agents.index(agent_id) |
|||
next_idx = next_info.agents.index(agent_id) |
|||
if not stored_info.local_done[idx]: |
|||
for i, _ in enumerate(stored_info.visual_observations): |
|||
self.training_buffer[agent_id]["visual_obs%d" % i].append( |
|||
stored_info.visual_observations[i][idx] |
|||
) |
|||
self.training_buffer[agent_id]["next_visual_obs%d" % i].append( |
|||
next_info.visual_observations[i][next_idx] |
|||
) |
|||
if self.policy.use_vec_obs: |
|||
self.training_buffer[agent_id]["vector_obs"].append( |
|||
stored_info.vector_observations[idx] |
|||
) |
|||
self.training_buffer[agent_id]["next_vector_in"].append( |
|||
next_info.vector_observations[next_idx] |
|||
) |
|||
if self.policy.use_recurrent: |
|||
if stored_info.memories.shape[1] == 0: |
|||
stored_info.memories = np.zeros( |
|||
(len(stored_info.agents), self.policy.m_size) |
|||
) |
|||
self.training_buffer[agent_id]["memory"].append( |
|||
stored_info.memories[idx] |
|||
) |
|||
|
|||
self.training_buffer[agent_id]["masks"].append(1.0) |
|||
self.training_buffer[agent_id]["done"].append( |
|||
next_info.local_done[next_idx] |
|||
) |
|||
# Add the outputs of the last eval |
|||
self.add_policy_outputs(stored_take_action_outputs, agent_id, idx) |
|||
# Store action masks if neccessary |
|||
if not self.policy.use_continuous_act: |
|||
self.training_buffer[agent_id]["action_mask"].append( |
|||
stored_info.action_masks[idx], padding_value=1 |
|||
) |
|||
self.training_buffer[agent_id]["prev_action"].append( |
|||
stored_info.previous_vector_actions[idx] |
|||
) |
|||
|
|||
values = stored_take_action_outputs["value_heads"] |
|||
# Add the value outputs if needed |
|||
self.add_rewards_outputs( |
|||
values, tmp_rewards_dict, agent_id, idx, next_idx |
|||
) |
|||
|
|||
for name, rewards in self.collected_rewards.items(): |
|||
if agent_id not in rewards: |
|||
rewards[agent_id] = 0 |
|||
if name == "environment": |
|||
# Report the reward from the environment |
|||
rewards[agent_id] += np.array(next_info.rewards)[next_idx] |
|||
else: |
|||
# Report the reward signals |
|||
rewards[agent_id] += tmp_rewards_dict[name].scaled_reward[ |
|||
next_idx |
|||
] |
|||
if not next_info.local_done[next_idx]: |
|||
if agent_id not in self.episode_steps: |
|||
self.episode_steps[agent_id] = 0 |
|||
self.episode_steps[agent_id] += 1 |
|||
self.trainer_metrics.end_experience_collection_timer() |
|||
|
|||
def add_policy_outputs( |
|||
self, take_action_outputs: ActionInfoOutputs, agent_id: str, agent_idx: int |
|||
) -> None: |
|||
""" |
|||
Takes the output of the last action and store it into the training buffer. |
|||
We break this out from add_experiences since it is very highly dependent |
|||
on the type of trainer. |
|||
:param take_action_outputs: The outputs of the Policy's get_action method. |
|||
:param agent_id: the Agent we're adding to. |
|||
:param agent_idx: the index of the Agent agent_id |
|||
""" |
|||
raise UnityTrainerException( |
|||
"The process_experiences method was not implemented." |
|||
) |
|||
|
|||
def add_rewards_outputs( |
|||
self, |
|||
value: Dict[str, Any], |
|||
rewards_dict: Dict[str, float], |
|||
agent_id: str, |
|||
agent_idx: int, |
|||
agent_next_idx: int, |
|||
) -> None: |
|||
""" |
|||
Takes the value and evaluated rewards output of the last action and store it |
|||
into the training buffer. We break this out from add_experiences since it is very |
|||
highly dependent on the type of trainer. |
|||
:param take_action_outputs: The outputs of the Policy's get_action method. |
|||
:param rewards_dict: Dict of rewards after evaluation |
|||
:param agent_id: the Agent we're adding to. |
|||
:param agent_idx: the index of the Agent agent_id in the current brain info |
|||
:param agent_next_idx: the index of the Agent agent_id in the next brain info |
|||
""" |
|||
raise UnityTrainerException( |
|||
"The process_experiences method was not implemented." |
|||
) |
|
|||
import unittest.mock as mock |
|||
import pytest |
|||
import yaml |
|||
import mlagents.trainers.tests.mock_brain as mb |
|||
import numpy as np |
|||
from mlagents.trainers.rl_trainer import RLTrainer |
|||
|
|||
|
|||
@pytest.fixture |
|||
def dummy_config(): |
|||
return yaml.safe_load( |
|||
""" |
|||
summary_path: "test/" |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 1.0 |
|||
gamma: 0.99 |
|||
""" |
|||
) |
|||
|
|||
|
|||
def create_mock_brain(): |
|||
mock_brain = mb.create_mock_brainparams( |
|||
vector_action_space_type="continuous", |
|||
vector_action_space_size=[2], |
|||
vector_observation_space_size=8, |
|||
number_visual_observations=1, |
|||
) |
|||
return mock_brain |
|||
|
|||
|
|||
def create_rl_trainer(): |
|||
mock_brainparams = create_mock_brain() |
|||
trainer = RLTrainer(mock_brainparams, dummy_config(), True, 0) |
|||
return trainer |
|||
|
|||
|
|||
def create_mock_all_brain_info(brain_info): |
|||
return {"MockBrain": brain_info} |
|||
|
|||
|
|||
def create_mock_policy(): |
|||
mock_policy = mock.Mock() |
|||
mock_policy.reward_signals = {} |
|||
return mock_policy |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.add_policy_outputs") |
|||
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.add_rewards_outputs") |
|||
def test_rl_trainer(add_policy_outputs, add_rewards_outputs): |
|||
trainer = create_rl_trainer() |
|||
trainer.policy = create_mock_policy() |
|||
fake_action_outputs = { |
|||
"action": [0.1, 0.1], |
|||
"value_heads": {}, |
|||
"entropy": np.array([1.0]), |
|||
"learning_rate": 1.0, |
|||
} |
|||
mock_braininfo = mb.create_mock_braininfo( |
|||
num_agents=2, |
|||
num_vector_observations=8, |
|||
num_vector_acts=2, |
|||
num_vis_observations=1, |
|||
) |
|||
trainer.add_experiences( |
|||
create_mock_all_brain_info(mock_braininfo), |
|||
create_mock_all_brain_info(mock_braininfo), |
|||
fake_action_outputs, |
|||
) |
|||
|
|||
# Remove one of the agents |
|||
next_mock_braininfo = mb.create_mock_braininfo( |
|||
num_agents=1, |
|||
num_vector_observations=8, |
|||
num_vector_acts=2, |
|||
num_vis_observations=1, |
|||
) |
|||
brain_info = trainer.construct_curr_info(next_mock_braininfo) |
|||
|
|||
# assert construct_curr_info worked properly |
|||
assert len(brain_info.agents) == 1 |
撰写
预览
正在加载...
取消
保存
Reference in new issue