浏览代码

add timers and fix filehandle leak (#2989)

/develop/tanhsquash
GitHub 5 年前
当前提交
ece9733c
共有 2 个文件被更改,包括 32 次插入25 次删除
  1. 1
      ml-agents-envs/mlagents/envs/brain.py
  2. 56
      ml-agents/mlagents/trainers/demo_loader.py

1
ml-agents-envs/mlagents/envs/brain.py


return s
@staticmethod
@timed
def from_agent_proto(
worker_id: int,
agent_info_list: List[AgentInfoProto],

56
ml-agents/mlagents/trainers/demo_loader.py


from mlagents.envs.communicator_objects.demonstration_meta_pb2 import (
DemonstrationMetaProto,
)
from mlagents.envs.timers import timed, hierarchical_timer
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore

@timed
def make_demo_buffer(
pair_infos: List[AgentInfoActionPairProto],
brain_params: BrainParameters,

return demo_buffer
@timed
def demo_to_buffer(
file_path: str, sequence_length: int
) -> Tuple[BrainParameters, Buffer]:

return brain_params, demo_buffer
@timed
def load_demonstration(
file_path: str
) -> Tuple[BrainParameters, List[AgentInfoActionPairProto], int]:

info_action_pairs = []
total_expected = 0
for _file_path in file_paths:
data = open(_file_path, "rb").read()
next_pos, pos, obs_decoded = 0, 0, 0
while pos < len(data):
next_pos, pos = _DecodeVarint32(data, pos)
if obs_decoded == 0:
meta_data_proto = DemonstrationMetaProto()
meta_data_proto.ParseFromString(data[pos : pos + next_pos])
total_expected += meta_data_proto.number_steps
pos = INITIAL_POS
if obs_decoded == 1:
brain_param_proto = BrainParametersProto()
brain_param_proto.ParseFromString(data[pos : pos + next_pos])
pos += next_pos
if obs_decoded > 1:
agent_info_action = AgentInfoActionPairProto()
agent_info_action.ParseFromString(data[pos : pos + next_pos])
if brain_params is None:
brain_params = BrainParameters.from_proto(
brain_param_proto, agent_info_action.agent_info
)
info_action_pairs.append(agent_info_action)
if len(info_action_pairs) == total_expected:
break
pos += next_pos
obs_decoded += 1
with open(_file_path, "rb") as fp:
with hierarchical_timer("read_file"):
data = fp.read()
next_pos, pos, obs_decoded = 0, 0, 0
while pos < len(data):
next_pos, pos = _DecodeVarint32(data, pos)
if obs_decoded == 0:
meta_data_proto = DemonstrationMetaProto()
meta_data_proto.ParseFromString(data[pos : pos + next_pos])
total_expected += meta_data_proto.number_steps
pos = INITIAL_POS
if obs_decoded == 1:
brain_param_proto = BrainParametersProto()
brain_param_proto.ParseFromString(data[pos : pos + next_pos])
pos += next_pos
if obs_decoded > 1:
agent_info_action = AgentInfoActionPairProto()
agent_info_action.ParseFromString(data[pos : pos + next_pos])
if brain_params is None:
brain_params = BrainParameters.from_proto(
brain_param_proto, agent_info_action.agent_info
)
info_action_pairs.append(agent_info_action)
if len(info_action_pairs) == total_expected:
break
pos += next_pos
obs_decoded += 1
return brain_params, info_action_pairs, total_expected
正在加载...
取消
保存