比较提交

...
此合并请求有变更与目标分支冲突。
/ml-agents/mlagents/trainers/buffer.py
/ml-agents/mlagents/trainers/torch/components/bc/module.py
/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
/ml-agents/mlagents/trainers/trajectory.py

3 次代码提交

共有 10 个文件被更改,包括 408 次插入32 次删除
  1. 20
      ml-agents/mlagents/trainers/trajectory.py
  2. 4
      ml-agents/mlagents/trainers/buffer.py
  3. 50
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
  4. 25
      ml-agents/mlagents/trainers/torch/components/bc/module.py
  5. 27
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
  6. 0
      ml-agents/mlagents/trainers/demonstrations/__init__.py
  7. 96
      ml-agents/mlagents/trainers/demonstrations/demonstration_proto_utils.py
  8. 72
      ml-agents/mlagents/trainers/demonstrations/demonstration_provider.py
  9. 146
      ml-agents/mlagents/trainers/demonstrations/local_demonstration_provider.py

20
ml-agents/mlagents/trainers/trajectory.py


from typing import List, NamedTuple
from typing import List, NamedTuple, Optional
import numpy as np
from mlagents.trainers.buffer import (

reward: float
done: bool
action: ActionTuple
action_probs: LogProbsTuple
action_probs: Optional[LogProbsTuple] # TODO rename to action_log_probs
memory: np.ndarray
memory: Optional[np.ndarray]
class ObsUtil:

agent_buffer_trajectory[BufferKey.DISCRETE_ACTION].append(
exp.action.discrete
)
agent_buffer_trajectory[BufferKey.CONTINUOUS_LOG_PROBS].append(
exp.action_probs.continuous
)
agent_buffer_trajectory[BufferKey.DISCRETE_LOG_PROBS].append(
exp.action_probs.discrete
)
if exp.action_probs is not None:
agent_buffer_trajectory[BufferKey.CONTINUOUS_LOG_PROBS].append(
exp.action_probs.continuous
)
agent_buffer_trajectory[BufferKey.DISCRETE_LOG_PROBS].append(
exp.action_probs.discrete
)
# Store action masks if necessary. Note that 1 means active, while
# in AgentExperience False means active.

4
ml-agents/mlagents/trainers/buffer.py


if key_list is None:
key_list = list(self.keys())
if not self.check_length(key_list):
lengths = {k: len(self._fields[k]) for k in key_list}
lengths_str = "\n\t".join(str((k, v)) for k, v in lengths.items())
f"The length of the fields {key_list} were not of same length"
f"The length of the fields were not of same length: {lengths_str}"
)
for field_key in key_list:
target_buffer[field_key].extend(

50
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py


import os
import pytest
from unittest.mock import patch
import pytest
from unittest.mock import patch
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
import os
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.demonstrations.demonstration_provider import (
DemonstrationProvider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents.trainers.settings import GAILSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,

ACTIONSPEC_DISCRETE = ActionSpec.create_discrete((20,))
class MockDemonstrationProvider(DemonstrationProvider):
def __init__(self, behavior_spec, buffer):
self._behavior_spec = behavior_spec
self._buffer = buffer
def get_behavior_spec(self) -> BehaviorSpec:
return self._behavior_spec
def pop_trajectories(self):
raise NotImplementedError()
def to_agentbuffer(self, training_length: int) -> AgentBuffer:
return self._buffer
@pytest.mark.parametrize(
"behavior_spec",
[BehaviorSpec(create_observation_specs_with_shapes([(8,)]), ACTIONSPEC_CONTINUOUS)],

],
)
@pytest.mark.parametrize("use_actions", [False, True])
@patch(
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer"
)
@patch.object(GAILRewardProvider, "_get_demonstration_provider")
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int
mock_demo_provider: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int
demo_to_buffer.return_value = None, buffer_expert
mock_demo_provider.return_value = MockDemonstrationProvider(
behavior_spec, buffer_expert
)
gail_settings = GAILSettings(
demo_path="", learning_rate=0.005, use_vail=False, use_actions=use_actions
)

],
)
@pytest.mark.parametrize("use_actions", [False, True])
@patch(
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer"
)
@patch.object(GAILRewardProvider, "_get_demonstration_provider")
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int
mock_demo_provider: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int
demo_to_buffer.return_value = None, buffer_expert
mock_demo_provider.return_value = MockDemonstrationProvider(
behavior_spec, buffer_expert
)
gail_settings = GAILSettings(
demo_path="", learning_rate=0.005, use_vail=True, use_actions=use_actions
)

