Chris Elion
4 年前
当前提交
d6b9bf25
共有 5 个文件被更改,包括 185 次插入 和 0 次删除
-
0ml-agents/mlagents/trainers/demonstrations/__init__.py
-
95ml-agents/mlagents/trainers/demonstrations/demonstration_proto_utils.py
-
17ml-agents/mlagents/trainers/demonstrations/demonstration_provider.py
-
73ml-agents/mlagents/trainers/demonstrations/local_demonstration_provider.py
|
|||
import os |
|||
from typing import List, Tuple |
|||
import numpy as np |
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import ( |
|||
AgentInfoActionPairProto, |
|||
) |
|||
from mlagents.trainers.trajectory import ObsUtil |
|||
from mlagents_envs.rpc_utils import behavior_spec_from_proto, steps_from_proto |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto |
|||
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import ( |
|||
DemonstrationMetaProto, |
|||
) |
|||
from mlagents_envs.timers import timed, hierarchical_timer |
|||
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore |
|||
from google.protobuf.internal.encoder import _EncodeVarint # type: ignore |
|||
|
|||
|
|||
INITIAL_POS = 33 |
|||
SUPPORTED_DEMONSTRATION_VERSIONS = frozenset([0, 1]) |
|||
|
|||
@timed |
|||
def load_demonstration( |
|||
file_paths: List[str], |
|||
) -> Tuple[BehaviorSpec, List[AgentInfoActionPairProto], int]: |
|||
""" |
|||
Loads and parses a demonstration file. |
|||
:param file_path: Location of demonstration file (.demo). |
|||
:return: BrainParameter and list of AgentInfoActionPairProto containing demonstration data. |
|||
""" |
|||
|
|||
# First 32 bytes of file dedicated to meta-data. |
|||
behavior_spec = None |
|||
brain_param_proto = None |
|||
info_action_pairs = [] |
|||
total_expected = 0 |
|||
for _file_path in file_paths: |
|||
with open(_file_path, "rb") as fp: |
|||
with hierarchical_timer("read_file"): |
|||
data = fp.read() |
|||
next_pos, pos, obs_decoded = 0, 0, 0 |
|||
while pos < len(data): |
|||
next_pos, pos = _DecodeVarint32(data, pos) |
|||
if obs_decoded == 0: |
|||
meta_data_proto = DemonstrationMetaProto() |
|||
meta_data_proto.ParseFromString(data[pos : pos + next_pos]) |
|||
if ( |
|||
meta_data_proto.api_version |
|||
not in SUPPORTED_DEMONSTRATION_VERSIONS |
|||
): |
|||
raise RuntimeError( |
|||
f"Can't load Demonstration data from an unsupported version ({meta_data_proto.api_version})" |
|||
) |
|||
total_expected += meta_data_proto.number_steps |
|||
pos = INITIAL_POS |
|||
if obs_decoded == 1: |
|||
brain_param_proto = BrainParametersProto() |
|||
brain_param_proto.ParseFromString(data[pos : pos + next_pos]) |
|||
pos += next_pos |
|||
if obs_decoded > 1: |
|||
agent_info_action = AgentInfoActionPairProto() |
|||
agent_info_action.ParseFromString(data[pos : pos + next_pos]) |
|||
if behavior_spec is None: |
|||
behavior_spec = behavior_spec_from_proto( |
|||
brain_param_proto, agent_info_action.agent_info |
|||
) |
|||
info_action_pairs.append(agent_info_action) |
|||
if len(info_action_pairs) == total_expected: |
|||
break |
|||
pos += next_pos |
|||
obs_decoded += 1 |
|||
if not behavior_spec: |
|||
raise RuntimeError( |
|||
f"No BrainParameters found in demonstration file(s) at {file_paths}." |
|||
) |
|||
return behavior_spec, info_action_pairs, total_expected |
|||
|
|||
|
|||
def write_delimited(f, message): |
|||
msg_string = message.SerializeToString() |
|||
msg_size = len(msg_string) |
|||
_EncodeVarint(f.write, msg_size) |
|||
f.write(msg_string) |
|||
|
|||
|
|||
def write_demo(demo_path, meta_data_proto, brain_param_proto, agent_info_protos): |
|||
with open(demo_path, "wb") as f: |
|||
# write metadata |
|||
write_delimited(f, meta_data_proto) |
|||
f.seek(INITIAL_POS) |
|||
write_delimited(f, brain_param_proto) |
|||
|
|||
for agent in agent_info_protos: |
|||
write_delimited(f, agent) |
|
|||
import abc |
|||
from typing import List |
|||
|
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.trajectory import Trajectory |
|||
|
|||
|
|||
class DemonstrationProvider(abc.ABC): |
|||
@abc.abstractmethod |
|||
def get_behavior_spec(self) -> BehaviorSpec: |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def get_trajectories(self) -> List[Trajectory]: |
|||
pass |
|
|||
from typing import List |
|||
|
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
|
|||
from mlagents.trainers.trajectory import Trajectory |
|||
from mlagents.trainers.demonstrations.demonstration_provider import DemonstrationProvider |
|||
from mlagents.trainers.demonstrations.demonstration_proto_utils import load_demonstration |
|||
|
|||
|
|||
import os |
|||
from typing import List, Tuple |
|||
import numpy as np |
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import ( |
|||
AgentInfoActionPairProto, |
|||
) |
|||
from mlagents.trainers.trajectory import ObsUtil |
|||
from mlagents_envs.rpc_utils import behavior_spec_from_proto, steps_from_proto |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto |
|||
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import ( |
|||
DemonstrationMetaProto, |
|||
) |
|||
from mlagents_envs.timers import timed, hierarchical_timer |
|||
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore |
|||
from google.protobuf.internal.encoder import _EncodeVarint # type: ignore |
|||
|
|||
|
|||
|
|||
class LocalDemonstrationProver(DemonstrationProvider): |
|||
def __init__(self, file_path: str): |
|||
super().__init__() |
|||
self._trajectories: List[Trajectory] = [] |
|||
self._load(file_path) |
|||
|
|||
def get_behavior_spec(self) -> BehaviorSpec: |
|||
pass |
|||
|
|||
def get_trajectories(self) -> List[Trajectory]: |
|||
pass |
|||
|
|||
|
|||
def _load(self, file_path: str) -> None: |
|||
demo_paths = self._get_demo_files(file_path) |
|||
behavior_spec, info_action_pair, _ = load_demonstration(demo_paths) |
|||
|
|||
|
|||
@staticmethod |
|||
def _get_demo_files(path: str) -> List[str]: |
|||
""" |
|||
Retrieves the demonstration file(s) from a path. |
|||
:param path: Path of demonstration file or directory. |
|||
:return: List of demonstration files |
|||
|
|||
Raises errors if |path| is invalid. |
|||
""" |
|||
if os.path.isfile(path): |
|||
if not path.endswith(".demo"): |
|||
raise ValueError("The path provided is not a '.demo' file.") |
|||
return [path] |
|||
elif os.path.isdir(path): |
|||
paths = [ |
|||
os.path.join(path, name) |
|||
for name in os.listdir(path) |
|||
if name.endswith(".demo") |
|||
] |
|||
if not paths: |
|||
raise ValueError("There are no '.demo' files in the provided directory.") |
|||
return paths |
|||
else: |
|||
raise FileNotFoundError( |
|||
f"The demonstration file or directory {path} does not exist." |
|||
) |
撰写
预览
正在加载...
取消
保存
Reference in new issue