浏览代码

WIP

/MLA-1734-demo-provider
Chris Elion 4 年前
当前提交
d6b9bf25
共有 5 个文件被更改,包括 185 次插入0 次删除
  1. 0
      ml-agents/mlagents/trainers/demonstrations/__init__.py
  2. 95
      ml-agents/mlagents/trainers/demonstrations/demonstration_proto_utils.py
  3. 17
      ml-agents/mlagents/trainers/demonstrations/demonstration_provider.py
  4. 73
      ml-agents/mlagents/trainers/demonstrations/local_demonstration_provider.py

0
ml-agents/mlagents/trainers/demonstrations/__init__.py

95
ml-agents/mlagents/trainers/demonstrations/demonstration_proto_utils.py


import os
from typing import List, Tuple
import numpy as np
from mlagents.trainers.buffer import AgentBuffer
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import (
AgentInfoActionPairProto,
)
from mlagents.trainers.trajectory import ObsUtil
from mlagents_envs.rpc_utils import behavior_spec_from_proto, steps_from_proto
from mlagents_envs.base_env import BehaviorSpec
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import (
DemonstrationMetaProto,
)
from mlagents_envs.timers import timed, hierarchical_timer
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore
from google.protobuf.internal.encoder import _EncodeVarint # type: ignore
INITIAL_POS = 33
SUPPORTED_DEMONSTRATION_VERSIONS = frozenset([0, 1])
@timed
def load_demonstration(
file_paths: List[str],
) -> Tuple[BehaviorSpec, List[AgentInfoActionPairProto], int]:
"""
Loads and parses a demonstration file.
:param file_path: Location of demonstration file (.demo).
:return: BrainParameter and list of AgentInfoActionPairProto containing demonstration data.
"""
# First 32 bytes of file dedicated to meta-data.
behavior_spec = None
brain_param_proto = None
info_action_pairs = []
total_expected = 0
for _file_path in file_paths:
with open(_file_path, "rb") as fp:
with hierarchical_timer("read_file"):
data = fp.read()
next_pos, pos, obs_decoded = 0, 0, 0
while pos < len(data):
next_pos, pos = _DecodeVarint32(data, pos)
if obs_decoded == 0:
meta_data_proto = DemonstrationMetaProto()
meta_data_proto.ParseFromString(data[pos : pos + next_pos])
if (
meta_data_proto.api_version
not in SUPPORTED_DEMONSTRATION_VERSIONS
):
raise RuntimeError(
f"Can't load Demonstration data from an unsupported version ({meta_data_proto.api_version})"
)
total_expected += meta_data_proto.number_steps
pos = INITIAL_POS
if obs_decoded == 1:
brain_param_proto = BrainParametersProto()
brain_param_proto.ParseFromString(data[pos : pos + next_pos])
pos += next_pos
if obs_decoded > 1:
agent_info_action = AgentInfoActionPairProto()
agent_info_action.ParseFromString(data[pos : pos + next_pos])
if behavior_spec is None:
behavior_spec = behavior_spec_from_proto(
brain_param_proto, agent_info_action.agent_info
)
info_action_pairs.append(agent_info_action)
if len(info_action_pairs) == total_expected:
break
pos += next_pos
obs_decoded += 1
if not behavior_spec:
raise RuntimeError(
f"No BrainParameters found in demonstration file(s) at {file_paths}."
)
return behavior_spec, info_action_pairs, total_expected
def write_delimited(f, message):
msg_string = message.SerializeToString()
msg_size = len(msg_string)
_EncodeVarint(f.write, msg_size)
f.write(msg_string)
def write_demo(demo_path, meta_data_proto, brain_param_proto, agent_info_protos):
with open(demo_path, "wb") as f:
# write metadata
write_delimited(f, meta_data_proto)
f.seek(INITIAL_POS)
write_delimited(f, brain_param_proto)
for agent in agent_info_protos:
write_delimited(f, agent)

17
ml-agents/mlagents/trainers/demonstrations/demonstration_provider.py


import abc
from typing import List
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import Trajectory
class DemonstrationProvider(abc.ABC):
@abc.abstractmethod
def get_behavior_spec(self) -> BehaviorSpec:
pass
@abc.abstractmethod
def get_trajectories(self) -> List[Trajectory]:
pass

73
ml-agents/mlagents/trainers/demonstrations/local_demonstration_provider.py


from typing import List
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.demonstrations.demonstration_provider import DemonstrationProvider
from mlagents.trainers.demonstrations.demonstration_proto_utils import load_demonstration
import os
from typing import List, Tuple
import numpy as np
from mlagents.trainers.buffer import AgentBuffer
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import (
AgentInfoActionPairProto,
)
from mlagents.trainers.trajectory import ObsUtil
from mlagents_envs.rpc_utils import behavior_spec_from_proto, steps_from_proto
from mlagents_envs.base_env import BehaviorSpec
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import (
DemonstrationMetaProto,
)
from mlagents_envs.timers import timed, hierarchical_timer
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore
from google.protobuf.internal.encoder import _EncodeVarint # type: ignore
class LocalDemonstrationProver(DemonstrationProvider):
def __init__(self, file_path: str):
super().__init__()
self._trajectories: List[Trajectory] = []
self._load(file_path)
def get_behavior_spec(self) -> BehaviorSpec:
pass
def get_trajectories(self) -> List[Trajectory]:
pass
def _load(self, file_path: str) -> None:
demo_paths = self._get_demo_files(file_path)
behavior_spec, info_action_pair, _ = load_demonstration(demo_paths)
@staticmethod
def _get_demo_files(path: str) -> List[str]:
"""
Retrieves the demonstration file(s) from a path.
:param path: Path of demonstration file or directory.
:return: List of demonstration files
Raises errors if |path| is invalid.
"""
if os.path.isfile(path):
if not path.endswith(".demo"):
raise ValueError("The path provided is not a '.demo' file.")
return [path]
elif os.path.isdir(path):
paths = [
os.path.join(path, name)
for name in os.listdir(path)
if name.endswith(".demo")
]
if not paths:
raise ValueError("There are no '.demo' files in the provided directory.")
return paths
else:
raise FileNotFoundError(
f"The demonstration file or directory {path} does not exist."
)
正在加载...
取消
保存