25
ml-agents/mlagents/trainers/torch/components/bc/module.py


from mlagents.torch_utils import torch
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.demo_loader import demo_to_buffer
from mlagents.trainers.demonstrations.demonstration_provider import (
DemonstrationProvider,
)
from mlagents.trainers.demonstrations.local_demonstration_provider import (
LocalDemonstrationProvider,
)
from mlagents.trainers.settings import BehavioralCloningSettings, ScheduleType
from mlagents.trainers.torch.agent_action import AgentAction
from mlagents.trainers.torch.action_log_probs import ActionLogProbs

)
params = self.policy.actor_critic.parameters()
self.optimizer = torch.optim.Adam(params, lr=self.current_lr)
_, self.demonstration_buffer = demo_to_buffer(
settings.demo_path, policy.sequence_length, policy.behavior_spec
demo_provider = self._get_demonstration_provider(settings)
# TODO check policy.behavior_spec == demo_provider_spec
self.demonstration_buffer = demo_provider.to_agentbuffer(
training_length=policy.sequence_length
self.batch_size = (
settings.batch_size if settings.batch_size else default_batch_size
)

self.has_updated = False
self.use_recurrent = self.policy.use_recurrent
self.samples_per_update = settings.samples_per_update
def _get_demonstration_provider(
self, settings: BehavioralCloningSettings
) -> DemonstrationProvider:
"""
Get the DemonstrationProvider as determined by the BehavioralCloningSettings.
This is currently always a LocalDemonstrationProvider but could change in the future,
based on the settings.
"""
return LocalDemonstrationProvider(settings.demo_path)
def update(self) -> Dict[str, np.ndarray]:
"""

27
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


)
from mlagents.trainers.settings import GAILSettings
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.demonstrations.demonstration_provider import (
DemonstrationProvider,
)
from mlagents.trainers.demonstrations.local_demonstration_provider import (
LocalDemonstrationProvider,
)
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.agent_action import AgentAction
from mlagents.trainers.torch.action_flattener import ActionFlattener

from mlagents.trainers.demo_loader import demo_to_buffer
from mlagents.trainers.trajectory import ObsUtil

self._ignore_done = True
self._discriminator_network = DiscriminatorNetwork(specs, settings)
self._discriminator_network.to(default_device())
_, self._demo_buffer = demo_to_buffer(
settings.demo_path, 1, specs
) # This is supposed to be the sequence length but we do not have access here
demo_provider = self._get_demonstration_provider(settings)
# TODO check spec == demo_provider_spec
self._demo_buffer = demo_provider.to_agentbuffer(training_length=1)
def _get_demonstration_provider(
self, settings: GAILSettings
) -> DemonstrationProvider:
"""
Get the DemonstrationProvider as determined by the GAILSettings.
This is currently always a LocalDemonstrationProvider but could change in the future,
based on the settings.
"""
return LocalDemonstrationProvider(settings.demo_path)
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
with torch.no_grad():

def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
expert_batch = self._demo_buffer.sample_mini_batch(
mini_batch.num_experiences, 1
mini_batch.num_experiences, sequence_length=1
)
loss, stats_dict = self._discriminator_network.compute_loss(
mini_batch, expert_batch

0
ml-agents/mlagents/trainers/demonstrations/__init__.py

96
ml-agents/mlagents/trainers/demonstrations/demonstration_proto_utils.py


import os
from typing import List, Tuple
import numpy as np
from mlagents.trainers.buffer import AgentBuffer
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import (
AgentInfoActionPairProto,
)
from mlagents.trainers.trajectory import ObsUtil
from mlagents_envs.rpc_utils import behavior_spec_from_proto, steps_from_proto
from mlagents_envs.base_env import BehaviorSpec
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import (
DemonstrationMetaProto,
)
from mlagents_envs.timers import timed, hierarchical_timer
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore
from google.protobuf.internal.encoder import _EncodeVarint # type: ignore
INITIAL_POS = 33
SUPPORTED_DEMONSTRATION_VERSIONS = frozenset([0, 1])
@timed
def load_demonstration(
file_paths: List[str],
) -> Tuple[BehaviorSpec, List[AgentInfoActionPairProto]]:
"""
Loads and parses a demonstration file.
:param file_path: Location of demonstration file (.demo).
:return: BrainParameter and list of AgentInfoActionPairProto containing demonstration data.
"""
# First 32 bytes of file dedicated to meta-data.
behavior_spec = None
brain_param_proto = None
info_action_pairs = []
total_expected = 0
for _file_path in file_paths:
with open(_file_path, "rb") as fp:
with hierarchical_timer("read_file"):
data = fp.read()
next_pos, pos, obs_decoded = 0, 0, 0
while pos < len(data):
next_pos, pos = _DecodeVarint32(data, pos)
if obs_decoded == 0:
meta_data_proto = DemonstrationMetaProto()
meta_data_proto.ParseFromString(data[pos : pos + next_pos])
if (
meta_data_proto.api_version
not in SUPPORTED_DEMONSTRATION_VERSIONS
):
raise RuntimeError(
f"Can't load Demonstration data from an unsupported version ({meta_data_proto.api_version})"
)
total_expected += meta_data_proto.number_steps
pos = INITIAL_POS
if obs_decoded == 1:
brain_param_proto = BrainParametersProto()
brain_param_proto.ParseFromString(data[pos : pos + next_pos])
pos += next_pos
if obs_decoded > 1:
agent_info_action = AgentInfoActionPairProto()
agent_info_action.ParseFromString(data[pos : pos + next_pos])
if behavior_spec is None:
behavior_spec = behavior_spec_from_proto(
brain_param_proto, agent_info_action.agent_info
)
info_action_pairs.append(agent_info_action)
if len(info_action_pairs) == total_expected:
break
pos += next_pos
obs_decoded += 1
if not behavior_spec:
raise RuntimeError(
f"No BrainParameters found in demonstration file(s) at {file_paths}."
)
return behavior_spec, info_action_pairs
def write_delimited(f, message):
msg_string = message.SerializeToString()
msg_size = len(msg_string)
_EncodeVarint(f.write, msg_size)
f.write(msg_string)
def write_demo(demo_path, meta_data_proto, brain_param_proto, agent_info_protos):
with open(demo_path, "wb") as f:
# write metadata
write_delimited(f, meta_data_proto)
f.seek(INITIAL_POS)
write_delimited(f, brain_param_proto)
for agent in agent_info_protos:
write_delimited(f, agent)

72
ml-agents/mlagents/trainers/demonstrations/demonstration_provider.py


import abc
import numpy as np
from typing import List, NamedTuple
from mlagents_envs.base_env import ActionTuple, BehaviorSpec
from mlagents.trainers.buffer import AgentBuffer, BufferKey
from mlagents.trainers.trajectory import ObsUtil
class DemonstrationExperience(NamedTuple):
obs: List[np.ndarray]
reward: float
done: bool
action: ActionTuple
prev_action: np.ndarray
interrupted: bool
class DemonstrationTrajectory(NamedTuple):
experiences: List[DemonstrationExperience]
def to_agentbuffer(self) -> AgentBuffer:
"""
Converts a Trajectory to an AgentBuffer
:param trajectory: A Trajectory
:returns: AgentBuffer. Note that the length of the AgentBuffer will be one
less than the trajectory, as the next observation need to be populated from the last
step of the trajectory.
"""
agent_buffer_trajectory = AgentBuffer()
for exp in self.experiences:
for i, obs in enumerate(exp.obs):
agent_buffer_trajectory[ObsUtil.get_name_at(i)].append(obs)
# TODO Not in demo_loader
agent_buffer_trajectory[BufferKey.MASKS].append(1.0)
agent_buffer_trajectory[BufferKey.DONE].append(exp.done)
agent_buffer_trajectory[BufferKey.CONTINUOUS_ACTION].append(
exp.action.continuous
)
agent_buffer_trajectory[BufferKey.DISCRETE_ACTION].append(
exp.action.discrete
)
agent_buffer_trajectory[BufferKey.PREV_ACTION].append(exp.prev_action)
agent_buffer_trajectory[BufferKey.ENVIRONMENT_REWARDS].append(exp.reward)
return agent_buffer_trajectory
class DemonstrationProvider(abc.ABC):
@abc.abstractmethod
def get_behavior_spec(self) -> BehaviorSpec:
pass
@abc.abstractmethod
def pop_trajectories(self) -> List[DemonstrationTrajectory]:
pass
def to_agentbuffer(self, training_length: int) -> AgentBuffer:
buffer_out = AgentBuffer()
trajectories = self.pop_trajectories()
for trajectory in trajectories:
temp_buffer = trajectory.to_agentbuffer()
temp_buffer.resequence_and_append(
buffer_out, batch_size=None, training_length=training_length
)
return buffer_out

146
ml-agents/mlagents/trainers/demonstrations/local_demonstration_provider.py


import os
from typing import List
import numpy as np
from mlagents_envs.base_env import ActionTuple, BehaviorSpec, ActionSpec
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import (
AgentInfoActionPairProto,
)
from mlagents_envs.rpc_utils import steps_from_proto
from mlagents.trainers.demonstrations.demonstration_provider import (
DemonstrationProvider,
DemonstrationExperience,
DemonstrationTrajectory,
)
from mlagents.trainers.demonstrations.demonstration_proto_utils import (
load_demonstration,
)
class LocalDemonstrationProvider(DemonstrationProvider):
def __init__(self, file_path: str):
super().__init__()
demo_paths = self._get_demo_files(file_path)
behavior_spec, info_action_pairs, = load_demonstration(demo_paths)
self._behavior_spec = behavior_spec
self._info_action_pairs = info_action_pairs
def get_behavior_spec(self) -> BehaviorSpec:
return self._behavior_spec
def pop_trajectories(self) -> List[DemonstrationTrajectory]:
trajectories = LocalDemonstrationProvider._info_action_pairs_to_trajectories(
self._behavior_spec, self._info_action_pairs
)
self._info_action_pairs = []
return trajectories
@staticmethod
def _get_demo_files(path: str) -> List[str]:
"""
Retrieves the demonstration file(s) from a path.
:param path: Path of demonstration file or directory.
:return: List of demonstration files
Raises errors if |path| is invalid.
"""
if os.path.isfile(path):
if not path.endswith(".demo"):
raise ValueError("The path provided is not a '.demo' file.")
return [path]
elif os.path.isdir(path):
paths = [
os.path.join(path, name)
for name in os.listdir(path)
if name.endswith(".demo")
]
if not paths:
raise ValueError(
"There are no '.demo' files in the provided directory."
)
return paths
else:
raise FileNotFoundError(
f"The demonstration file or directory {path} does not exist."
)
@staticmethod
def _info_action_pairs_to_trajectories(
behavior_spec: BehaviorSpec, info_action_pairs: List[AgentInfoActionPairProto]
) -> List[DemonstrationTrajectory]:
trajectories_out: List[DemonstrationTrajectory] = []
current_experiences = []
previous_action = np.zeros(
behavior_spec.action_spec.continuous_size, dtype=np.float32
) # TODO or discrete?
for pair_index, pair in enumerate(info_action_pairs):
# Extract the observations from the decision/terminal steps
current_decision_step, current_terminal_step = steps_from_proto(
[pair.agent_info], behavior_spec
)
if len(current_terminal_step) == 1:
obs = list(current_terminal_step.values())[0].obs
else:
obs = list(current_decision_step.values())[0].obs
action_tuple = LocalDemonstrationProvider._get_action_tuple(
pair, behavior_spec.action_spec
)
exp = DemonstrationExperience(
obs=obs,
reward=pair.agent_info.reward, # TODO next step's reward?
done=pair.agent_info.done,
action=action_tuple,
prev_action=previous_action,
interrupted=pair.agent_info.max_step_reached,
)
current_experiences.append(exp)
previous_action = np.array(
pair.action_info.vector_actions_deprecated, dtype=np.float32
)
if pair.agent_info.done or pair_index == len(info_action_pairs) - 1:
trajectories_out.append(
DemonstrationTrajectory(experiences=current_experiences)
)
current_experiences = []
return trajectories_out
@staticmethod
def _get_action_tuple(
pair: AgentInfoActionPairProto, action_spec: ActionSpec
) -> ActionTuple:
continuous_actions = None
discrete_actions = None
if (
len(pair.action_info.continuous_actions) == 0
and len(pair.action_info.discrete_actions) == 0
):
if action_spec.continuous_size > 0:
continuous_actions = pair.action_info.vector_actions_deprecated
else:
discrete_actions = pair.action_info.vector_actions_deprecated
else:
if action_spec.continuous_size > 0:
continuous_actions = pair.action_info.continuous_actions
if action_spec.discrete_size > 0:
discrete_actions = pair.action_info.discrete_actions
# TODO 2D?
continuous_np = (
np.array(continuous_actions, dtype=np.float32)
if continuous_actions
else None
)
discrete_np = (
np.array(discrete_actions, dtype=np.float32) if discrete_actions else None
)
return ActionTuple(continuous_np, discrete_np)
正在加载...
取消
保存