浏览代码

Merge branch 'develop-fix-lstms' into develop-critic-op-lstm

/develop/critic-op-lstm-currentmem
Ervin Teng 4 年前
当前提交
219e773b
共有 21 个文件被更改,包括 271 次插入136 次删除
  1. 1
      .yamato/com.unity.ml-agents-performance.yml
  2. 4
      .yamato/com.unity.ml-agents-test.yml
  3. 4
      .yamato/compressed-sensor-test.yml
  4. 4
      .yamato/gym-interface-test.yml
  5. 4
      .yamato/python-ll-api-test.yml
  6. 15
      .yamato/test_versions.metafile
  7. 2
      Project/Assets/ML-Agents/Editor/Tests/StandaloneBuildTest.cs
  8. 29
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs
  9. 3
      com.unity.ml-agents/CHANGELOG.md
  10. 49
      com.unity.ml-agents/Runtime/Academy.cs
  11. 5
      com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs
  12. 147
      com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
  13. 14
      com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs
  14. 19
      com.unity.ml-agents/Tests/Editor/Communicator/RpcCommunicatorTests.cs
  15. 2
      docs/Training-Configuration-File.md
  16. 2
      ml-agents/mlagents/trainers/agent_processor.py
  17. 40
      ml-agents/mlagents/trainers/ghost/trainer.py
  18. 10
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  19. 15
      ml-agents/mlagents/trainers/policy/policy.py
  20. 9
      ml-agents/mlagents/trainers/tests/torch/test_ghost.py
  21. 29
      ml-agents/tests/yamato/training_int_tests.py

1
.yamato/com.unity.ml-agents-performance.yml


test_editors:
- version: 2019.4
- version: 2020.1
- version: 2020.2
---
{% for editor in test_editors %}

4
.yamato/com.unity.ml-agents-test.yml


enableCodeCoverage: !!bool true
testProject: DevProject
enableNoDefaultPackages: !!bool true
- version: 2020.1
enableCodeCoverage: !!bool true
testProject: DevProject
enableNoDefaultPackages: !!bool true
- version: 2020.2
enableCodeCoverage: !!bool true
testProject: DevProject

4
.yamato/compressed-sensor-test.yml


- .yamato/standalone-build-test.yml#test_linux_standalone_{{ editor.version }}
triggers:
cancel_old_ci: true
{% if editor.extra_test == "sensor" %}
expression: |
(pull_request.target eq "master" OR
pull_request.target match "release.+") AND

pull_request.changes.any match "Project/**" OR
pull_request.changes.any match "ml-agents/**" OR
pull_request.changes.any match "ml-agents/tests/yamato/**" OR
{% endif %}
{% endfor %}

4
.yamato/gym-interface-test.yml


- .yamato/standalone-build-test.yml#test_linux_standalone_{{ editor.version }}
triggers:
cancel_old_ci: true
{% if editor.extra_test == "gym" %}
expression: |
(pull_request.target eq "master" OR
pull_request.target match "release.+") AND

pull_request.changes.any match "ml-agents/**" OR
pull_request.changes.any match "ml-agents/tests/yamato/**" OR
{% endif %}
{% endfor %}

4
.yamato/python-ll-api-test.yml


- .yamato/standalone-build-test.yml#test_linux_standalone_{{ editor.version }}
triggers:
cancel_old_ci: true
{% if editor.extra_test == "llapi" %}
expression: |
(pull_request.target eq "master" OR
pull_request.target match "release.+") AND

pull_request.changes.any match "ml-agents/**" OR
pull_request.changes.any match "ml-agents/tests/yamato/**" OR
{% endif %}
{% endfor %}

15
.yamato/test_versions.metafile


# List of editor versions for standalone-build-test and its dependencies.
# csharp_backcompat_version is used in training-int-tests to determine the
# older package version to run the backwards compat tests against.
# We always run training-int-tests for all versions of the editor
# For each "other" test, we only run it against a single version of the
# editor to reduce the number of yamato jobs
csharp_backcompat_version: 1.0.0
extra_test: llapi
csharp_backcompat_version: 1.0.0
- version: 2020.1
csharp_backcompat_version: 1.0.0
extra_test: gym
# 2020.2 moved the AssetImporters namespace
# but we didn't handle this until 1.2.0
csharp_backcompat_version: 1.2.0
extra_test: sensor

