Ervin Teng
4 年前
当前提交
2fc4fe16
共有 8 个文件被更改,包括 185 次插入 和 1 次删除
-
3Project/Assets/ML-Agents/Examples/3DBall/Scripts/Task3DAgent.cs
-
7com.unity.ml-agents/Runtime/Academy.cs
-
5com.unity.ml-agents/Runtime/Agent.cs
-
50com.unity.ml-agents/Runtime/AgentParameters.cs
-
11com.unity.ml-agents/Runtime/AgentParameters.cs.meta
-
65com.unity.ml-agents/Runtime/SideChannels/AgentParametersChannel.cs
-
11com.unity.ml-agents/Runtime/SideChannels/AgentParametersChannel.cs.meta
-
34ml-agents-envs/mlagents_envs/side_channel/agent_parameters_channel.py
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using Unity.MLAgents.SideChannels; |
|||
|
|||
namespace Unity.MLAgents |
|||
{ |
|||
/// <summary>
|
|||
/// A container for the Environment Parameters that may be modified during training.
|
|||
/// The keys for those parameters are defined in the trainer configurations and the
|
|||
/// the values are generated from the training process in features such as Curriculum Learning
|
|||
/// and Environment Parameter Randomization.
|
|||
///
|
|||
/// One current assumption for all the environment parameters is that they are of type float.
|
|||
/// </summary>
|
|||
public sealed class AgentParameters |
|||
{ |
|||
/// <summary>
|
|||
/// The side channel that is used to receive the new parameter values.
|
|||
/// </summary>
|
|||
readonly AgentParametersChannel m_Channel; |
|||
|
|||
/// <summary>
|
|||
/// Constructor.
|
|||
/// </summary>
|
|||
internal AgentParameters() |
|||
{ |
|||
m_Channel = new AgentParametersChannel(); |
|||
SideChannelManager.RegisterSideChannel(m_Channel); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns the parameter value for the specified key. Returns the default value provided
|
|||
/// if this parameter key does not have a value. Only returns a parameter value if it is
|
|||
/// of type float.
|
|||
/// </summary>
|
|||
/// <param name="key">The parameter key</param>
|
|||
/// <param name="defaultValue">Default value for this parameter.</param>
|
|||
/// <returns></returns>
|
|||
public float GetWithDefault(int episodeId, string key, float defaultValue) |
|||
{ |
|||
return m_Channel.GetWithDefault(episodeId, key, defaultValue); |
|||
} |
|||
|
|||
|
|||
internal void Dispose() |
|||
{ |
|||
SideChannelManager.UnregisterSideChannel(m_Channel); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: c6d4c5ad59e7b4066b64fa47b5205889 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System.Collections.Generic; |
|||
using System; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.SideChannels |
|||
{ |
|||
internal class AgentParametersChannel : SideChannel |
|||
{ |
|||
Dictionary<int, Dictionary<string, float>> m_Parameters = new Dictionary<int, Dictionary<string, float>>(); |
|||
|
|||
const string k_EnvParamsId = "534c891e-810f-11ea-a9d0-822485860401"; |
|||
|
|||
/// <summary>
|
|||
/// Initializes the side channel. The constructor is internal because only one instance is
|
|||
/// supported at a time, and is created by the Academy.
|
|||
/// </summary>
|
|||
internal AgentParametersChannel() |
|||
{ |
|||
ChannelId = new Guid(k_EnvParamsId); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected override void OnMessageReceived(IncomingMessage msg) |
|||
{ |
|||
var episodeId = msg.ReadInt32(); |
|||
var key = msg.ReadString(); |
|||
var value = msg.ReadFloat32(); |
|||
if(!m_Parameters.ContainsKey(episodeId)) |
|||
{ |
|||
m_Parameters[episodeId] = new Dictionary<string, float>(); |
|||
} |
|||
m_Parameters[episodeId][key] = value; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns the parameter value associated with the provided key. Returns the default
|
|||
/// value if one doesn't exist.
|
|||
/// </summary>
|
|||
/// <param name="key">Parameter key.</param>
|
|||
/// <param name="defaultValue">Default value to return.</param>
|
|||
/// <returns></returns>
|
|||
public float GetWithDefault(int episodeId, string key, float defaultValue) |
|||
{ |
|||
float value = defaultValue; |
|||
bool hasKey = false; |
|||
Dictionary<string, float> agent_dict; |
|||
if(m_Parameters.TryGetValue(episodeId, out agent_dict)) |
|||
{ |
|||
agent_dict.TryGetValue(key, out value); |
|||
} |
|||
return value; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns all parameter keys that have a registered value.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public IList<string> ListParameters(int episodeId) |
|||
{ |
|||
Dictionary<string, float> agent_dict; |
|||
m_Parameters.TryGetValue(episodeId, out agent_dict); |
|||
return new List<string>(agent_dict.Keys); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 22884d2b9466b4a589e059247a2f519f |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage |
|||
from mlagents_envs.exception import UnityCommunicationException |
|||
from mlagents_envs.base_env import AgentId |
|||
import uuid |
|||
|
|||
|
|||
class AgentParametersChannel(SideChannel): |
|||
""" |
|||
This is the SideChannel for sending agent-specific parameters to Unity. |
|||
You can send parameters to an environment with the command |
|||
set_float_parameter. |
|||
""" |
|||
|
|||
def __init__(self) -> None: |
|||
channel_id = uuid.UUID(("534c891e-810f-11ea-a9d0-822485860401")) |
|||
super().__init__(channel_id) |
|||
|
|||
def on_message_received(self, msg: IncomingMessage) -> None: |
|||
raise UnityCommunicationException( |
|||
"The EnvironmentParametersChannel received a message from Unity, " |
|||
+ "this should not have happend." |
|||
) |
|||
|
|||
def set_float_parameter(self, agent_id: AgentId, key: str, value: float) -> None: |
|||
""" |
|||
Sets a float environment parameter in the Unity Environment. |
|||
:param key: The string identifier of the parameter. |
|||
:param value: The float value of the parameter. |
|||
""" |
|||
msg = OutgoingMessage() |
|||
msg.write_int32(agent_id) |
|||
msg.write_string(key) |
|||
msg.write_float32(value) |
|||
super().queue_message_to_send(msg) |
撰写
预览
正在加载...
取消
保存
Reference in new issue