|
|
|
|
|
|
import argparse |
|
|
|
import os |
|
|
|
import shutil |
|
|
|
from typing import Any |
|
|
|
find_executables, |
|
|
|
get_base_path, |
|
|
|
get_base_output_path, |
|
|
|
run_standalone_build, |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def run_training(python_version, csharp_version): |
|
|
|
def run_training(python_version: str, csharp_version: str) -> bool: |
|
|
|
latest = "latest" |
|
|
|
run_id = int(time.time() * 1000.0) |
|
|
|
print( |
|
|
|
|
|
|
nn_file_expected = f"./{output_dir}/{run_id}/3DBall.nn" |
|
|
|
onnx_file_expected = f"./{output_dir}/{run_id}/3DBall.onnx" |
|
|
|
frozen_graph_file_expected = f"./{output_dir}/{run_id}/3DBall/frozen_graph_def.pb" |
|
|
|
|
|
|
|
sys.exit(1) |
|
|
|
return False |
|
|
|
|
|
|
|
base_path = get_base_path() |
|
|
|
print(f"Running in base path {base_path}") |
|
|
|
|
|
|
build_returncode = run_standalone_build(base_path) |
|
|
|
|
|
|
|
if build_returncode != 0: |
|
|
|
print("Standalone build FAILED!") |
|
|
|
sys.exit(build_returncode) |
|
|
|
print(f"Standalone build FAILED! with return code {build_returncode}") |
|
|
|
return False |
|
|
|
|
|
|
|
# Now rename the newly-built executable, and restore the old one |
|
|
|
os.rename(full_player_path, final_player_path) |
|
|
|
|
|
|
# and reduce the batch_size and buffer_size enough to ensure an update step happens. |
|
|
|
yaml_out = "override.yaml" |
|
|
|
if python_version: |
|
|
|
overrides = {"max_steps": 100, "batch_size": 10, "buffer_size": 10} |
|
|
|
overrides: Any = {"max_steps": 100, "batch_size": 10, "buffer_size": 10} |
|
|
|
override_legacy_config_file( |
|
|
|
python_version, "config/trainer_config.yaml", yaml_out, **overrides |
|
|
|
) |
|
|
|
|
|
|
} |
|
|
|
override_config_file("config/ppo/3DBall.yaml", yaml_out, overrides) |
|
|
|
|
|
|
|
env_path = os.path.join(get_base_output_path(), standalone_player_path + ".app") |
|
|
|
f"mlagents-learn {yaml_out} --force --env=" |
|
|
|
f"{os.path.join(get_base_output_path(), standalone_player_path)} " |
|
|
|
f"mlagents-learn {yaml_out} --force --env={env_path} " |
|
|
|
f"--run-id={run_id} --no-graphics --env-args -logFile -" |
|
|
|
) # noqa |
|
|
|
res = subprocess.run( |
|
|
|
|
|
|
if res.returncode != 0 or not os.path.exists(nn_file_expected): |
|
|
|
# Save models as artifacts (only if we're using latest python and C#) |
|
|
|
if csharp_version is None and python_version is None: |
|
|
|
model_artifacts_dir = os.path.join(get_base_output_path(), "models") |
|
|
|
os.makedirs(model_artifacts_dir, exist_ok=True) |
|
|
|
shutil.copy(nn_file_expected, model_artifacts_dir) |
|
|
|
shutil.copy(onnx_file_expected, model_artifacts_dir) |
|
|
|
shutil.copy(frozen_graph_file_expected, model_artifacts_dir) |
|
|
|
|
|
|
|
if ( |
|
|
|
res.returncode != 0 |
|
|
|
or not os.path.exists(nn_file_expected) |
|
|
|
or not os.path.exists(onnx_file_expected) |
|
|
|
): |
|
|
|
sys.exit(1) |
|
|
|
return False |
|
|
|
|
|
|
|
if csharp_version is None and python_version is None: |
|
|
|
# Use abs path so that loading doesn't get confused |
|
|
|
model_path = os.path.abspath(os.path.dirname(nn_file_expected)) |
|
|
|
# Onnx loading for overrides not currently supported, but this is |
|
|
|
# where to add it in when it is. |
|
|
|
for extension in ["nn"]: |
|
|
|
inference_ok = run_inference(env_path, model_path, extension) |
|
|
|
if not inference_ok: |
|
|
|
return False |
|
|
|
sys.exit(0) |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def run_inference(env_path: str, output_path: str, model_extension: str) -> bool: |
|
|
|
start_time = time.time() |
|
|
|
exes = find_executables(env_path) |
|
|
|
if len(exes) != 1: |
|
|
|
print(f"Can't determine the player executable in {env_path}. Found {exes}.") |
|
|
|
return False |
|
|
|
|
|
|
|
log_output_path = f"{get_base_output_path()}/inference.{model_extension}.txt" |
|
|
|
|
|
|
|
exe_path = exes[0] |
|
|
|
args = [ |
|
|
|
exe_path, |
|
|
|
"-nographics", |
|
|
|
"-batchmode", |
|
|
|
"-logfile", |
|
|
|
log_output_path, |
|
|
|
"--mlagents-override-model-directory", |
|
|
|
output_path, |
|
|
|
"--mlagents-quit-on-load-failure", |
|
|
|
"--mlagents-quit-after-episodes", |
|
|
|
"1", |
|
|
|
"--mlagents-override-model-extension", |
|
|
|
model_extension, |
|
|
|
] |
|
|
|
res = subprocess.run(args) |
|
|
|
end_time = time.time() |
|
|
|
if res.returncode != 0: |
|
|
|
print("Error running inference!") |
|
|
|
print("Command line: " + " ".join(args)) |
|
|
|
subprocess.run(["cat", log_output_path]) |
|
|
|
return False |
|
|
|
else: |
|
|
|
print(f"Inference succeeded! Took {end_time - start_time} seconds") |
|
|
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
try: |
|
|
|
run_training(args.python, args.csharp) |
|
|
|
ok = run_training(args.python, args.csharp) |
|
|
|
if not ok: |
|
|
|
sys.exit(1) |
|
|
|
|
|
|
|
finally: |
|
|
|
# Cleanup - this gets executed even if we hit sys.exit() |
|
|
|
undo_git_checkout() |
|
|
|