2
Project/Assets/ML-Agents/Editor/Tests/StandaloneBuildTest.cs


scenes,
outputPath,
buildTarget,
BuildOptions.None
BuildOptions.Development
);
var isOk = buildResult.summary.result == BuildResult.Succeeded;
var error = "";

29
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs


const string k_CommandLineModelOverrideDirectoryFlag = "--mlagents-override-model-directory";
const string k_CommandLineModelOverrideExtensionFlag = "--mlagents-override-model-extension";
const string k_CommandLineQuitAfterEpisodesFlag = "--mlagents-quit-after-episodes";
const string k_CommandLineQuitAfterSeconds = "--mlagents-quit-after-seconds";
const string k_CommandLineQuitOnLoadFailure = "--mlagents-quit-on-load-failure";
// The attached Agent

// Max episodes to run. Only used if > 0
// Will default to 1 if override models are specified, otherwise 0.
int m_MaxEpisodes;
// Deadline - exit if the time exceeds this
DateTime m_Deadline = DateTime.MaxValue;
int m_NumSteps;
int m_PreviousNumSteps;

void GetAssetPathFromCommandLine()
{
var maxEpisodes = 0;
var timeoutSeconds = 0;
string[] commandLineArgsOverride = null;
if (!string.IsNullOrEmpty(debugCommandLineOverride) && Application.isEditor)
{

{
Int32.TryParse(args[i + 1], out maxEpisodes);
}
else if (args[i] == k_CommandLineQuitAfterSeconds && i < args.Length - 1)
{
Int32.TryParse(args[i + 1], out timeoutSeconds);
}
else if (args[i] == k_CommandLineQuitOnLoadFailure)
{
m_QuitOnLoadFailure = true;

m_MaxEpisodes = maxEpisodes > 0 ? maxEpisodes : 1;
Debug.Log($"setting m_MaxEpisodes to {maxEpisodes}");
}
if (timeoutSeconds > 0)
{
m_Deadline = DateTime.Now + TimeSpan.FromSeconds(timeoutSeconds);
Debug.Log($"setting deadline to {timeoutSeconds} from now.");
}
}
void OnEnable()

EditorApplication.isPlaying = false;
#endif
}
else if (DateTime.Now >= m_Deadline)
{
Debug.Log(
$"Deadline exceeded. " +
$"{TotalCompletedEpisodes}/{m_MaxEpisodes} episodes and " +
$"{TotalNumSteps}/{m_MaxEpisodes * m_Agent.MaxStep} steps completed. Exiting.");
Application.Quit(0);
#if UNITY_EDITOR
EditorApplication.isPlaying = false;
#endif
}
m_NumSteps++;
}

3
com.unity.ml-agents/CHANGELOG.md


