|
|
|
|
|
|
from mlagents.envs.communicator_objects.demonstration_meta_pb2 import ( |
|
|
|
DemonstrationMetaProto, |
|
|
|
) |
|
|
|
from mlagents.envs.timers import timed, hierarchical_timer |
|
|
|
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@timed |
|
|
|
def make_demo_buffer( |
|
|
|
pair_infos: List[AgentInfoActionPairProto], |
|
|
|
brain_params: BrainParameters, |
|
|
|
|
|
|
return demo_buffer |
|
|
|
|
|
|
|
|
|
|
|
@timed |
|
|
|
def demo_to_buffer( |
|
|
|
file_path: str, sequence_length: int |
|
|
|
) -> Tuple[BrainParameters, Buffer]: |
|
|
|
|
|
|
return brain_params, demo_buffer |
|
|
|
|
|
|
|
|
|
|
|
@timed |
|
|
|
def load_demonstration( |
|
|
|
file_path: str |
|
|
|
) -> Tuple[BrainParameters, List[AgentInfoActionPairProto], int]: |
|
|
|
|
|
|
info_action_pairs = [] |
|
|
|
total_expected = 0 |
|
|
|
for _file_path in file_paths: |
|
|
|
data = open(_file_path, "rb").read() |
|
|
|
next_pos, pos, obs_decoded = 0, 0, 0 |
|
|
|
while pos < len(data): |
|
|
|
next_pos, pos = _DecodeVarint32(data, pos) |
|
|
|
if obs_decoded == 0: |
|
|
|
meta_data_proto = DemonstrationMetaProto() |
|
|
|
meta_data_proto.ParseFromString(data[pos : pos + next_pos]) |
|
|
|
total_expected += meta_data_proto.number_steps |
|
|
|
pos = INITIAL_POS |
|
|
|
if obs_decoded == 1: |
|
|
|
brain_param_proto = BrainParametersProto() |
|
|
|
brain_param_proto.ParseFromString(data[pos : pos + next_pos]) |
|
|
|
pos += next_pos |
|
|
|
if obs_decoded > 1: |
|
|
|
agent_info_action = AgentInfoActionPairProto() |
|
|
|
agent_info_action.ParseFromString(data[pos : pos + next_pos]) |
|
|
|
if brain_params is None: |
|
|
|
brain_params = BrainParameters.from_proto( |
|
|
|
brain_param_proto, agent_info_action.agent_info |
|
|
|
) |
|
|
|
info_action_pairs.append(agent_info_action) |
|
|
|
if len(info_action_pairs) == total_expected: |
|
|
|
break |
|
|
|
pos += next_pos |
|
|
|
obs_decoded += 1 |
|
|
|
with open(_file_path, "rb") as fp: |
|
|
|
with hierarchical_timer("read_file"): |
|
|
|
data = fp.read() |
|
|
|
next_pos, pos, obs_decoded = 0, 0, 0 |
|
|
|
while pos < len(data): |
|
|
|
next_pos, pos = _DecodeVarint32(data, pos) |
|
|
|
if obs_decoded == 0: |
|
|
|
meta_data_proto = DemonstrationMetaProto() |
|
|
|
meta_data_proto.ParseFromString(data[pos : pos + next_pos]) |
|
|
|
total_expected += meta_data_proto.number_steps |
|
|
|
pos = INITIAL_POS |
|
|
|
if obs_decoded == 1: |
|
|
|
brain_param_proto = BrainParametersProto() |
|
|
|
brain_param_proto.ParseFromString(data[pos : pos + next_pos]) |
|
|
|
pos += next_pos |
|
|
|
if obs_decoded > 1: |
|
|
|
agent_info_action = AgentInfoActionPairProto() |
|
|
|
agent_info_action.ParseFromString(data[pos : pos + next_pos]) |
|
|
|
if brain_params is None: |
|
|
|
brain_params = BrainParameters.from_proto( |
|
|
|
brain_param_proto, agent_info_action.agent_info |
|
|
|
) |
|
|
|
info_action_pairs.append(agent_info_action) |
|
|
|
if len(info_action_pairs) == total_expected: |
|
|
|
break |
|
|
|
pos += next_pos |
|
|
|
obs_decoded += 1 |
|
|
|
return brain_params, info_action_pairs, total_expected |