浏览代码

check demonstration version before loading (#3745)

* check demonstration version before loading

* I guess we have version 0 too

* fix mock import (works on my machine)
/develop/add-fire
GitHub 4 年前
当前提交
55b26417
共有 2 个文件被更改,包括 35 次插入3 次删除
  1. 14
      ml-agents/mlagents/trainers/demo_loader.py
  2. 24
      ml-agents/mlagents/trainers/tests/test_demo_loader.py

14
ml-agents/mlagents/trainers/demo_loader.py


from google.protobuf.internal.encoder import _EncodeVarint # type: ignore
INITIAL_POS = 33
SUPPORTED_DEMONSTRATION_VERSIONS = frozenset([0, 1])
@timed
def make_demo_buffer(
pair_infos: List[AgentInfoActionPairProto],

)
INITIAL_POS = 33
@timed
def load_demonstration(
file_path: str

if obs_decoded == 0:
meta_data_proto = DemonstrationMetaProto()
meta_data_proto.ParseFromString(data[pos : pos + next_pos])
if (
meta_data_proto.api_version
not in SUPPORTED_DEMONSTRATION_VERSIONS
):
raise RuntimeError(
f"Can't load Demonstration data from an unsupported version ({meta_data_proto.api_version})"
)
total_expected += meta_data_proto.number_steps
pos = INITIAL_POS
if obs_decoded == 1:

24
ml-agents/mlagents/trainers/tests/test_demo_loader.py


import io
from unittest import mock
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import (
DemonstrationMetaProto,
)
write_delimited,
)

assert get_demo_files(valid_fname) == [valid_fname]
# valid directory
assert get_demo_files(tmpdirname) == [valid_fname]
@mock.patch("mlagents.trainers.demo_loader.get_demo_files", return_value=["foo.demo"])
def test_unsupported_version_raises_error(mock_get_demo_files):
# Create a metadata proto with an unsupported version
bad_metadata = DemonstrationMetaProto()
bad_metadata.api_version = 1337
# Write the metadata to a temporary buffer, which will get returned by open()
buffer = io.BytesIO()
write_delimited(buffer, bad_metadata)
m = mock.mock_open(read_data=buffer.getvalue())
# Make sure that we get a RuntimeError when trying to load this.
with mock.patch("builtins.open", m):
with pytest.raises(RuntimeError):
load_demonstration("foo")
正在加载...
取消
保存