reduced the amount of memory allocated by approximately 25%. (#4887)
- Removed several memory allocations that happened during inference with discrete actions. (#4922)
- Properly catch permission errors when writing timer files. (#4921)
- Unexpected exceptions during training initialization and shutdown are now logged. If you see
"noisy" logs, please let us know! (#4930, #4935)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Fixed a bug that would cause an exception when `RunOptions` was deserialized via `pickle`. (#4842)

while waiting for a connection, and raises a better error message if it crashes. (#4880)
- Passing a `-logfile` option in the `--env-args` option to `mlagents-learn` is
no longer overwritten. (#4880)
- The `load_weights` function was being called unnecessarily often in the Ghost Trainer leading to training slowdowns. (#4934)
## [1.7.2-preview] - 2020-12-22

49
com.unity.ml-agents/Runtime/Academy.cs


{
// We try to exchange the first message with Python. If this fails, it means
// no Python Process is ready to train the environment. In this case, the
//environment must use Inference.
// environment must use Inference.
bool initSuccessful = false;
var communicatorInitParams = new CommunicatorInitParameters
{
unityCommunicationVersion = k_ApiVersion,
unityPackageVersion = k_PackageVersion,
name = "AcademySingleton",
CSharpCapabilities = new UnityRLCapabilities()
};
var unityRlInitParameters = Communicator.Initialize(
new CommunicatorInitParameters
{
unityCommunicationVersion = k_ApiVersion,
unityPackageVersion = k_PackageVersion,
name = "AcademySingleton",
CSharpCapabilities = new UnityRLCapabilities()
});
UnityEngine.Random.InitState(unityRlInitParameters.seed);
// We might have inference-only Agents, so set the seed for them too.
m_InferenceSeed = unityRlInitParameters.seed;
TrainerCapabilities = unityRlInitParameters.TrainerCapabilities;
TrainerCapabilities.WarnOnPythonMissingBaseRLCapabilities();
initSuccessful = Communicator.Initialize(
communicatorInitParams,
out var unityRlInitParameters
);
if (initSuccessful)
{
UnityEngine.Random.InitState(unityRlInitParameters.seed);
// We might have inference-only Agents, so set the seed for them too.
m_InferenceSeed = unityRlInitParameters.seed;
TrainerCapabilities = unityRlInitParameters.TrainerCapabilities;
TrainerCapabilities.WarnOnPythonMissingBaseRLCapabilities();
}
else
{
Debug.Log($"Couldn't connect to trainer on port {port} using API version {k_ApiVersion}. Will perform inference instead.");
Communicator = null;
}
catch
catch (Exception ex)
Debug.Log($"" +
$"Couldn't connect to trainer on port {port} using API version {k_ApiVersion}. " +
"Will perform inference instead."
);
Debug.Log($"Unexpected exception when trying to initialize communication: {ex}\nWill perform inference instead.");
if (Communicator != null)
{
Communicator.QuitCommandReceived += OnQuitCommandReceived;

5
com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs


/// Sends the academy parameters through the Communicator.
/// Is used by the academy to send the AcademyParameters to the communicator.
/// </summary>
/// <returns>The External Initialization Parameters received.</returns>
/// <returns>Whether the connection was successful.</returns>
UnityRLInitParameters Initialize(CommunicatorInitParameters initParameters);
/// <param name="initParametersOut">The External Initialization Parameters received</param>
bool Initialize(CommunicatorInitParameters initParameters, out UnityRLInitParameters initParametersOut);
/// <summary>
/// Registers a new Brain to the Communicator.

147
com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs


internal static bool CheckCommunicationVersionsAreCompatible(
string unityCommunicationVersion,
string pythonApiVersion,
string pythonLibraryVersion)
string pythonApiVersion
)
{
var unityVersion = new Version(unityCommunicationVersion);
var pythonVersion = new Version(pythonApiVersion);

/// Sends the initialization parameters through the Communicator.
/// Is used by the academy to send initialization parameters to the communicator.
/// </summary>
/// <returns>The External Initialization Parameters received.</returns>
/// <returns>Whether the connection was successful.</returns>
public UnityRLInitParameters Initialize(CommunicatorInitParameters initParameters)
/// <param name="initParametersOut">The External Initialization Parameters received.</param>
public bool Initialize(CommunicatorInitParameters initParameters, out UnityRLInitParameters initParametersOut)
{
var academyParameters = new UnityRLInitializationOutputProto
{

{
RlInitializationOutput = academyParameters
},
out input);
var pythonPackageVersion = initializationInput.RlInitializationInput.PackageVersion;
var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion;
var unityCommunicationVersion = initParameters.unityCommunicationVersion;
TrainingAnalytics.SetTrainerInformation(pythonPackageVersion, pythonCommunicationVersion);
out input
);
}
catch (Exception ex)
{
if (ex is RpcException rpcException)
{
var communicationIsCompatible = CheckCommunicationVersionsAreCompatible(unityCommunicationVersion,
pythonCommunicationVersion,
pythonPackageVersion);
// Initialization succeeded part-way. The most likely cause is a mismatch between the communicator
// API strings, so log an explicit warning if that's the case.
if (initializationInput != null && input == null)
{
if (!communicationIsCompatible)
switch (rpcException.Status.StatusCode)
Debug.LogWarningFormat(
"Communication protocol between python ({0}) and Unity ({1}) have different " +
"versions which make them incompatible. Python library version: {2}.",
pythonCommunicationVersion, initParameters.unityCommunicationVersion,
pythonPackageVersion
);
case StatusCode.Unavailable:
// This is the common case where there's no trainer to connect to.
break;
case StatusCode.DeadlineExceeded:
// We don't currently set a deadline for connection, but likely will in the future.
break;
default:
Debug.Log($"Unexpected gRPC exception when trying to initialize communication: {rpcException}");
break;
else
{
Debug.LogWarningFormat(
"Unknown communication error between Python. Python communication protocol: {0}, " +
"Python library version: {1}.",
pythonCommunicationVersion,
pythonPackageVersion
);
}
throw new UnityAgentsException("ICommunicator.Initialize() failed.");
else
{
Debug.Log($"Unexpected exception when trying to initialize communication: {ex}");
}
initParametersOut = new UnityRLInitParameters();
return false;
catch
var pythonPackageVersion = initializationInput.RlInitializationInput.PackageVersion;
var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion;
TrainingAnalytics.SetTrainerInformation(pythonPackageVersion, pythonCommunicationVersion);
var communicationIsCompatible = CheckCommunicationVersionsAreCompatible(
initParameters.unityCommunicationVersion,
pythonCommunicationVersion
);
// Initialization succeeded part-way. The most likely cause is a mismatch between the communicator
// API strings, so log an explicit warning if that's the case.
if (initializationInput != null && input == null)
var exceptionMessage = "The Communicator was unable to connect. Please make sure the External " +
"process is ready to accept communication with Unity.";
// Check for common error condition and add details to the exception message.
var httpProxy = Environment.GetEnvironmentVariable("HTTP_PROXY");
var httpsProxy = Environment.GetEnvironmentVariable("HTTPS_PROXY");
if (httpProxy != null || httpsProxy != null)
if (!communicationIsCompatible)
{
Debug.LogWarningFormat(
"Communication protocol between python ({0}) and Unity ({1}) have different " +
"versions which make them incompatible. Python library version: {2}.",
pythonCommunicationVersion, initParameters.unityCommunicationVersion,
pythonPackageVersion
);
}
else
exceptionMessage += " Try removing HTTP_PROXY and HTTPS_PROXY from the" +
"environment variables and try again.";
Debug.LogWarningFormat(
"Unknown communication error between Python. Python communication protocol: {0}, " +
"Python library version: {1}.",
pythonCommunicationVersion,
pythonPackageVersion
);
throw new UnityAgentsException(exceptionMessage);
initParametersOut = new UnityRLInitParameters();
return false;
return initializationInput.RlInitializationInput.ToUnityRLInitParameters();
initParametersOut = initializationInput.RlInitializationInput.ToUnityRLInitParameters();
return true;
}
/// <summary>

SendCommandEvent(rlInput.Command);
}
UnityInputProto Initialize(UnityOutputProto unityOutput,
out UnityInputProto unityInput)
UnityInputProto Initialize(UnityOutputProto unityOutput, out UnityInputProto unityInput)
{
#if UNITY_EDITOR || UNITY_STANDALONE_WIN || UNITY_STANDALONE_OSX || UNITY_STANDALONE_LINUX
m_IsOpen = true;

}
return result.UnityInput;
#else
throw new UnityAgentsException(
"You cannot perform training on this platform.");
throw new UnityAgentsException("You cannot perform training on this platform.");
#endif
}

{
return null;
}
try
{
var message = m_Client.Exchange(WrapMessage(unityOutput, 200));

QuitCommandReceived?.Invoke();
return message.UnityInput;
}
catch
catch (Exception ex)
if (ex is RpcException rpcException)
{
// Log more verbose errors if they're something the user can possibly do something about.
switch (rpcException.Status.StatusCode)
{
case StatusCode.Unavailable:
// This can happen when python disconnects. Ignore it to avoid noisy logs.
break;
case StatusCode.ResourceExhausted:
// This happens is the message body is too large. There's no way to
// gracefully handle this, but at least we can show the message and the
// user can try to reduce the number of agents or observation sizes.
Debug.LogError($"GRPC Exception: {rpcException.Message}. Disconnecting from trainer.");
break;
default:
// Other unknown errors. Log at INFO level.
Debug.Log($"GRPC Exception: {rpcException.Message}. Disconnecting from trainer.");
break;
}
}
else
{
// Fall-through for other error types
Debug.LogError($"Communication Exception: {ex.Message}. Disconnecting from trainer.");
}
m_IsOpen = false;
QuitCommandReceived?.Invoke();
return null;

14
com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs


using System.Collections.Generic;
using System;
using UnityEngine;
namespace Unity.MLAgents.SideChannels
{

internal void ProcessMessage(byte[] msg)
{
using (var incomingMsg = new IncomingMessage(msg))
try
{
using (var incomingMsg = new IncomingMessage(msg))
{
OnMessageReceived(incomingMsg);
}
}
catch (Exception ex)
OnMessageReceived(incomingMsg);
// Catch all errors in the sidechannel processing, so that a single
// bad SideChannel implementation doesn't take everything down with it.
Debug.LogError($"Error processing SideChannel message: {ex}.\nThe message will be skipped.");
}
}

19
com.unity.ml-agents/Tests/Editor/Communicator/RpcCommunicatorTests.cs


{
var unityVerStr = "1.0.0";
var pythonVerStr = "1.0.0";
var pythonPackageVerStr = "0.16.0";
pythonVerStr,
pythonPackageVerStr));
pythonVerStr));
pythonVerStr,
pythonPackageVerStr));
pythonVerStr));
pythonVerStr,
pythonPackageVerStr));
pythonVerStr));
pythonVerStr,
pythonPackageVerStr));
pythonVerStr));
pythonVerStr,
pythonPackageVerStr));
pythonVerStr));
pythonVerStr,
pythonPackageVerStr));
pythonVerStr));
}
}

