浏览代码

Refactor file logic in demo_loader and add unit tests. (#3241)

/asymm-envs
Chris Elion 5 年前
当前提交
45e6e53c
共有 2 个文件被更改,包括 64 次插入23 次删除
  1. 50
      ml-agents/mlagents/trainers/demo_loader.py
  2. 37
      ml-agents/mlagents/trainers/tests/test_demo_loader.py

50
ml-agents/mlagents/trainers/demo_loader.py


import pathlib
import logging
import os
from typing import List, Tuple

return brain_params, demo_buffer
def get_demo_files(path: str) -> List[str]:
"""
Retrieves the demonstration file(s) from a path.
:param path: Path of demonstration file or directory.
:return: List of demonstration files
Raises errors if |path| is invalid.
"""
if os.path.isfile(path):
if not path.endswith(".demo"):
raise ValueError("The path provided is not a '.demo' file.")
return [path]
elif os.path.isdir(path):
paths = [
os.path.join(path, name)
for name in os.listdir(path)
if name.endswith(".demo")
]
if not paths:
raise ValueError("There are no '.demo' files in the provided directory.")
return paths
else:
raise FileNotFoundError(
f"The demonstration file or directory {path} does not exist."
)
@timed
def load_demonstration(
file_path: str

# First 32 bytes of file dedicated to meta-data.
INITIAL_POS = 33
file_paths = []
if os.path.isdir(file_path):
all_files = os.listdir(file_path)
for _file in all_files:
if _file.endswith(".demo"):
file_paths.append(os.path.join(file_path, _file))
if not all_files:
raise ValueError("There are no '.demo' files in the provided directory.")
elif os.path.isfile(file_path):
file_paths.append(file_path)
file_extension = pathlib.Path(file_path).suffix
if file_extension != ".demo":
raise ValueError(
"The file is not a '.demo' file. Please provide a file with the "
"correct extension."
)
else:
raise FileNotFoundError(
"The demonstration file or directory {} does not exist.".format(file_path)
)
file_paths = get_demo_files(file_path)
group_spec = None
brain_param_proto = None
info_action_pairs = []

37
ml-agents/mlagents/trainers/tests/test_demo_loader.py


import os
import numpy as np
import pytest
import tempfile
from mlagents.trainers.demo_loader import load_demonstration, demo_to_buffer
from mlagents.trainers.demo_loader import (
load_demonstration,
demo_to_buffer,
get_demo_files,
)
def test_load_demo():

_, demo_buffer = demo_to_buffer(path_prefix + "/test_demo_dir", 1)
assert len(demo_buffer["actions"]) == total_expected - 1
def test_edge_cases():
path_prefix = os.path.dirname(os.path.abspath(__file__))
# nonexistent file and directory
with pytest.raises(FileNotFoundError):
get_demo_files(os.path.join(path_prefix, "nonexistent_file.demo"))
with pytest.raises(FileNotFoundError):
get_demo_files(os.path.join(path_prefix, "nonexistent_directory"))
with tempfile.TemporaryDirectory() as tmpdirname:
# empty directory
with pytest.raises(ValueError):
get_demo_files(tmpdirname)
# invalid file
invalid_fname = os.path.join(tmpdirname, "mydemo.notademo")
with open(invalid_fname, "w") as f:
f.write("I'm not a demo")
with pytest.raises(ValueError):
get_demo_files(invalid_fname)
# invalid directory
with pytest.raises(ValueError):
get_demo_files(tmpdirname)
# valid file
valid_fname = os.path.join(tmpdirname, "mydemo.demo")
with open(valid_fname, "w") as f:
f.write("I'm a demo file")
assert get_demo_files(valid_fname) == [valid_fname]
# valid directory
assert get_demo_files(tmpdirname) == [valid_fname]
正在加载...
取消
保存