|
|
|
|
|
|
import io |
|
|
|
from unittest import mock |
|
|
|
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import ( |
|
|
|
DemonstrationMetaProto, |
|
|
|
) |
|
|
|
|
|
|
|
write_delimited, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert get_demo_files(valid_fname) == [valid_fname] |
|
|
|
# valid directory |
|
|
|
assert get_demo_files(tmpdirname) == [valid_fname] |
|
|
|
|
|
|
|
|
|
|
|
@mock.patch("mlagents.trainers.demo_loader.get_demo_files", return_value=["foo.demo"]) |
|
|
|
def test_unsupported_version_raises_error(mock_get_demo_files): |
|
|
|
# Create a metadata proto with an unsupported version |
|
|
|
bad_metadata = DemonstrationMetaProto() |
|
|
|
bad_metadata.api_version = 1337 |
|
|
|
|
|
|
|
# Write the metadata to a temporary buffer, which will get returned by open() |
|
|
|
buffer = io.BytesIO() |
|
|
|
write_delimited(buffer, bad_metadata) |
|
|
|
m = mock.mock_open(read_data=buffer.getvalue()) |
|
|
|
|
|
|
|
# Make sure that we get a RuntimeError when trying to load this. |
|
|
|
with mock.patch("builtins.open", m): |
|
|
|
with pytest.raises(RuntimeError): |
|
|
|
load_demonstration("foo") |