2
docs/Training-Configuration-File.md


- LSTM does not work well with continuous actions. Please use
discrete actions for better results.
- Since the memories must be sent back and forth between Python and Unity, using
too large `memory_size` will slow down training.
- Adding a recurrent layer increases the complexity of the neural network, it is
recommended to decrease `num_layers` when using recurrent.
- It is required that `memory_size` be divisible by 2.

2
ml-agents/mlagents/trainers/agent_processor.py


if stored_decision_step is not None and stored_take_action_outputs is not None:
obs = stored_decision_step.obs
if self.policy.use_recurrent:
memory = self.policy.retrieve_memories([global_id])[0, :]
memory = self.policy.retrieve_previous_memories([global_id])[0, :]
else:
memory = None
done = terminated # Since this is an ongoing step

40
ml-agents/mlagents/trainers/ghost/trainer.py


next_learning_team = self.controller.get_learning_team
# CASE 1: Current learning team is managed by this GhostTrainer.
# If the learning team changes, the following loop over queues will push the
# new policy into the policy queue for the new learning agent if
# that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
# CASE 2: Current learning team is managed by a different GhostTrainer.
# If the learning team changes to a team managed by this GhostTrainer, this loop
# will push the current_snapshot into the correct queue. Otherwise,
# it will continue skipping and swap_snapshot will continue to handle
# pushing fixed snapshots
# Case 3: No team change. The if statement just continues to push the policy
# Case 1: No team change. The if statement just continues to push the policy
# into the correct queue (or not if not learning team).
for brain_name in self._internal_policy_queues:
internal_policy_queue = self._internal_policy_queues[brain_name]

