您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
96 行
3.8 KiB
96 行
3.8 KiB
import os
|
|
from typing import List, Tuple
|
|
import numpy as np
|
|
from mlagents.trainers.buffer import AgentBuffer
|
|
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import (
|
|
AgentInfoActionPairProto,
|
|
)
|
|
from mlagents.trainers.trajectory import ObsUtil
|
|
from mlagents_envs.rpc_utils import behavior_spec_from_proto, steps_from_proto
|
|
from mlagents_envs.base_env import BehaviorSpec
|
|
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
|
|
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import (
|
|
DemonstrationMetaProto,
|
|
)
|
|
from mlagents_envs.timers import timed, hierarchical_timer
|
|
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore
|
|
from google.protobuf.internal.encoder import _EncodeVarint # type: ignore
|
|
|
|
|
|
INITIAL_POS = 33
|
|
SUPPORTED_DEMONSTRATION_VERSIONS = frozenset([0, 1])
|
|
|
|
|
|
@timed
|
|
def load_demonstration(
|
|
file_paths: List[str],
|
|
) -> Tuple[BehaviorSpec, List[AgentInfoActionPairProto]]:
|
|
"""
|
|
Loads and parses a demonstration file.
|
|
:param file_path: Location of demonstration file (.demo).
|
|
:return: BrainParameter and list of AgentInfoActionPairProto containing demonstration data.
|
|
"""
|
|
|
|
# First 32 bytes of file dedicated to meta-data.
|
|
behavior_spec = None
|
|
brain_param_proto = None
|
|
info_action_pairs = []
|
|
total_expected = 0
|
|
for _file_path in file_paths:
|
|
with open(_file_path, "rb") as fp:
|
|
with hierarchical_timer("read_file"):
|
|
data = fp.read()
|
|
next_pos, pos, obs_decoded = 0, 0, 0
|
|
while pos < len(data):
|
|
next_pos, pos = _DecodeVarint32(data, pos)
|
|
if obs_decoded == 0:
|
|
meta_data_proto = DemonstrationMetaProto()
|
|
meta_data_proto.ParseFromString(data[pos : pos + next_pos])
|
|
if (
|
|
meta_data_proto.api_version
|
|
not in SUPPORTED_DEMONSTRATION_VERSIONS
|
|
):
|
|
raise RuntimeError(
|
|
f"Can't load Demonstration data from an unsupported version ({meta_data_proto.api_version})"
|
|
)
|
|
total_expected += meta_data_proto.number_steps
|
|
pos = INITIAL_POS
|
|
if obs_decoded == 1:
|
|
brain_param_proto = BrainParametersProto()
|
|
brain_param_proto.ParseFromString(data[pos : pos + next_pos])
|
|
pos += next_pos
|
|
if obs_decoded > 1:
|
|
agent_info_action = AgentInfoActionPairProto()
|
|
agent_info_action.ParseFromString(data[pos : pos + next_pos])
|
|
if behavior_spec is None:
|
|
behavior_spec = behavior_spec_from_proto(
|
|
brain_param_proto, agent_info_action.agent_info
|
|
)
|
|
info_action_pairs.append(agent_info_action)
|
|
if len(info_action_pairs) == total_expected:
|
|
break
|
|
pos += next_pos
|
|
obs_decoded += 1
|
|
if not behavior_spec:
|
|
raise RuntimeError(
|
|
f"No BrainParameters found in demonstration file(s) at {file_paths}."
|
|
)
|
|
return behavior_spec, info_action_pairs
|
|
|
|
|
|
def write_delimited(f, message):
|
|
msg_string = message.SerializeToString()
|
|
msg_size = len(msg_string)
|
|
_EncodeVarint(f.write, msg_size)
|
|
f.write(msg_string)
|
|
|
|
|
|
def write_demo(demo_path, meta_data_proto, brain_param_proto, agent_info_protos):
|
|
with open(demo_path, "wb") as f:
|
|
# write metadata
|
|
write_delimited(f, meta_data_proto)
|
|
f.seek(INITIAL_POS)
|
|
write_delimited(f, brain_param_proto)
|
|
|
|
for agent in agent_info_protos:
|
|
write_delimited(f, agent)
|