GitHub
4 年前
当前提交
2fb87e4f
共有 94 个文件被更改,包括 2966 次插入 和 328 次删除
-
3Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
-
2Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
-
2Project/ProjectSettings/UnityConnectSettings.asset
-
26README.md
-
8com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
-
9com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
-
9com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
-
29com.unity.ml-agents/CHANGELOG.md
-
7com.unity.ml-agents/Runtime/Academy.cs
-
13com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
-
7com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
-
68com.unity.ml-agents/Runtime/Analytics/Events.cs
-
14com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs
-
52com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
-
5com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
-
5com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
-
39com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
-
21com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
-
36com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs
-
46com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
-
7com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
-
7com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
-
14com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
-
8com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
-
15com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
-
24com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
-
10com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
-
9com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
-
9com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
-
7com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs
-
11com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
-
24com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
-
26com.unity.ml-agents/Runtime/SideChannels/SideChannelManager.cs
-
19com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs
-
32com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
-
9com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
-
12com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs
-
15com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
-
4com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
-
5docs/Migrating.md
-
10docs/Training-ML-Agents.md
-
17ml-agents-envs/mlagents_envs/communicator.py
-
11ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
-
6ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
-
8ml-agents-envs/mlagents_envs/env_utils.py
-
55ml-agents-envs/mlagents_envs/environment.py
-
12ml-agents-envs/mlagents_envs/mock_communicator.py
-
47ml-agents-envs/mlagents_envs/rpc_communicator.py
-
2ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py
-
2ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
-
54ml-agents-envs/mlagents_envs/tests/test_rpc_communicator.py
-
1ml-agents/mlagents/torch_utils/__init__.py
-
37ml-agents/mlagents/torch_utils/torch.py
-
15ml-agents/mlagents/trainers/cli_utils.py
-
12ml-agents/mlagents/trainers/env_manager.py
-
15ml-agents/mlagents/trainers/learn.py
-
9ml-agents/mlagents/trainers/settings.py
-
81ml-agents/mlagents/trainers/subprocess_env_manager.py
-
2ml-agents/mlagents/trainers/tests/simple_test_envs.py
-
66ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py
-
8ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
-
4ml-agents/mlagents/trainers/tests/torch/test_action_model.py
-
10ml-agents/mlagents/trainers/tests/torch/test_distributions.py
-
4ml-agents/mlagents/trainers/tests/torch/test_encoders.py
-
6ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
-
8ml-agents/mlagents/trainers/tests/torch/test_networks.py
-
3ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
-
4ml-agents/mlagents/trainers/torch/encoders.py
-
3ml-agents/mlagents/trainers/trainer_controller.py
-
3protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
-
150utils/make_readme_table.py
-
23.github/workflows/lock.yml
-
24.yamato/pytest-gpu.yml
-
40com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs
-
3com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs.meta
-
246com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs
-
3com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs.meta
-
850com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs
-
11com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs.meta
-
39com.unity.ml-agents/Runtime/Sensors/IBuiltInSensor.cs
-
3com.unity.ml-agents/Runtime/Sensors/IBuiltInSensor.cs.meta
-
50com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs
-
3com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs.meta
-
42com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs
-
3com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs.meta
-
65com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs
-
3com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs.meta
-
243ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.py
-
97ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.pyi
-
41ml-agents/mlagents/trainers/tests/test_torch_utils.py
-
99ml-agents/mlagents/training_analytics_side_channel.py
-
31protobuf-definitions/proto/mlagents_envs/communicator_objects/training_analytics.proto
-
4pytest.ini
-
38.github/lock.yml
|
|||
from mlagents.torch_utils.torch import torch as torch # noqa |
|||
from mlagents.torch_utils.torch import nn # noqa |
|||
from mlagents.torch_utils.torch import set_torch_config # noqa |
|||
from mlagents.torch_utils.torch import default_device # noqa |
|
|||
name: 'Lock Threads' |
|||
|
|||
on: |
|||
schedule: |
|||
- cron: '0 0/4 * * *' |
|||
|
|||
jobs: |
|||
lock: |
|||
runs-on: ubuntu-latest |
|||
steps: |
|||
- uses: dessant/lock-threads@v2 |
|||
with: |
|||
github-token: ${{ github.token }} |
|||
issue-lock-inactive-days: '30' |
|||
issue-exclude-created-before: '' |
|||
issue-exclude-labels: '' |
|||
issue-lock-labels: '' |
|||
issue-lock-comment: > |
|||
This thread has been automatically locked since there has not been |
|||
any recent activity after it was closed. Please open a new issue for |
|||
related bugs. |
|||
issue-lock-reason: 'resolved' |
|||
process-only: 'issues' |
|
|||
pytest_gpu: |
|||
name: Pytest GPU |
|||
agent: |
|||
type: Unity::VM::GPU |
|||
image: package-ci/ubuntu:stable |
|||
flavor: b1.large |
|||
commands: |
|||
- | |
|||
sudo apt-get update && sudo apt-get install -y python3-venv |
|||
python3 -m venv venv && source venv/bin/activate |
|||
python3 -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple |
|||
python3 -u -m ml-agents.tests.yamato.setup_venv |
|||
python3 -m pip install --progress-bar=off -r test_requirements.txt --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple |
|||
python3 -m pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple |
|||
python3 -m pytest -m "not check_environment_trains" --junitxml=junit/test-results.xml -p no:warnings |
|||
triggers: |
|||
cancel_old_ci: true |
|||
recurring: |
|||
- branch: master |
|||
frequency: daily |
|||
artifacts: |
|||
logs: |
|||
paths: |
|||
- "artifacts/standalone_build.txt" |
|
|||
using System; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Analytics |
|||
{ |
|||
internal static class AnalyticsUtils |
|||
{ |
|||
/// <summary>
|
|||
/// Hash a string to remove PII or secret info before sending to analytics
|
|||
/// </summary>
|
|||
/// <param name="s"></param>
|
|||
/// <returns>A string containing the Hash128 of the input string.</returns>
|
|||
public static string Hash(string s) |
|||
{ |
|||
var behaviorNameHash = Hash128.Compute(s); |
|||
return behaviorNameHash.ToString(); |
|||
} |
|||
|
|||
internal static bool s_SendEditorAnalytics = true; |
|||
|
|||
/// <summary>
|
|||
/// Helper class to temporarily disable sending analytics from unit tests.
|
|||
/// </summary>
|
|||
internal class DisableAnalyticsSending : IDisposable |
|||
{ |
|||
private bool m_PreviousSendEditorAnalytics; |
|||
|
|||
public DisableAnalyticsSending() |
|||
{ |
|||
m_PreviousSendEditorAnalytics = s_SendEditorAnalytics; |
|||
s_SendEditorAnalytics = false; |
|||
} |
|||
|
|||
public void Dispose() |
|||
{ |
|||
s_SendEditorAnalytics = m_PreviousSendEditorAnalytics; |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: af1ef3e70f1242938d7b39284b1a892b |
|||
timeCreated: 1610575760 |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using Unity.MLAgents.Actuators; |
|||
using Unity.MLAgents.Sensors; |
|||
using UnityEngine; |
|||
using UnityEngine.Analytics; |
|||
|
|||
#if UNITY_EDITOR
|
|||
using UnityEditor; |
|||
using UnityEditor.Analytics; |
|||
#endif
|
|||
|
|||
namespace Unity.MLAgents.Analytics |
|||
{ |
|||
internal class TrainingAnalytics |
|||
{ |
|||
const string k_VendorKey = "unity.ml-agents"; |
|||
const string k_TrainingEnvironmentInitializedEventName = "ml_agents_training_environment_initialized"; |
|||
const string k_TrainingBehaviorInitializedEventName = "ml_agents_training_behavior_initialized"; |
|||
const string k_RemotePolicyInitializedEventName = "ml_agents_remote_policy_initialized"; |
|||
|
|||
private static readonly string[] s_EventNames = |
|||
{ |
|||
k_TrainingEnvironmentInitializedEventName, |
|||
k_TrainingBehaviorInitializedEventName, |
|||
k_RemotePolicyInitializedEventName |
|||
}; |
|||
|
|||
/// <summary>
|
|||
/// Whether or not we've registered this particular event yet
|
|||
/// </summary>
|
|||
static bool s_EventsRegistered = false; |
|||
|
|||
/// <summary>
|
|||
/// Hourly limit for this event name
|
|||
/// </summary>
|
|||
const int k_MaxEventsPerHour = 1000; |
|||
|
|||
/// <summary>
|
|||
/// Maximum number of items in this event.
|
|||
/// </summary>
|
|||
const int k_MaxNumberOfElements = 1000; |
|||
|
|||
private static bool s_SentEnvironmentInitialized; |
|||
/// <summary>
|
|||
/// Behaviors that we've already sent events for.
|
|||
/// </summary>
|
|||
private static HashSet<string> s_SentRemotePolicyInitialized; |
|||
private static HashSet<string> s_SentTrainingBehaviorInitialized; |
|||
|
|||
private static Guid s_TrainingSessionGuid; |
|||
|
|||
// These are set when the RpcCommunicator connects
|
|||
private static string s_TrainerPackageVersion = ""; |
|||
private static string s_TrainerCommunicationVersion = ""; |
|||
|
|||
static bool EnableAnalytics() |
|||
{ |
|||
if (s_EventsRegistered) |
|||
{ |
|||
return true; |
|||
} |
|||
|
|||
foreach (var eventName in s_EventNames) |
|||
{ |
|||
#if UNITY_EDITOR
|
|||
AnalyticsResult result = EditorAnalytics.RegisterEventWithLimit(eventName, k_MaxEventsPerHour, k_MaxNumberOfElements, k_VendorKey); |
|||
#else
|
|||
AnalyticsResult result = AnalyticsResult.UnsupportedPlatform; |
|||
#endif
|
|||
if (result != AnalyticsResult.Ok) |
|||
{ |
|||
return false; |
|||
} |
|||
} |
|||
s_EventsRegistered = true; |
|||
|
|||
if (s_SentRemotePolicyInitialized == null) |
|||
{ |
|||
s_SentRemotePolicyInitialized = new HashSet<string>(); |
|||
s_SentTrainingBehaviorInitialized = new HashSet<string>(); |
|||
s_TrainingSessionGuid = Guid.NewGuid(); |
|||
} |
|||
|
|||
return s_EventsRegistered; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Cache information about the trainer when it becomes available in the RpcCommunicator.
|
|||
/// </summary>
|
|||
/// <param name="communicationVersion"></param>
|
|||
/// <param name="packageVersion"></param>
|
|||
public static void SetTrainerInformation(string packageVersion, string communicationVersion) |
|||
{ |
|||
s_TrainerPackageVersion = packageVersion; |
|||
s_TrainerCommunicationVersion = communicationVersion; |
|||
} |
|||
|
|||
public static bool IsAnalyticsEnabled() |
|||
{ |
|||
#if UNITY_EDITOR
|
|||
return EditorAnalytics.enabled; |
|||
#else
|
|||
return false; |
|||
#endif
|
|||
} |
|||
|
|||
public static void TrainingEnvironmentInitialized(TrainingEnvironmentInitializedEvent tbiEvent) |
|||
{ |
|||
if (!IsAnalyticsEnabled()) |
|||
return; |
|||
|
|||
if (!EnableAnalytics()) |
|||
return; |
|||
|
|||
if (s_SentEnvironmentInitialized) |
|||
{ |
|||
// We already sent an TrainingEnvironmentInitializedEvent. Exit so we don't resend.
|
|||
return; |
|||
} |
|||
|
|||
s_SentEnvironmentInitialized = true; |
|||
tbiEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString(); |
|||
|
|||
// Note - to debug, use JsonUtility.ToJson on the event.
|
|||
// Debug.Log(
|
|||
// $"Would send event {k_TrainingEnvironmentInitializedEventName} with body {JsonUtility.ToJson(tbiEvent, true)}"
|
|||
// );
|
|||
#if UNITY_EDITOR
|
|||
if (AnalyticsUtils.s_SendEditorAnalytics) |
|||
{ |
|||
EditorAnalytics.SendEventWithLimit(k_TrainingEnvironmentInitializedEventName, tbiEvent); |
|||
} |
|||
#else
|
|||
return; |
|||
#endif
|
|||
} |
|||
|
|||
public static void RemotePolicyInitialized( |
|||
string fullyQualifiedBehaviorName, |
|||
IList<ISensor> sensors, |
|||
ActionSpec actionSpec |
|||
) |
|||
{ |
|||
if (!IsAnalyticsEnabled()) |
|||
return; |
|||
|
|||
if (!EnableAnalytics()) |
|||
return; |
|||
|
|||
// Extract base behavior name (no team ID)
|
|||
var behaviorName = ParseBehaviorName(fullyQualifiedBehaviorName); |
|||
var added = s_SentRemotePolicyInitialized.Add(behaviorName); |
|||
|
|||
if (!added) |
|||
{ |
|||
// We previously added this model. Exit so we don't resend.
|
|||
return; |
|||
} |
|||
|
|||
var data = GetEventForRemotePolicy(behaviorName, sensors, actionSpec); |
|||
// Note - to debug, use JsonUtility.ToJson on the event.
|
|||
// Debug.Log(
|
|||
// $"Would send event {k_RemotePolicyInitializedEventName} with body {JsonUtility.ToJson(data, true)}"
|
|||
// );
|
|||
#if UNITY_EDITOR
|
|||
if (AnalyticsUtils.s_SendEditorAnalytics) |
|||
{ |
|||
EditorAnalytics.SendEventWithLimit(k_RemotePolicyInitializedEventName, data); |
|||
} |
|||
#else
|
|||
return; |
|||
#endif
|
|||
} |
|||
|
|||
internal static string ParseBehaviorName(string fullyQualifiedBehaviorName) |
|||
{ |
|||
var lastQuestionIndex = fullyQualifiedBehaviorName.LastIndexOf("?"); |
|||
if (lastQuestionIndex < 0) |
|||
{ |
|||
// Nothing to remove
|
|||
return fullyQualifiedBehaviorName; |
|||
} |
|||
|
|||
return fullyQualifiedBehaviorName.Substring(0, lastQuestionIndex); |
|||
} |
|||
|
|||
public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent tbiEvent) |
|||
{ |
|||
if (!IsAnalyticsEnabled()) |
|||
return; |
|||
|
|||
if (!EnableAnalytics()) |
|||
return; |
|||
|
|||
var behaviorName = tbiEvent.BehaviorName; |
|||
var added = s_SentTrainingBehaviorInitialized.Add(behaviorName); |
|||
|
|||
if (!added) |
|||
{ |
|||
// We previously added this model. Exit so we don't resend.
|
|||
return; |
|||
} |
|||
|
|||
// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
|
|||
tbiEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString(); |
|||
tbiEvent.BehaviorName = AnalyticsUtils.Hash(tbiEvent.BehaviorName); |
|||
|
|||
// Note - to debug, use JsonUtility.ToJson on the event.
|
|||
// Debug.Log(
|
|||
// $"Would send event {k_TrainingBehaviorInitializedEventName} with body {JsonUtility.ToJson(tbiEvent, true)}"
|
|||
// );
|
|||
#if UNITY_EDITOR
|
|||
if (AnalyticsUtils.s_SendEditorAnalytics) |
|||
{ |
|||
EditorAnalytics.SendEventWithLimit(k_TrainingBehaviorInitializedEventName, tbiEvent); |
|||
} |
|||
#else
|
|||
return; |
|||
#endif
|
|||
} |
|||
|
|||
static RemotePolicyInitializedEvent GetEventForRemotePolicy( |
|||
string behaviorName, |
|||
IList<ISensor> sensors, |
|||
ActionSpec actionSpec) |
|||
{ |
|||
var remotePolicyEvent = new RemotePolicyInitializedEvent(); |
|||
|
|||
// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
|
|||
remotePolicyEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName); |
|||
|
|||
remotePolicyEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString(); |
|||
remotePolicyEvent.ActionSpec = EventActionSpec.FromActionSpec(actionSpec); |
|||
remotePolicyEvent.ObservationSpecs = new List<EventObservationSpec>(sensors.Count); |
|||
foreach (var sensor in sensors) |
|||
{ |
|||
remotePolicyEvent.ObservationSpecs.Add(EventObservationSpec.FromSensor(sensor)); |
|||
} |
|||
|
|||
remotePolicyEvent.MLAgentsEnvsVersion = s_TrainerPackageVersion; |
|||
remotePolicyEvent.TrainerCommunicationVersion = s_TrainerCommunicationVersion; |
|||
return remotePolicyEvent; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 5ad0bc6b45614bb7929d25dd59d5ac38 |
|||
timeCreated: 1608168600 |
|
|||
// <auto-generated>
|
|||
// Generated by the protocol buffer compiler. DO NOT EDIT!
|
|||
// source: mlagents_envs/communicator_objects/training_analytics.proto
|
|||
// </auto-generated>
|
|||
#pragma warning disable 1591, 0612, 3021
|
|||
#region Designer generated code
|
|||
|
|||
using pb = global::Google.Protobuf; |
|||
using pbc = global::Google.Protobuf.Collections; |
|||
using pbr = global::Google.Protobuf.Reflection; |
|||
using scg = global::System.Collections.Generic; |
|||
namespace Unity.MLAgents.CommunicatorObjects { |
|||
|
|||
/// <summary>Holder for reflection information generated from mlagents_envs/communicator_objects/training_analytics.proto</summary>
|
|||
internal static partial class TrainingAnalyticsReflection { |
|||
|
|||
#region Descriptor
|
|||
/// <summary>File descriptor for mlagents_envs/communicator_objects/training_analytics.proto</summary>
|
|||
public static pbr::FileDescriptor Descriptor { |
|||
get { return descriptor; } |
|||
} |
|||
private static pbr::FileDescriptor descriptor; |
|||
|
|||
static TrainingAnalyticsReflection() { |
|||
byte[] descriptorData = global::System.Convert.FromBase64String( |
|||
string.Concat( |
|||
"CjttbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3RyYWluaW5n", |
|||
"X2FuYWx5dGljcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi2QEKHlRy", |
|||
"YWluaW5nRW52aXJvbm1lbnRJbml0aWFsaXplZBIYChBtbGFnZW50c192ZXJz", |
|||
"aW9uGAEgASgJEh0KFW1sYWdlbnRzX2VudnNfdmVyc2lvbhgCIAEoCRIWCg5w", |
|||
"eXRob25fdmVyc2lvbhgDIAEoCRIVCg10b3JjaF92ZXJzaW9uGAQgASgJEhkK", |
|||
"EXRvcmNoX2RldmljZV90eXBlGAUgASgJEhAKCG51bV9lbnZzGAYgASgFEiIK", |
|||
"Gm51bV9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzGAcgASgFIq0DChtUcmFpbmlu", |
|||
"Z0JlaGF2aW9ySW5pdGlhbGl6ZWQSFQoNYmVoYXZpb3JfbmFtZRgBIAEoCRIU", |
|||
"Cgx0cmFpbmVyX3R5cGUYAiABKAkSIAoYZXh0cmluc2ljX3Jld2FyZF9lbmFi", |
|||
"bGVkGAMgASgIEhsKE2dhaWxfcmV3YXJkX2VuYWJsZWQYBCABKAgSIAoYY3Vy", |
|||
"aW9zaXR5X3Jld2FyZF9lbmFibGVkGAUgASgIEhoKEnJuZF9yZXdhcmRfZW5h", |
|||
"YmxlZBgGIAEoCBIiChpiZWhhdmlvcmFsX2Nsb25pbmdfZW5hYmxlZBgHIAEo", |
|||
"CBIZChFyZWN1cnJlbnRfZW5hYmxlZBgIIAEoCBIWCg52aXN1YWxfZW5jb2Rl", |
|||
"chgJIAEoCRIaChJudW1fbmV0d29ya19sYXllcnMYCiABKAUSIAoYbnVtX25l", |
|||
"dHdvcmtfaGlkZGVuX3VuaXRzGAsgASgFEhgKEHRyYWluZXJfdGhyZWFkZWQY", |
|||
"DCABKAgSGQoRc2VsZl9wbGF5X2VuYWJsZWQYDSABKAgSGgoSY3VycmljdWx1", |
|||
"bV9lbmFibGVkGA4gASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0", |
|||
"b3JPYmplY3RzYgZwcm90bzM=")); |
|||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, |
|||
new pbr::FileDescriptor[] { }, |
|||
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { |
|||
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized.Parser, new[]{ "MlagentsVersion", "MlagentsEnvsVersion", "PythonVersion", "TorchVersion", "TorchDeviceType", "NumEnvs", "NumEnvironmentParameters" }, null, null, null), |
|||
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized.Parser, new[]{ "BehaviorName", "TrainerType", "ExtrinsicRewardEnabled", "GailRewardEnabled", "CuriosityRewardEnabled", "RndRewardEnabled", "BehavioralCloningEnabled", "RecurrentEnabled", "VisualEncoder", "NumNetworkLayers", "NumNetworkHiddenUnits", "TrainerThreaded", "SelfPlayEnabled", "CurriculumEnabled" }, null, null, null) |
|||
})); |
|||
} |
|||
#endregion
|
|||
|
|||
} |
|||
#region Messages
|
|||
internal sealed partial class TrainingEnvironmentInitialized : pb::IMessage<TrainingEnvironmentInitialized> { |
|||
private static readonly pb::MessageParser<TrainingEnvironmentInitialized> _parser = new pb::MessageParser<TrainingEnvironmentInitialized>(() => new TrainingEnvironmentInitialized()); |
|||
private pb::UnknownFieldSet _unknownFields; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public static pb::MessageParser<TrainingEnvironmentInitialized> Parser { get { return _parser; } } |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public static pbr::MessageDescriptor Descriptor { |
|||
get { return global::Unity.MLAgents.CommunicatorObjects.TrainingAnalyticsReflection.Descriptor.MessageTypes[0]; } |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
pbr::MessageDescriptor pb::IMessage.Descriptor { |
|||
get { return Descriptor; } |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public TrainingEnvironmentInitialized() { |
|||
OnConstruction(); |
|||
} |
|||
|
|||
partial void OnConstruction(); |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public TrainingEnvironmentInitialized(TrainingEnvironmentInitialized other) : this() { |
|||
mlagentsVersion_ = other.mlagentsVersion_; |
|||
mlagentsEnvsVersion_ = other.mlagentsEnvsVersion_; |
|||
pythonVersion_ = other.pythonVersion_; |
|||
torchVersion_ = other.torchVersion_; |
|||
torchDeviceType_ = other.torchDeviceType_; |
|||
numEnvs_ = other.numEnvs_; |
|||
numEnvironmentParameters_ = other.numEnvironmentParameters_; |
|||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public TrainingEnvironmentInitialized Clone() { |
|||
return new TrainingEnvironmentInitialized(this); |
|||
} |
|||
|
|||
/// <summary>Field number for the "mlagents_version" field.</summary>
|
|||
public const int MlagentsVersionFieldNumber = 1; |
|||
private string mlagentsVersion_ = ""; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public string MlagentsVersion { |
|||
get { return mlagentsVersion_; } |
|||
set { |
|||
mlagentsVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "mlagents_envs_version" field.</summary>
|
|||
public const int MlagentsEnvsVersionFieldNumber = 2; |
|||
private string mlagentsEnvsVersion_ = ""; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public string MlagentsEnvsVersion { |
|||
get { return mlagentsEnvsVersion_; } |
|||
set { |
|||
mlagentsEnvsVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "python_version" field.</summary>
|
|||
public const int PythonVersionFieldNumber = 3; |
|||
private string pythonVersion_ = ""; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public string PythonVersion { |
|||
get { return pythonVersion_; } |
|||
set { |
|||
pythonVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "torch_version" field.</summary>
|
|||
public const int TorchVersionFieldNumber = 4; |
|||
private string torchVersion_ = ""; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public string TorchVersion { |
|||
get { return torchVersion_; } |
|||
set { |
|||
torchVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "torch_device_type" field.</summary>
|
|||
public const int TorchDeviceTypeFieldNumber = 5; |
|||
private string torchDeviceType_ = ""; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public string TorchDeviceType { |
|||
get { return torchDeviceType_; } |
|||
set { |
|||
torchDeviceType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "num_envs" field.</summary>
|
|||
public const int NumEnvsFieldNumber = 6; |
|||
private int numEnvs_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public int NumEnvs { |
|||
get { return numEnvs_; } |
|||
set { |
|||
numEnvs_ = value; |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "num_environment_parameters" field.</summary>
|
|||
public const int NumEnvironmentParametersFieldNumber = 7; |
|||
private int numEnvironmentParameters_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public int NumEnvironmentParameters { |
|||
get { return numEnvironmentParameters_; } |
|||
set { |
|||
numEnvironmentParameters_ = value; |
|||
} |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public override bool Equals(object other) { |
|||
return Equals(other as TrainingEnvironmentInitialized); |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public bool Equals(TrainingEnvironmentInitialized other) { |
|||
if (ReferenceEquals(other, null)) { |
|||
return false; |
|||
} |
|||
if (ReferenceEquals(other, this)) { |
|||
return true; |
|||
} |
|||
if (MlagentsVersion != other.MlagentsVersion) return false; |
|||
if (MlagentsEnvsVersion != other.MlagentsEnvsVersion) return false; |
|||
if (PythonVersion != other.PythonVersion) return false; |
|||
if (TorchVersion != other.TorchVersion) return false; |
|||
if (TorchDeviceType != other.TorchDeviceType) return false; |
|||
if (NumEnvs != other.NumEnvs) return false; |
|||
if (NumEnvironmentParameters != other.NumEnvironmentParameters) return false; |
|||
return Equals(_unknownFields, other._unknownFields); |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public override int GetHashCode() { |
|||
int hash = 1; |
|||
if (MlagentsVersion.Length != 0) hash ^= MlagentsVersion.GetHashCode(); |
|||
if (MlagentsEnvsVersion.Length != 0) hash ^= MlagentsEnvsVersion.GetHashCode(); |
|||
if (PythonVersion.Length != 0) hash ^= PythonVersion.GetHashCode(); |
|||
if (TorchVersion.Length != 0) hash ^= TorchVersion.GetHashCode(); |
|||
if (TorchDeviceType.Length != 0) hash ^= TorchDeviceType.GetHashCode(); |
|||
if (NumEnvs != 0) hash ^= NumEnvs.GetHashCode(); |
|||
if (NumEnvironmentParameters != 0) hash ^= NumEnvironmentParameters.GetHashCode(); |
|||
if (_unknownFields != null) { |
|||
hash ^= _unknownFields.GetHashCode(); |
|||
} |
|||
return hash; |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public override string ToString() { |
|||
return pb::JsonFormatter.ToDiagnosticString(this); |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public void WriteTo(pb::CodedOutputStream output) { |
|||
if (MlagentsVersion.Length != 0) { |
|||
output.WriteRawTag(10); |
|||
output.WriteString(MlagentsVersion); |
|||
} |
|||
if (MlagentsEnvsVersion.Length != 0) { |
|||
output.WriteRawTag(18); |
|||
output.WriteString(MlagentsEnvsVersion); |
|||
} |
|||
if (PythonVersion.Length != 0) { |
|||
output.WriteRawTag(26); |
|||
output.WriteString(PythonVersion); |
|||
} |
|||
if (TorchVersion.Length != 0) { |
|||
output.WriteRawTag(34); |
|||
output.WriteString(TorchVersion); |
|||
} |
|||
if (TorchDeviceType.Length != 0) { |
|||
output.WriteRawTag(42); |
|||
output.WriteString(TorchDeviceType); |
|||
} |
|||
if (NumEnvs != 0) { |
|||
output.WriteRawTag(48); |
|||
output.WriteInt32(NumEnvs); |
|||
} |
|||
if (NumEnvironmentParameters != 0) { |
|||
output.WriteRawTag(56); |
|||
output.WriteInt32(NumEnvironmentParameters); |
|||
} |
|||
if (_unknownFields != null) { |
|||
_unknownFields.WriteTo(output); |
|||
} |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public int CalculateSize() { |
|||
int size = 0; |
|||
if (MlagentsVersion.Length != 0) { |
|||
size += 1 + pb::CodedOutputStream.ComputeStringSize(MlagentsVersion); |
|||
} |
|||
if (MlagentsEnvsVersion.Length != 0) { |
|||
size += 1 + pb::CodedOutputStream.ComputeStringSize(MlagentsEnvsVersion); |
|||
} |
|||
if (PythonVersion.Length != 0) { |
|||
size += 1 + pb::CodedOutputStream.ComputeStringSize(PythonVersion); |
|||
} |
|||
if (TorchVersion.Length != 0) { |
|||
size += 1 + pb::CodedOutputStream.ComputeStringSize(TorchVersion); |
|||
} |
|||
if (TorchDeviceType.Length != 0) { |
|||
size += 1 + pb::CodedOutputStream.ComputeStringSize(TorchDeviceType); |
|||
} |
|||
if (NumEnvs != 0) { |
|||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumEnvs); |
|||
} |
|||
if (NumEnvironmentParameters != 0) { |
|||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumEnvironmentParameters); |
|||
} |
|||
if (_unknownFields != null) { |
|||
size += _unknownFields.CalculateSize(); |
|||
} |
|||
return size; |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public void MergeFrom(TrainingEnvironmentInitialized other) { |
|||
if (other == null) { |
|||
return; |
|||
} |
|||
if (other.MlagentsVersion.Length != 0) { |
|||
MlagentsVersion = other.MlagentsVersion; |
|||
} |
|||
if (other.MlagentsEnvsVersion.Length != 0) { |
|||
MlagentsEnvsVersion = other.MlagentsEnvsVersion; |
|||
} |
|||
if (other.PythonVersion.Length != 0) { |
|||
PythonVersion = other.PythonVersion; |
|||
} |
|||
if (other.TorchVersion.Length != 0) { |
|||
TorchVersion = other.TorchVersion; |
|||
} |
|||
if (other.TorchDeviceType.Length != 0) { |
|||
TorchDeviceType = other.TorchDeviceType; |
|||
} |
|||
if (other.NumEnvs != 0) { |
|||
NumEnvs = other.NumEnvs; |
|||
} |
|||
if (other.NumEnvironmentParameters != 0) { |
|||
NumEnvironmentParameters = other.NumEnvironmentParameters; |
|||
} |
|||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public void MergeFrom(pb::CodedInputStream input) { |
|||
uint tag; |
|||
while ((tag = input.ReadTag()) != 0) { |
|||
switch(tag) { |
|||
default: |
|||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); |
|||
break; |
|||
case 10: { |
|||
MlagentsVersion = input.ReadString(); |
|||
break; |
|||
} |
|||
case 18: { |
|||
MlagentsEnvsVersion = input.ReadString(); |
|||
break; |
|||
} |
|||
case 26: { |
|||
PythonVersion = input.ReadString(); |
|||
break; |
|||
} |
|||
case 34: { |
|||
TorchVersion = input.ReadString(); |
|||
break; |
|||
} |
|||
case 42: { |
|||
TorchDeviceType = input.ReadString(); |
|||
break; |
|||
} |
|||
case 48: { |
|||
NumEnvs = input.ReadInt32(); |
|||
break; |
|||
} |
|||
case 56: { |
|||
NumEnvironmentParameters = input.ReadInt32(); |
|||
break; |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
} |
|||
|
|||
internal sealed partial class TrainingBehaviorInitialized : pb::IMessage<TrainingBehaviorInitialized> { |
|||
private static readonly pb::MessageParser<TrainingBehaviorInitialized> _parser = new pb::MessageParser<TrainingBehaviorInitialized>(() => new TrainingBehaviorInitialized()); |
|||
private pb::UnknownFieldSet _unknownFields; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public static pb::MessageParser<TrainingBehaviorInitialized> Parser { get { return _parser; } } |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public static pbr::MessageDescriptor Descriptor { |
|||
get { return global::Unity.MLAgents.CommunicatorObjects.TrainingAnalyticsReflection.Descriptor.MessageTypes[1]; } |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
pbr::MessageDescriptor pb::IMessage.Descriptor { |
|||
get { return Descriptor; } |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public TrainingBehaviorInitialized() { |
|||
OnConstruction(); |
|||
} |
|||
|
|||
partial void OnConstruction(); |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public TrainingBehaviorInitialized(TrainingBehaviorInitialized other) : this() { |
|||
behaviorName_ = other.behaviorName_; |
|||
trainerType_ = other.trainerType_; |
|||
extrinsicRewardEnabled_ = other.extrinsicRewardEnabled_; |
|||
gailRewardEnabled_ = other.gailRewardEnabled_; |
|||
curiosityRewardEnabled_ = other.curiosityRewardEnabled_; |
|||
rndRewardEnabled_ = other.rndRewardEnabled_; |
|||
behavioralCloningEnabled_ = other.behavioralCloningEnabled_; |
|||
recurrentEnabled_ = other.recurrentEnabled_; |
|||
visualEncoder_ = other.visualEncoder_; |
|||
numNetworkLayers_ = other.numNetworkLayers_; |
|||
numNetworkHiddenUnits_ = other.numNetworkHiddenUnits_; |
|||
trainerThreaded_ = other.trainerThreaded_; |
|||
selfPlayEnabled_ = other.selfPlayEnabled_; |
|||
curriculumEnabled_ = other.curriculumEnabled_; |
|||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public TrainingBehaviorInitialized Clone() { |
|||
return new TrainingBehaviorInitialized(this); |
|||
} |
|||
|
|||
/// <summary>Field number for the "behavior_name" field.</summary>
|
|||
public const int BehaviorNameFieldNumber = 1; |
|||
private string behaviorName_ = ""; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public string BehaviorName { |
|||
get { return behaviorName_; } |
|||
set { |
|||
behaviorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "trainer_type" field.</summary>
|
|||
public const int TrainerTypeFieldNumber = 2; |
|||
private string trainerType_ = ""; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public string TrainerType { |
|||
get { return trainerType_; } |
|||
set { |
|||
trainerType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "extrinsic_reward_enabled" field.</summary>
|
|||
public const int ExtrinsicRewardEnabledFieldNumber = 3; |
|||
private bool extrinsicRewardEnabled_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public bool ExtrinsicRewardEnabled { |
|||
get { return extrinsicRewardEnabled_; } |
|||
set { |
|||
extrinsicRewardEnabled_ = value; |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "gail_reward_enabled" field.</summary>
|
|||
public const int GailRewardEnabledFieldNumber = 4; |
|||
private bool gailRewardEnabled_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public bool GailRewardEnabled { |
|||
get { return gailRewardEnabled_; } |
|||
set { |
|||
gailRewardEnabled_ = value; |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "curiosity_reward_enabled" field.</summary>
|
|||
public const int CuriosityRewardEnabledFieldNumber = 5; |
|||
private bool curiosityRewardEnabled_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public bool CuriosityRewardEnabled { |
|||
get { return curiosityRewardEnabled_; } |
|||
set { |
|||
curiosityRewardEnabled_ = value; |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "rnd_reward_enabled" field.</summary>
|
|||
public const int RndRewardEnabledFieldNumber = 6; |
|||
private bool rndRewardEnabled_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public bool RndRewardEnabled { |
|||
get { return rndRewardEnabled_; } |
|||
set { |
|||
rndRewardEnabled_ = value; |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "behavioral_cloning_enabled" field.</summary>
|
|||
public const int BehavioralCloningEnabledFieldNumber = 7; |
|||
private bool behavioralCloningEnabled_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public bool BehavioralCloningEnabled { |
|||
get { return behavioralCloningEnabled_; } |
|||
set { |
|||
behavioralCloningEnabled_ = value; |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "recurrent_enabled" field.</summary>
|
|||
public const int RecurrentEnabledFieldNumber = 8; |
|||
private bool recurrentEnabled_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public bool RecurrentEnabled { |
|||
get { return recurrentEnabled_; } |
|||
set { |
|||
recurrentEnabled_ = value; |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "visual_encoder" field.</summary>
|
|||
public const int VisualEncoderFieldNumber = 9; |
|||
private string visualEncoder_ = ""; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public string VisualEncoder { |
|||
get { return visualEncoder_; } |
|||
set { |
|||
visualEncoder_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "num_network_layers" field.</summary>
|
|||
public const int NumNetworkLayersFieldNumber = 10; |
|||
private int numNetworkLayers_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public int NumNetworkLayers { |
|||
get { return numNetworkLayers_; } |
|||
set { |
|||
numNetworkLayers_ = value; |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "num_network_hidden_units" field.</summary>
|
|||
public const int NumNetworkHiddenUnitsFieldNumber = 11; |
|||
private int numNetworkHiddenUnits_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public int NumNetworkHiddenUnits { |
|||
get { return numNetworkHiddenUnits_; } |
|||
set { |
|||
numNetworkHiddenUnits_ = value; |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "trainer_threaded" field.</summary>
|
|||
public const int TrainerThreadedFieldNumber = 12; |
|||
private bool trainerThreaded_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public bool TrainerThreaded { |
|||
get { return trainerThreaded_; } |
|||
set { |
|||
trainerThreaded_ = value; |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "self_play_enabled" field.</summary>
|
|||
public const int SelfPlayEnabledFieldNumber = 13; |
|||
private bool selfPlayEnabled_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public bool SelfPlayEnabled { |
|||
get { return selfPlayEnabled_; } |
|||
set { |
|||
selfPlayEnabled_ = value; |
|||
} |
|||
} |
|||
|
|||
/// <summary>Field number for the "curriculum_enabled" field.</summary>
|
|||
public const int CurriculumEnabledFieldNumber = 14; |
|||
private bool curriculumEnabled_; |
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public bool CurriculumEnabled { |
|||
get { return curriculumEnabled_; } |
|||
set { |
|||
curriculumEnabled_ = value; |
|||
} |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public override bool Equals(object other) { |
|||
return Equals(other as TrainingBehaviorInitialized); |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public bool Equals(TrainingBehaviorInitialized other) { |
|||
if (ReferenceEquals(other, null)) { |
|||
return false; |
|||
} |
|||
if (ReferenceEquals(other, this)) { |
|||
return true; |
|||
} |
|||
if (BehaviorName != other.BehaviorName) return false; |
|||
if (TrainerType != other.TrainerType) return false; |
|||
if (ExtrinsicRewardEnabled != other.ExtrinsicRewardEnabled) return false; |
|||
if (GailRewardEnabled != other.GailRewardEnabled) return false; |
|||
if (CuriosityRewardEnabled != other.CuriosityRewardEnabled) return false; |
|||
if (RndRewardEnabled != other.RndRewardEnabled) return false; |
|||
if (BehavioralCloningEnabled != other.BehavioralCloningEnabled) return false; |
|||
if (RecurrentEnabled != other.RecurrentEnabled) return false; |
|||
if (VisualEncoder != other.VisualEncoder) return false; |
|||
if (NumNetworkLayers != other.NumNetworkLayers) return false; |
|||
if (NumNetworkHiddenUnits != other.NumNetworkHiddenUnits) return false; |
|||
if (TrainerThreaded != other.TrainerThreaded) return false; |
|||
if (SelfPlayEnabled != other.SelfPlayEnabled) return false; |
|||
if (CurriculumEnabled != other.CurriculumEnabled) return false; |
|||
return Equals(_unknownFields, other._unknownFields); |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public override int GetHashCode() { |
|||
int hash = 1; |
|||
if (BehaviorName.Length != 0) hash ^= BehaviorName.GetHashCode(); |
|||
if (TrainerType.Length != 0) hash ^= TrainerType.GetHashCode(); |
|||
if (ExtrinsicRewardEnabled != false) hash ^= ExtrinsicRewardEnabled.GetHashCode(); |
|||
if (GailRewardEnabled != false) hash ^= GailRewardEnabled.GetHashCode(); |
|||
if (CuriosityRewardEnabled != false) hash ^= CuriosityRewardEnabled.GetHashCode(); |
|||
if (RndRewardEnabled != false) hash ^= RndRewardEnabled.GetHashCode(); |
|||
if (BehavioralCloningEnabled != false) hash ^= BehavioralCloningEnabled.GetHashCode(); |
|||
if (RecurrentEnabled != false) hash ^= RecurrentEnabled.GetHashCode(); |
|||
if (VisualEncoder.Length != 0) hash ^= VisualEncoder.GetHashCode(); |
|||
if (NumNetworkLayers != 0) hash ^= NumNetworkLayers.GetHashCode(); |
|||
if (NumNetworkHiddenUnits != 0) hash ^= NumNetworkHiddenUnits.GetHashCode(); |
|||
if (TrainerThreaded != false) hash ^= TrainerThreaded.GetHashCode(); |
|||
if (SelfPlayEnabled != false) hash ^= SelfPlayEnabled.GetHashCode(); |
|||
if (CurriculumEnabled != false) hash ^= CurriculumEnabled.GetHashCode(); |
|||
if (_unknownFields != null) { |
|||
hash ^= _unknownFields.GetHashCode(); |
|||
} |
|||
return hash; |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public override string ToString() { |
|||
return pb::JsonFormatter.ToDiagnosticString(this); |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public void WriteTo(pb::CodedOutputStream output) { |
|||
if (BehaviorName.Length != 0) { |
|||
output.WriteRawTag(10); |
|||
output.WriteString(BehaviorName); |
|||
} |
|||
if (TrainerType.Length != 0) { |
|||
output.WriteRawTag(18); |
|||
output.WriteString(TrainerType); |
|||
} |
|||
if (ExtrinsicRewardEnabled != false) { |
|||
output.WriteRawTag(24); |
|||
output.WriteBool(ExtrinsicRewardEnabled); |
|||
} |
|||
if (GailRewardEnabled != false) { |
|||
output.WriteRawTag(32); |
|||
output.WriteBool(GailRewardEnabled); |
|||
} |
|||
if (CuriosityRewardEnabled != false) { |
|||
output.WriteRawTag(40); |
|||
output.WriteBool(CuriosityRewardEnabled); |
|||
} |
|||
if (RndRewardEnabled != false) { |
|||
output.WriteRawTag(48); |
|||
output.WriteBool(RndRewardEnabled); |
|||
} |
|||
if (BehavioralCloningEnabled != false) { |
|||
output.WriteRawTag(56); |
|||
output.WriteBool(BehavioralCloningEnabled); |
|||
} |
|||
if (RecurrentEnabled != false) { |
|||
output.WriteRawTag(64); |
|||
output.WriteBool(RecurrentEnabled); |
|||
} |
|||
if (VisualEncoder.Length != 0) { |
|||
output.WriteRawTag(74); |
|||
output.WriteString(VisualEncoder); |
|||
} |
|||
if (NumNetworkLayers != 0) { |
|||
output.WriteRawTag(80); |
|||
output.WriteInt32(NumNetworkLayers); |
|||
} |
|||
if (NumNetworkHiddenUnits != 0) { |
|||
output.WriteRawTag(88); |
|||
output.WriteInt32(NumNetworkHiddenUnits); |
|||
} |
|||
if (TrainerThreaded != false) { |
|||
output.WriteRawTag(96); |
|||
output.WriteBool(TrainerThreaded); |
|||
} |
|||
if (SelfPlayEnabled != false) { |
|||
output.WriteRawTag(104); |
|||
output.WriteBool(SelfPlayEnabled); |
|||
} |
|||
if (CurriculumEnabled != false) { |
|||
output.WriteRawTag(112); |
|||
output.WriteBool(CurriculumEnabled); |
|||
} |
|||
if (_unknownFields != null) { |
|||
_unknownFields.WriteTo(output); |
|||
} |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public int CalculateSize() { |
|||
int size = 0; |
|||
if (BehaviorName.Length != 0) { |
|||
size += 1 + pb::CodedOutputStream.ComputeStringSize(BehaviorName); |
|||
} |
|||
if (TrainerType.Length != 0) { |
|||
size += 1 + pb::CodedOutputStream.ComputeStringSize(TrainerType); |
|||
} |
|||
if (ExtrinsicRewardEnabled != false) { |
|||
size += 1 + 1; |
|||
} |
|||
if (GailRewardEnabled != false) { |
|||
size += 1 + 1; |
|||
} |
|||
if (CuriosityRewardEnabled != false) { |
|||
size += 1 + 1; |
|||
} |
|||
if (RndRewardEnabled != false) { |
|||
size += 1 + 1; |
|||
} |
|||
if (BehavioralCloningEnabled != false) { |
|||
size += 1 + 1; |
|||
} |
|||
if (RecurrentEnabled != false) { |
|||
size += 1 + 1; |
|||
} |
|||
if (VisualEncoder.Length != 0) { |
|||
size += 1 + pb::CodedOutputStream.ComputeStringSize(VisualEncoder); |
|||
} |
|||
if (NumNetworkLayers != 0) { |
|||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumNetworkLayers); |
|||
} |
|||
if (NumNetworkHiddenUnits != 0) { |
|||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumNetworkHiddenUnits); |
|||
} |
|||
if (TrainerThreaded != false) { |
|||
size += 1 + 1; |
|||
} |
|||
if (SelfPlayEnabled != false) { |
|||
size += 1 + 1; |
|||
} |
|||
if (CurriculumEnabled != false) { |
|||
size += 1 + 1; |
|||
} |
|||
if (_unknownFields != null) { |
|||
size += _unknownFields.CalculateSize(); |
|||
} |
|||
return size; |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public void MergeFrom(TrainingBehaviorInitialized other) { |
|||
if (other == null) { |
|||
return; |
|||
} |
|||
if (other.BehaviorName.Length != 0) { |
|||
BehaviorName = other.BehaviorName; |
|||
} |
|||
if (other.TrainerType.Length != 0) { |
|||
TrainerType = other.TrainerType; |
|||
} |
|||
if (other.ExtrinsicRewardEnabled != false) { |
|||
ExtrinsicRewardEnabled = other.ExtrinsicRewardEnabled; |
|||
} |
|||
if (other.GailRewardEnabled != false) { |
|||
GailRewardEnabled = other.GailRewardEnabled; |
|||
} |
|||
if (other.CuriosityRewardEnabled != false) { |
|||
CuriosityRewardEnabled = other.CuriosityRewardEnabled; |
|||
} |
|||
if (other.RndRewardEnabled != false) { |
|||
RndRewardEnabled = other.RndRewardEnabled; |
|||
} |
|||
if (other.BehavioralCloningEnabled != false) { |
|||
BehavioralCloningEnabled = other.BehavioralCloningEnabled; |
|||
} |
|||
if (other.RecurrentEnabled != false) { |
|||
RecurrentEnabled = other.RecurrentEnabled; |
|||
} |
|||
if (other.VisualEncoder.Length != 0) { |
|||
VisualEncoder = other.VisualEncoder; |
|||
} |
|||
if (other.NumNetworkLayers != 0) { |
|||
NumNetworkLayers = other.NumNetworkLayers; |
|||
} |
|||
if (other.NumNetworkHiddenUnits != 0) { |
|||
NumNetworkHiddenUnits = other.NumNetworkHiddenUnits; |
|||
} |
|||
if (other.TrainerThreaded != false) { |
|||
TrainerThreaded = other.TrainerThreaded; |
|||
} |
|||
if (other.SelfPlayEnabled != false) { |
|||
SelfPlayEnabled = other.SelfPlayEnabled; |
|||
} |
|||
if (other.CurriculumEnabled != false) { |
|||
CurriculumEnabled = other.CurriculumEnabled; |
|||
} |
|||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); |
|||
} |
|||
|
|||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] |
|||
public void MergeFrom(pb::CodedInputStream input) { |
|||
uint tag; |
|||
while ((tag = input.ReadTag()) != 0) { |
|||
switch(tag) { |
|||
default: |
|||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); |
|||
break; |
|||
case 10: { |
|||
BehaviorName = input.ReadString(); |
|||
break; |
|||
} |
|||
case 18: { |
|||
TrainerType = input.ReadString(); |
|||
break; |
|||
} |
|||
case 24: { |
|||
ExtrinsicRewardEnabled = input.ReadBool(); |
|||
break; |
|||
} |
|||
case 32: { |
|||
GailRewardEnabled = input.ReadBool(); |
|||
break; |
|||
} |
|||
case 40: { |
|||
CuriosityRewardEnabled = input.ReadBool(); |
|||
break; |
|||
} |
|||
case 48: { |
|||
RndRewardEnabled = input.ReadBool(); |
|||
break; |
|||
} |
|||
case 56: { |
|||
BehavioralCloningEnabled = input.ReadBool(); |
|||
break; |
|||
} |
|||
case 64: { |
|||
RecurrentEnabled = input.ReadBool(); |
|||
break; |
|||
} |
|||
case 74: { |
|||
VisualEncoder = input.ReadString(); |
|||
break; |
|||
} |
|||
case 80: { |
|||
NumNetworkLayers = input.ReadInt32(); |
|||
break; |
|||
} |
|||
case 88: { |
|||
NumNetworkHiddenUnits = input.ReadInt32(); |
|||
break; |
|||
} |
|||
case 96: { |
|||
TrainerThreaded = input.ReadBool(); |
|||
break; |
|||
} |
|||
case 104: { |
|||
SelfPlayEnabled = input.ReadBool(); |
|||
break; |
|||
} |
|||
case 112: { |
|||
CurriculumEnabled = input.ReadBool(); |
|||
break; |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
} |
|||
|
|||
#endregion
|
|||
|
|||
} |
|||
|
|||
#endregion Designer generated code
|
|
|||
fileFormatVersion: 2 |
|||
guid: 9e6ac06a3931742d798cf922de6b99f0 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// Identifiers for "built in" sensor types.
|
|||
/// These are only used for analytics, and should not be used for any runtime decisions.
|
|||
///
|
|||
/// NOTE: Do not renumber these, since the values are used for analytics. Renaming is allowed though.
|
|||
/// </summary>
|
|||
public enum BuiltInSensorType |
|||
{ |
|||
Unknown = 0, |
|||
VectorSensor = 1, |
|||
// Note that StackingSensor actually returns the wrapped sensor's type
|
|||
StackingSensor = 2, |
|||
RayPerceptionSensor = 3, |
|||
ReflectionSensor = 4, |
|||
CameraSensor = 5, |
|||
RenderTextureSensor = 6, |
|||
BufferSensor = 7, |
|||
PhysicsBodySensor = 8, |
|||
Match3Sensor = 9, |
|||
GridSensor = 10 |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Interface for sensors that are provided as part of ML-Agents.
|
|||
/// User-implemented sensors don't need to use this interface.
|
|||
/// </summary>
|
|||
public interface IBuiltInSensor |
|||
{ |
|||
/// <summary>
|
|||
/// Return the corresponding BuiltInSensorType for the sensor.
|
|||
/// </summary>
|
|||
/// <returns>A BuiltInSensorType corresponding to the sensor.</returns>
|
|||
BuiltInSensorType GetBuiltInSensorType(); |
|||
} |
|||
|
|||
|
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: c0c4a98bf1c941b381917cb65209beee |
|||
timeCreated: 1611096525 |
|
|||
using System; |
|||
using UnityEngine; |
|||
using Unity.MLAgents.Analytics; |
|||
using Unity.MLAgents.CommunicatorObjects; |
|||
|
|||
namespace Unity.MLAgents.SideChannels |
|||
{ |
|||
public class TrainingAnalyticsSideChannel : SideChannel |
|||
{ |
|||
const string k_TrainingAnalyticsConfigId = "b664a4a9-d86f-5a5f-95cb-e8353a7e8356"; |
|||
|
|||
/// <summary>
|
|||
/// Initializes the side channel. The constructor is internal because only one instance is
|
|||
/// supported at a time, and is created by the Academy.
|
|||
/// </summary>
|
|||
internal TrainingAnalyticsSideChannel() |
|||
{ |
|||
ChannelId = new Guid(k_TrainingAnalyticsConfigId); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected override void OnMessageReceived(IncomingMessage msg) |
|||
{ |
|||
Google.Protobuf.WellKnownTypes.Any anyMessage = null; |
|||
try |
|||
{ |
|||
anyMessage = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(msg.GetRawBytes()); |
|||
} |
|||
catch (Google.Protobuf.InvalidProtocolBufferException) |
|||
{ |
|||
// Bad message, nothing we can do about it, so just ignore.
|
|||
return; |
|||
} |
|||
|
|||
if (anyMessage.Is(TrainingEnvironmentInitialized.Descriptor)) |
|||
{ |
|||
var envInitProto = anyMessage.Unpack<TrainingEnvironmentInitialized>(); |
|||
var envInitEvent = envInitProto.ToTrainingEnvironmentInitializedEvent(); |
|||
TrainingAnalytics.TrainingEnvironmentInitialized(envInitEvent); |
|||
} |
|||
else if (anyMessage.Is(TrainingBehaviorInitialized.Descriptor)) |
|||
{ |
|||
var behaviorInitProto = anyMessage.Unpack<TrainingBehaviorInitialized>(); |
|||
var behaviorTrainingEvent = behaviorInitProto.ToTrainingBehaviorInitializedEvent(); |
|||
TrainingAnalytics.TrainingBehaviorInitialized(behaviorTrainingEvent); |
|||
} |
|||
// Don't do anything for unknown types, since the user probably can't do anything about it.
|
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 13c87198bbd54b40a0b93308eb37933e |
|||
timeCreated: 1608337471 |
|
|||
using System.Collections.Generic; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Sensors; |
|||
using UnityEngine; |
|||
using Unity.Barracuda; |
|||
using Unity.MLAgents.Actuators; |
|||
using Unity.MLAgents.Analytics; |
|||
using Unity.MLAgents.Policies; |
|||
using UnityEditor; |
|||
|
|||
namespace Unity.MLAgents.Tests.Analytics |
|||
{ |
|||
[TestFixture] |
|||
public class TrainingAnalyticsTests |
|||
{ |
|||
[TestCase("foo?team=42", ExpectedResult = "foo")] |
|||
[TestCase("foo", ExpectedResult = "foo")] |
|||
[TestCase("foo?bar?team=1337", ExpectedResult = "foo?bar")] |
|||
public string TestParseBehaviorName(string fullyQualifiedBehaviorName) |
|||
{ |
|||
return TrainingAnalytics.ParseBehaviorName(fullyQualifiedBehaviorName); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestRemotePolicy() |
|||
{ |
|||
if (Academy.IsInitialized) |
|||
{ |
|||
Academy.Instance.Dispose(); |
|||
} |
|||
|
|||
using (new AnalyticsUtils.DisableAnalyticsSending()) |
|||
{ |
|||
var actionSpec = ActionSpec.MakeContinuous(3); |
|||
var policy = new RemotePolicy(actionSpec, "TestBehavior?team=42"); |
|||
policy.RequestDecision(new AgentInfo(), new List<ISensor>()); |
|||
} |
|||
|
|||
Academy.Instance.Dispose(); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 70b8f1544bc34b4e8f1bc1068c64f01c |
|||
timeCreated: 1610419546 |
|
|||
using System; |
|||
using System.Linq; |
|||
using System.Text; |
|||
using NUnit.Framework; |
|||
using Google.Protobuf; |
|||
using Unity.MLAgents.Analytics; |
|||
using Unity.MLAgents.SideChannels; |
|||
using Unity.MLAgents.CommunicatorObjects; |
|||
|
|||
|
|||
namespace Unity.MLAgents.Tests |
|||
{ |
|||
/// <summary>
|
|||
/// These tests send messages through the event handling code.
|
|||
/// There's no output to test, so just make sure there are no exceptions
|
|||
/// (and get the code coverage above the minimum).
|
|||
/// </summary>
|
|||
public class TrainingAnalyticsSideChannelTests |
|||
{ |
|||
[Test] |
|||
public void TestTrainingEnvironmentReceived() |
|||
{ |
|||
var anyMsg = Google.Protobuf.WellKnownTypes.Any.Pack(new TrainingEnvironmentInitialized()); |
|||
var anyMsgBytes = anyMsg.ToByteArray(); |
|||
var sideChannel = new TrainingAnalyticsSideChannel(); |
|||
using (new AnalyticsUtils.DisableAnalyticsSending()) |
|||
{ |
|||
sideChannel.ProcessMessage(anyMsgBytes); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestTrainingBehaviorReceived() |
|||
{ |
|||
var anyMsg = Google.Protobuf.WellKnownTypes.Any.Pack(new TrainingBehaviorInitialized()); |
|||
var anyMsgBytes = anyMsg.ToByteArray(); |
|||
var sideChannel = new TrainingAnalyticsSideChannel(); |
|||
using (new AnalyticsUtils.DisableAnalyticsSending()) |
|||
{ |
|||
sideChannel.ProcessMessage(anyMsgBytes); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestInvalidProtobufMessage() |
|||
{ |
|||
// Test an invalid (non-protobuf) message. This should silently ignore the data.
|
|||
var badBytes = Encoding.ASCII.GetBytes("Lorem ipsum"); |
|||
var sideChannel = new TrainingAnalyticsSideChannel(); |
|||
using (new AnalyticsUtils.DisableAnalyticsSending()) |
|||
{ |
|||
sideChannel.ProcessMessage(badBytes); |
|||
} |
|||
|
|||
// Test an almost-valid message. This should silently ignore the data.
|
|||
var anyMsg = Google.Protobuf.WellKnownTypes.Any.Pack(new TrainingBehaviorInitialized()); |
|||
var anyMsgBytes = anyMsg.ToByteArray(); |
|||
var truncatedMessage = new ArraySegment<byte>(anyMsgBytes, 0, anyMsgBytes.Length - 1).ToArray(); |
|||
using (new AnalyticsUtils.DisableAnalyticsSending()) |
|||
{ |
|||
sideChannel.ProcessMessage(truncatedMessage); |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: c2a71036ddec4ba4bf83c5e8ba1b8daa |
|||
timeCreated: 1610574895 |
|
|||
# Generated by the protocol buffer compiler. DO NOT EDIT! |
|||
# source: mlagents_envs/communicator_objects/training_analytics.proto |
|||
|
|||
import sys |
|||
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) |
|||
from google.protobuf import descriptor as _descriptor |
|||
from google.protobuf import message as _message |
|||
from google.protobuf import reflection as _reflection |
|||
from google.protobuf import symbol_database as _symbol_database |
|||
from google.protobuf import descriptor_pb2 |
|||
# @@protoc_insertion_point(imports) |
|||
|
|||
_sym_db = _symbol_database.Default() |
|||
|
|||
|
|||
|
|||
|
|||
DESCRIPTOR = _descriptor.FileDescriptor( |
|||
name='mlagents_envs/communicator_objects/training_analytics.proto', |
|||
package='communicator_objects', |
|||
syntax='proto3', |
|||
serialized_pb=_b('\n;mlagents_envs/communicator_objects/training_analytics.proto\x12\x14\x63ommunicator_objects\"\xd9\x01\n\x1eTrainingEnvironmentInitialized\x12\x18\n\x10mlagents_version\x18\x01 \x01(\t\x12\x1d\n\x15mlagents_envs_version\x18\x02 \x01(\t\x12\x16\n\x0epython_version\x18\x03 \x01(\t\x12\x15\n\rtorch_version\x18\x04 \x01(\t\x12\x19\n\x11torch_device_type\x18\x05 \x01(\t\x12\x10\n\x08num_envs\x18\x06 \x01(\x05\x12\"\n\x1anum_environment_parameters\x18\x07 \x01(\x05\"\xad\x03\n\x1bTrainingBehaviorInitialized\x12\x15\n\rbehavior_name\x18\x01 \x01(\t\x12\x14\n\x0ctrainer_type\x18\x02 \x01(\t\x12 \n\x18\x65xtrinsic_reward_enabled\x18\x03 \x01(\x08\x12\x1b\n\x13gail_reward_enabled\x18\x04 \x01(\x08\x12 \n\x18\x63uriosity_reward_enabled\x18\x05 \x01(\x08\x12\x1a\n\x12rnd_reward_enabled\x18\x06 \x01(\x08\x12\"\n\x1a\x62\x65havioral_cloning_enabled\x18\x07 \x01(\x08\x12\x19\n\x11recurrent_enabled\x18\x08 \x01(\x08\x12\x16\n\x0evisual_encoder\x18\t \x01(\t\x12\x1a\n\x12num_network_layers\x18\n \x01(\x05\x12 \n\x18num_network_hidden_units\x18\x0b \x01(\x05\x12\x18\n\x10trainer_threaded\x18\x0c \x01(\x08\x12\x19\n\x11self_play_enabled\x18\r \x01(\x08\x12\x1a\n\x12\x63urriculum_enabled\x18\x0e \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') |
|||
) |
|||
|
|||
|
|||
|
|||
|
|||
_TRAININGENVIRONMENTINITIALIZED = _descriptor.Descriptor( |
|||
name='TrainingEnvironmentInitialized', |
|||
full_name='communicator_objects.TrainingEnvironmentInitialized', |
|||
filename=None, |
|||
file=DESCRIPTOR, |
|||
containing_type=None, |
|||
fields=[ |
|||
_descriptor.FieldDescriptor( |
|||
name='mlagents_version', full_name='communicator_objects.TrainingEnvironmentInitialized.mlagents_version', index=0, |
|||
number=1, type=9, cpp_type=9, label=1, |
|||
has_default_value=False, default_value=_b("").decode('utf-8'), |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='mlagents_envs_version', full_name='communicator_objects.TrainingEnvironmentInitialized.mlagents_envs_version', index=1, |
|||
number=2, type=9, cpp_type=9, label=1, |
|||
has_default_value=False, default_value=_b("").decode('utf-8'), |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='python_version', full_name='communicator_objects.TrainingEnvironmentInitialized.python_version', index=2, |
|||
number=3, type=9, cpp_type=9, label=1, |
|||
has_default_value=False, default_value=_b("").decode('utf-8'), |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='torch_version', full_name='communicator_objects.TrainingEnvironmentInitialized.torch_version', index=3, |
|||
number=4, type=9, cpp_type=9, label=1, |
|||
has_default_value=False, default_value=_b("").decode('utf-8'), |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='torch_device_type', full_name='communicator_objects.TrainingEnvironmentInitialized.torch_device_type', index=4, |
|||
number=5, type=9, cpp_type=9, label=1, |
|||
has_default_value=False, default_value=_b("").decode('utf-8'), |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='num_envs', full_name='communicator_objects.TrainingEnvironmentInitialized.num_envs', index=5, |
|||
number=6, type=5, cpp_type=1, label=1, |
|||
has_default_value=False, default_value=0, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='num_environment_parameters', full_name='communicator_objects.TrainingEnvironmentInitialized.num_environment_parameters', index=6, |
|||
number=7, type=5, cpp_type=1, label=1, |
|||
has_default_value=False, default_value=0, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
], |
|||
extensions=[ |
|||
], |
|||
nested_types=[], |
|||
enum_types=[ |
|||
], |
|||
options=None, |
|||
is_extendable=False, |
|||
syntax='proto3', |
|||
extension_ranges=[], |
|||
oneofs=[ |
|||
], |
|||
serialized_start=86, |
|||
serialized_end=303, |
|||
) |
|||
|
|||
|
|||
_TRAININGBEHAVIORINITIALIZED = _descriptor.Descriptor( |
|||
name='TrainingBehaviorInitialized', |
|||
full_name='communicator_objects.TrainingBehaviorInitialized', |
|||
filename=None, |
|||
file=DESCRIPTOR, |
|||
containing_type=None, |
|||
fields=[ |
|||
_descriptor.FieldDescriptor( |
|||
name='behavior_name', full_name='communicator_objects.TrainingBehaviorInitialized.behavior_name', index=0, |
|||
number=1, type=9, cpp_type=9, label=1, |
|||
has_default_value=False, default_value=_b("").decode('utf-8'), |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='trainer_type', full_name='communicator_objects.TrainingBehaviorInitialized.trainer_type', index=1, |
|||
number=2, type=9, cpp_type=9, label=1, |
|||
has_default_value=False, default_value=_b("").decode('utf-8'), |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='extrinsic_reward_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.extrinsic_reward_enabled', index=2, |
|||
number=3, type=8, cpp_type=7, label=1, |
|||
has_default_value=False, default_value=False, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='gail_reward_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.gail_reward_enabled', index=3, |
|||
number=4, type=8, cpp_type=7, label=1, |
|||
has_default_value=False, default_value=False, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='curiosity_reward_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.curiosity_reward_enabled', index=4, |
|||
number=5, type=8, cpp_type=7, label=1, |
|||
has_default_value=False, default_value=False, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='rnd_reward_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.rnd_reward_enabled', index=5, |
|||
number=6, type=8, cpp_type=7, label=1, |
|||
has_default_value=False, default_value=False, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='behavioral_cloning_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.behavioral_cloning_enabled', index=6, |
|||
number=7, type=8, cpp_type=7, label=1, |
|||
has_default_value=False, default_value=False, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='recurrent_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.recurrent_enabled', index=7, |
|||
number=8, type=8, cpp_type=7, label=1, |
|||
has_default_value=False, default_value=False, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='visual_encoder', full_name='communicator_objects.TrainingBehaviorInitialized.visual_encoder', index=8, |
|||
number=9, type=9, cpp_type=9, label=1, |
|||
has_default_value=False, default_value=_b("").decode('utf-8'), |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='num_network_layers', full_name='communicator_objects.TrainingBehaviorInitialized.num_network_layers', index=9, |
|||
number=10, type=5, cpp_type=1, label=1, |
|||
has_default_value=False, default_value=0, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='num_network_hidden_units', full_name='communicator_objects.TrainingBehaviorInitialized.num_network_hidden_units', index=10, |
|||
number=11, type=5, cpp_type=1, label=1, |
|||
has_default_value=False, default_value=0, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='trainer_threaded', full_name='communicator_objects.TrainingBehaviorInitialized.trainer_threaded', index=11, |
|||
number=12, type=8, cpp_type=7, label=1, |
|||
has_default_value=False, default_value=False, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='self_play_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.self_play_enabled', index=12, |
|||
number=13, type=8, cpp_type=7, label=1, |
|||
has_default_value=False, default_value=False, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
_descriptor.FieldDescriptor( |
|||
name='curriculum_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.curriculum_enabled', index=13, |
|||
number=14, type=8, cpp_type=7, label=1, |
|||
has_default_value=False, default_value=False, |
|||
message_type=None, enum_type=None, containing_type=None, |
|||
is_extension=False, extension_scope=None, |
|||
options=None, file=DESCRIPTOR), |
|||
], |
|||
extensions=[ |
|||
], |
|||
nested_types=[], |
|||
enum_types=[ |
|||
], |
|||
options=None, |
|||
is_extendable=False, |
|||
syntax='proto3', |
|||
extension_ranges=[], |
|||
oneofs=[ |
|||
], |
|||
serialized_start=306, |
|||
serialized_end=735, |
|||
) |
|||
|
|||
DESCRIPTOR.message_types_by_name['TrainingEnvironmentInitialized'] = _TRAININGENVIRONMENTINITIALIZED |
|||
DESCRIPTOR.message_types_by_name['TrainingBehaviorInitialized'] = _TRAININGBEHAVIORINITIALIZED |
|||
_sym_db.RegisterFileDescriptor(DESCRIPTOR) |
|||
|
|||
TrainingEnvironmentInitialized = _reflection.GeneratedProtocolMessageType('TrainingEnvironmentInitialized', (_message.Message,), dict( |
|||
DESCRIPTOR = _TRAININGENVIRONMENTINITIALIZED, |
|||
__module__ = 'mlagents_envs.communicator_objects.training_analytics_pb2' |
|||
# @@protoc_insertion_point(class_scope:communicator_objects.TrainingEnvironmentInitialized) |
|||
)) |
|||
_sym_db.RegisterMessage(TrainingEnvironmentInitialized) |
|||
|
|||
TrainingBehaviorInitialized = _reflection.GeneratedProtocolMessageType('TrainingBehaviorInitialized', (_message.Message,), dict( |
|||
DESCRIPTOR = _TRAININGBEHAVIORINITIALIZED, |
|||
__module__ = 'mlagents_envs.communicator_objects.training_analytics_pb2' |
|||
# @@protoc_insertion_point(class_scope:communicator_objects.TrainingBehaviorInitialized) |
|||
)) |
|||
_sym_db.RegisterMessage(TrainingBehaviorInitialized) |
|||
|
|||
|
|||
DESCRIPTOR.has_options = True |
|||
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\"Unity.MLAgents.CommunicatorObjects')) |
|||
# @@protoc_insertion_point(module_scope) |
|
|||
# @generated by generate_proto_mypy_stubs.py. Do not edit! |
|||
import sys |
|||
from google.protobuf.descriptor import ( |
|||
Descriptor as google___protobuf___descriptor___Descriptor, |
|||
) |
|||
|
|||
from google.protobuf.message import ( |
|||
Message as google___protobuf___message___Message, |
|||
) |
|||
|
|||
from typing import ( |
|||
Optional as typing___Optional, |
|||
Text as typing___Text, |
|||
) |
|||
|
|||
from typing_extensions import ( |
|||
Literal as typing_extensions___Literal, |
|||
) |
|||
|
|||
|
|||
builtin___bool = bool |
|||
builtin___bytes = bytes |
|||
builtin___float = float |
|||
builtin___int = int |
|||
|
|||
|
|||
class TrainingEnvironmentInitialized(google___protobuf___message___Message): |
|||
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... |
|||
mlagents_version = ... # type: typing___Text |
|||
mlagents_envs_version = ... # type: typing___Text |
|||
python_version = ... # type: typing___Text |
|||
torch_version = ... # type: typing___Text |
|||
torch_device_type = ... # type: typing___Text |
|||
num_envs = ... # type: builtin___int |
|||
num_environment_parameters = ... # type: builtin___int |
|||
|
|||
def __init__(self, |
|||
*, |
|||
mlagents_version : typing___Optional[typing___Text] = None, |
|||
mlagents_envs_version : typing___Optional[typing___Text] = None, |
|||
python_version : typing___Optional[typing___Text] = None, |
|||
torch_version : typing___Optional[typing___Text] = None, |
|||
torch_device_type : typing___Optional[typing___Text] = None, |
|||
num_envs : typing___Optional[builtin___int] = None, |
|||
num_environment_parameters : typing___Optional[builtin___int] = None, |
|||
) -> None: ... |
|||
@classmethod |
|||
def FromString(cls, s: builtin___bytes) -> TrainingEnvironmentInitialized: ... |
|||
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... |
|||
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... |
|||
if sys.version_info >= (3,): |
|||
def ClearField(self, field_name: typing_extensions___Literal[u"mlagents_envs_version",u"mlagents_version",u"num_environment_parameters",u"num_envs",u"python_version",u"torch_device_type",u"torch_version"]) -> None: ... |
|||
else: |
|||
def ClearField(self, field_name: typing_extensions___Literal[u"mlagents_envs_version",b"mlagents_envs_version",u"mlagents_version",b"mlagents_version",u"num_environment_parameters",b"num_environment_parameters",u"num_envs",b"num_envs",u"python_version",b"python_version",u"torch_device_type",b"torch_device_type",u"torch_version",b"torch_version"]) -> None: ... |
|||
|
|||
class TrainingBehaviorInitialized(google___protobuf___message___Message): |
|||
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... |
|||
behavior_name = ... # type: typing___Text |
|||
trainer_type = ... # type: typing___Text |
|||
extrinsic_reward_enabled = ... # type: builtin___bool |
|||
gail_reward_enabled = ... # type: builtin___bool |
|||
curiosity_reward_enabled = ... # type: builtin___bool |
|||
rnd_reward_enabled = ... # type: builtin___bool |
|||
behavioral_cloning_enabled = ... # type: builtin___bool |
|||
recurrent_enabled = ... # type: builtin___bool |
|||
visual_encoder = ... # type: typing___Text |
|||
num_network_layers = ... # type: builtin___int |
|||
num_network_hidden_units = ... # type: builtin___int |
|||
trainer_threaded = ... # type: builtin___bool |
|||
self_play_enabled = ... # type: builtin___bool |
|||
curriculum_enabled = ... # type: builtin___bool |
|||
|
|||
def __init__(self, |
|||
*, |
|||
behavior_name : typing___Optional[typing___Text] = None, |
|||
trainer_type : typing___Optional[typing___Text] = None, |
|||
extrinsic_reward_enabled : typing___Optional[builtin___bool] = None, |
|||
gail_reward_enabled : typing___Optional[builtin___bool] = None, |
|||
curiosity_reward_enabled : typing___Optional[builtin___bool] = None, |
|||
rnd_reward_enabled : typing___Optional[builtin___bool] = None, |
|||
behavioral_cloning_enabled : typing___Optional[builtin___bool] = None, |
|||
recurrent_enabled : typing___Optional[builtin___bool] = None, |
|||
visual_encoder : typing___Optional[typing___Text] = None, |
|||
num_network_layers : typing___Optional[builtin___int] = None, |
|||
num_network_hidden_units : typing___Optional[builtin___int] = None, |
|||
trainer_threaded : typing___Optional[builtin___bool] = None, |
|||
self_play_enabled : typing___Optional[builtin___bool] = None, |
|||
curriculum_enabled : typing___Optional[builtin___bool] = None, |
|||
) -> None: ... |
|||
@classmethod |
|||
def FromString(cls, s: builtin___bytes) -> TrainingBehaviorInitialized: ... |
|||
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... |
|||
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... |
|||
if sys.version_info >= (3,): |
|||
def ClearField(self, field_name: typing_extensions___Literal[u"behavior_name",u"behavioral_cloning_enabled",u"curiosity_reward_enabled",u"curriculum_enabled",u"extrinsic_reward_enabled",u"gail_reward_enabled",u"num_network_hidden_units",u"num_network_layers",u"recurrent_enabled",u"rnd_reward_enabled",u"self_play_enabled",u"trainer_threaded",u"trainer_type",u"visual_encoder"]) -> None: ... |
|||
else: |
|||
def ClearField(self, field_name: typing_extensions___Literal[u"behavior_name",b"behavior_name",u"behavioral_cloning_enabled",b"behavioral_cloning_enabled",u"curiosity_reward_enabled",b"curiosity_reward_enabled",u"curriculum_enabled",b"curriculum_enabled",u"extrinsic_reward_enabled",b"extrinsic_reward_enabled",u"gail_reward_enabled",b"gail_reward_enabled",u"num_network_hidden_units",b"num_network_hidden_units",u"num_network_layers",b"num_network_layers",u"recurrent_enabled",b"recurrent_enabled",u"rnd_reward_enabled",b"rnd_reward_enabled",u"self_play_enabled",b"self_play_enabled",u"trainer_threaded",b"trainer_threaded",u"trainer_type",b"trainer_type",u"visual_encoder",b"visual_encoder"]) -> None: ... |
|
|||
import pytest |
|||
from unittest import mock |
|||
|
|||
import torch # noqa I201 |
|||
|
|||
from mlagents.torch_utils import set_torch_config, default_device |
|||
from mlagents.trainers.settings import TorchSettings |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"device_str, expected_type, expected_index, expected_tensor_type", |
|||
[ |
|||
("cpu", "cpu", None, torch.FloatTensor), |
|||
("cuda", "cuda", None, torch.cuda.FloatTensor), |
|||
("cuda:42", "cuda", 42, torch.cuda.FloatTensor), |
|||
("opengl", "opengl", None, torch.FloatTensor), |
|||
], |
|||
) |
|||
@mock.patch.object(torch, "set_default_tensor_type") |
|||
def test_set_torch_device( |
|||
mock_set_default_tensor_type, |
|||
device_str, |
|||
expected_type, |
|||
expected_index, |
|||
expected_tensor_type, |
|||
): |
|||
try: |
|||
torch_settings = TorchSettings(device=device_str) |
|||
set_torch_config(torch_settings) |
|||
assert default_device().type == expected_type |
|||
if expected_index is None: |
|||
assert default_device().index is None |
|||
else: |
|||
assert default_device().index == expected_index |
|||
mock_set_default_tensor_type.assert_called_once_with(expected_tensor_type) |
|||
except Exception: |
|||
raise |
|||
finally: |
|||
# restore the defaults |
|||
torch_settings = TorchSettings(device=None) |
|||
set_torch_config(torch_settings) |
|
|||
import sys |
|||
from typing import Optional |
|||
import uuid |
|||
import mlagents_envs |
|||
import mlagents.trainers |
|||
from mlagents import torch_utils |
|||
from mlagents.trainers.settings import RewardSignalType |
|||
from mlagents_envs.exception import UnityCommunicationException |
|||
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage |
|||
from mlagents_envs.communicator_objects.training_analytics_pb2 import ( |
|||
TrainingEnvironmentInitialized, |
|||
TrainingBehaviorInitialized, |
|||
) |
|||
from google.protobuf.any_pb2 import Any |
|||
|
|||
from mlagents.trainers.settings import TrainerSettings, RunOptions |
|||
|
|||
|
|||
class TrainingAnalyticsSideChannel(SideChannel): |
|||
""" |
|||
Side channel that sends information about the training to the Unity environment so it can be logged. |
|||
""" |
|||
|
|||
def __init__(self) -> None: |
|||
# >>> uuid.uuid5(uuid.NAMESPACE_URL, "com.unity.ml-agents/TrainingAnalyticsSideChannel") |
|||
# UUID('b664a4a9-d86f-5a5f-95cb-e8353a7e8356') |
|||
super().__init__(uuid.UUID("b664a4a9-d86f-5a5f-95cb-e8353a7e8356")) |
|||
self.run_options: Optional[RunOptions] = None |
|||
|
|||
def on_message_received(self, msg: IncomingMessage) -> None: |
|||
raise UnityCommunicationException( |
|||
"The TrainingAnalyticsSideChannel received a message from Unity, " |
|||
+ "this should not have happened." |
|||
) |
|||
|
|||
def environment_initialized(self, run_options: RunOptions) -> None: |
|||
self.run_options = run_options |
|||
# Tuple of (major, minor, patch) |
|||
vi = sys.version_info |
|||
env_params = run_options.environment_parameters |
|||
|
|||
msg = TrainingEnvironmentInitialized( |
|||
python_version=f"{vi[0]}.{vi[1]}.{vi[2]}", |
|||
mlagents_version=mlagents.trainers.__version__, |
|||
mlagents_envs_version=mlagents_envs.__version__, |
|||
torch_version=torch_utils.torch.__version__, |
|||
torch_device_type=torch_utils.default_device().type, |
|||
num_envs=run_options.env_settings.num_envs, |
|||
num_environment_parameters=len(env_params) if env_params else 0, |
|||
) |
|||
|
|||
any_message = Any() |
|||
any_message.Pack(msg) |
|||
|
|||
env_init_msg = OutgoingMessage() |
|||
env_init_msg.set_raw_bytes(any_message.SerializeToString()) |
|||
super().queue_message_to_send(env_init_msg) |
|||
|
|||
def training_started(self, behavior_name: str, config: TrainerSettings) -> None: |
|||
msg = TrainingBehaviorInitialized( |
|||
behavior_name=behavior_name, |
|||
trainer_type=config.trainer_type.value, |
|||
extrinsic_reward_enabled=( |
|||
RewardSignalType.EXTRINSIC in config.reward_signals |
|||
), |
|||
gail_reward_enabled=(RewardSignalType.GAIL in config.reward_signals), |
|||
curiosity_reward_enabled=( |
|||
RewardSignalType.CURIOSITY in config.reward_signals |
|||
), |
|||
rnd_reward_enabled=(RewardSignalType.RND in config.reward_signals), |
|||
behavioral_cloning_enabled=config.behavioral_cloning is not None, |
|||
recurrent_enabled=config.network_settings.memory is not None, |
|||
visual_encoder=config.network_settings.vis_encode_type.value, |
|||
num_network_layers=config.network_settings.num_layers, |
|||
num_network_hidden_units=config.network_settings.hidden_units, |
|||
trainer_threaded=config.threaded, |
|||
self_play_enabled=config.self_play is not None, |
|||
curriculum_enabled=self._behavior_uses_curriculum(behavior_name), |
|||
) |
|||
|
|||
any_message = Any() |
|||
any_message.Pack(msg) |
|||
|
|||
training_start_msg = OutgoingMessage() |
|||
training_start_msg.set_raw_bytes(any_message.SerializeToString()) |
|||
|
|||
super().queue_message_to_send(training_start_msg) |
|||
|
|||
def _behavior_uses_curriculum(self, behavior_name: str) -> bool: |
|||
if not self.run_options or not self.run_options.environment_parameters: |
|||
return False |
|||
|
|||
for param_settings in self.run_options.environment_parameters.values(): |
|||
for lesson in param_settings.curriculum: |
|||
cc = lesson.completion_criteria |
|||
if cc and cc.behavior == behavior_name: |
|||
return True |
|||
|
|||
return False |
|
|||
syntax = "proto3"; |
|||
|
|||
option csharp_namespace = "Unity.MLAgents.CommunicatorObjects"; |
|||
package communicator_objects; |
|||
|
|||
message TrainingEnvironmentInitialized { |
|||
string mlagents_version = 1; |
|||
string mlagents_envs_version = 2; |
|||
string python_version = 3; |
|||
string torch_version = 4; |
|||
string torch_device_type = 5; |
|||
int32 num_envs = 6; |
|||
int32 num_environment_parameters = 7; |
|||
} |
|||
|
|||
message TrainingBehaviorInitialized { |
|||
string behavior_name = 1; |
|||
string trainer_type = 2; |
|||
bool extrinsic_reward_enabled = 3; |
|||
bool gail_reward_enabled = 4; |
|||
bool curiosity_reward_enabled = 5; |
|||
bool rnd_reward_enabled = 6; |
|||
bool behavioral_cloning_enabled = 7; |
|||
bool recurrent_enabled = 8; |
|||
string visual_encoder = 9; |
|||
int32 num_network_layers = 10; |
|||
int32 num_network_hidden_units = 11; |
|||
bool trainer_threaded = 12; |
|||
bool self_play_enabled = 13; |
|||
bool curriculum_enabled = 14; |
|||
} |
|
|||
[pytest] |
|||
addopts = --strict-markers |
|||
markers = |
|||
check_environment_trains: Slow training tests, do not run on yamato |
|
|||
# Configuration for Lock Threads - https://github.com/dessant/lock-threads |
|||
|
|||
# Number of days of inactivity before a closed issue or pull request is locked |
|||
daysUntilLock: 365 |
|||
|
|||
# Skip issues and pull requests created before a given timestamp. Timestamp must |
|||
# follow ISO 8601 (`YYYY-MM-DD`). Set to `false` to disable |
|||
skipCreatedBefore: false |
|||
|
|||
# Issues and pull requests with these labels will be ignored. Set to `[]` to disable |
|||
exemptLabels: [] |
|||
|
|||
# Label to add before locking, such as `outdated`. Set to `false` to disable |
|||
lockLabel: false |
|||
|
|||
# Comment to post before locking. Set to `false` to disable |
|||
lockComment: > |
|||
This thread has been automatically locked since there has not been |
|||
any recent activity after it was closed. Please open a new issue for |
|||
related bugs. |
|||
|
|||
# Assign `resolved` as the reason for locking. Set to `false` to disable |
|||
setLockReason: true |
|||
|
|||
# Limit to only `issues` or `pulls` |
|||
only: issues |
|||
|
|||
# Optionally, specify configuration settings just for `issues` or `pulls` |
|||
# issues: |
|||
# exemptLabels: |
|||
# - help-wanted |
|||
# lockLabel: outdated |
|||
|
|||
# pulls: |
|||
# daysUntilLock: 30 |
|||
|
|||
# Repository to extend settings from |
|||
# _extends: repo |
撰写
预览
正在加载...
取消
保存
Reference in new issue