except AgentManagerQueue.Empty:
pass
if next_learning_team in self._team_to_name_to_policy_queue:
continue
if (
self._learning_team == next_learning_team
and next_learning_team in self._team_to_name_to_policy_queue
):
name_to_policy_queue = self._team_to_name_to_policy_queue[
next_learning_team
]

policy = self.get_policy(behavior_id)
policy.load_weights(self.current_policy_snapshot[brain_name])
name_to_policy_queue[brain_name].put(policy)
# CASE 2: Current learning team is managed by this GhostTrainer.
# If the learning team changes, the following loop over queues will push the
# new policy into the policy queue for the new learning agent if
# that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
# CASE 3: Current learning team is managed by a different GhostTrainer.
# If the learning team changes to a team managed by this GhostTrainer, this loop
# will push the current_snapshot into the correct queue. Otherwise,
# it will continue skipping and swap_snapshot will continue to handle
# pushing fixed snapshots
if (
self._learning_team != next_learning_team
and next_learning_team in self._team_to_name_to_policy_queue
):
name_to_policy_queue = self._team_to_name_to_policy_queue[
next_learning_team
]
for brain_name in name_to_policy_queue:
behavior_id = create_name_behavior_id(brain_name, next_learning_team)
policy = self.get_policy(behavior_id)
policy.load_weights(self.current_policy_snapshot[brain_name])
name_to_policy_queue[brain_name].put(policy)
# Note save and swap should be on different step counters.
# We don't want to save unless the policy is learning.

10
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


from mlagents.torch_utils import torch
import numpy as np
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.buffer import AgentBuffer, BufferKey
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.components.bc.module import BCModule
from mlagents.trainers.torch.components.reward_providers import create_reward_provider

current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]
memory = torch.zeros([1, 1, self.policy.m_size])
memory = (
ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][0])
.unsqueeze(0)
.unsqueeze(0)
if self.policy.use_recurrent
else None
)
next_obs = [obs.unsqueeze(0) for obs in next_obs]

15
ml-agents/mlagents/trainers/policy/policy.py


self.network_settings: NetworkSettings = trainer_settings.network_settings
self.seed = seed
self.previous_action_dict: Dict[str, np.ndarray] = {}
self.previous_memory_dict: Dict[str, np.ndarray] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings.network_settings.normalize
self.use_recurrent = self.network_settings.memory is not None

if memory_matrix is None:
return
# Pass old memories into previous_memory_dict
for agent_id in agent_ids:
if agent_id in self.memory_dict:
self.previous_memory_dict[agent_id] = self.memory_dict[agent_id]
for index, agent_id in enumerate(agent_ids):
self.memory_dict[agent_id] = memory_matrix[index, :]

memory_matrix[index, :] = self.memory_dict[agent_id]
return memory_matrix
def retrieve_previous_memories(self, agent_ids: List[str]) -> np.ndarray:
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.previous_memory_dict:
memory_matrix[index, :] = self.previous_memory_dict[agent_id]
return memory_matrix
if agent_id in self.previous_memory_dict:
self.previous_memory_dict.pop(agent_id)
def make_empty_previous_action(self, num_agents: int) -> np.ndarray:
"""

9
ml-agents/mlagents/trainers/tests/torch/test_ghost.py


VECTOR_ACTION_SPACE = 1
VECTOR_OBS_SPACE = 8
DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 513
BUFFER_INIT_SAMPLES = 10241
NUM_AGENTS = 12

assert policy_queue0.empty() and not policy_queue1.empty()
# clear
policy_queue1.get_nowait()
mock_specs = mb.setup_test_behavior_specs(
False,
False,
vector_action_space=VECTOR_ACTION_SPACE,
vector_obs_space=VECTOR_OBS_SPACE,
)
buffer = mb.simulate_rollout(BUFFER_INIT_SAMPLES, mock_specs)
# Mock out reward signal eval

29
ml-agents/tests/yamato/training_int_tests.py


import argparse
import json
import os
import shutil
import sys

log_output_path = f"{get_base_output_path()}/inference.{model_extension}.txt"
# 10 minutes for inference is more than enough
process_timeout = 10 * 60
# Try to gracefully exit a few seconds before that.
model_override_timeout = process_timeout - 15
exe_path = exes[0]
args = [
exe_path,

"1",
"--mlagents-override-model-extension",
model_extension,
"--mlagents-quit-after-seconds",
str(model_override_timeout),
timeout = 15 * 60 # 15 minutes for inference is more than enough
res = subprocess.run(args, timeout=timeout)
res = subprocess.run(args, timeout=process_timeout)
end_time = time.time()
if res.returncode != 0:
print("Error running inference!")

else:
print(f"Inference succeeded! Took {end_time - start_time} seconds")
print(f"Inference finished! Took {end_time - start_time} seconds")
# Check the artifacts directory for the timers, so we can get the gauges
timer_file = f"{exe_path}_Data/ML-Agents/Timers/3DBall_timers.json"
with open(timer_file) as f:
timer_data = json.load(f)
gauges = timer_data.get("gauges", {})
rewards = gauges.get("Override_3DBall.CumulativeReward", {})
max_reward = rewards.get("max")
if max_reward is None:
print(
"Unable to find rewards in timer file. This usually indicates a problem with Barracuda or inference."
)
return False
# We could check that the rewards are over a threshold, but since we train for so short a time,
# the values could be highly variable. So don't do it for now.
return True

正在加载...
取消
保存