浏览代码

Add AgentParametersChannel

/develop/taggedobservations
Ervin Teng 5 年前
当前提交
2fc4fe16
共有 8 个文件被更改,包括 185 次插入1 次删除
  1. 3
      Project/Assets/ML-Agents/Examples/3DBall/Scripts/Task3DAgent.cs
  2. 7
      com.unity.ml-agents/Runtime/Academy.cs
  3. 5
      com.unity.ml-agents/Runtime/Agent.cs
  4. 50
      com.unity.ml-agents/Runtime/AgentParameters.cs
  5. 11
      com.unity.ml-agents/Runtime/AgentParameters.cs.meta
  6. 65
      com.unity.ml-agents/Runtime/SideChannels/AgentParametersChannel.cs
  7. 11
      com.unity.ml-agents/Runtime/SideChannels/AgentParametersChannel.cs.meta
  8. 34
      ml-agents-envs/mlagents_envs/side_channel/agent_parameters_channel.py

3
Project/Assets/ML-Agents/Examples/3DBall/Scripts/Task3DAgent.cs


m_BallRb = ball.GetComponent<Rigidbody>();
m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
Debug.Log(m_TaskSensor);
}
public override void CollectObservations(VectorSensor sensor)

public void SetResetParameters()
{
SetBall();
// Get agent parameters
Debug.Log(GetParameterWithDefault("test_param", 0));
}
}

7
com.unity.ml-agents/Runtime/Academy.cs


}
EnvironmentParameters m_EnvironmentParameters;
AgentParameters m_AgentParameters;
StatsRecorder m_StatsRecorder;
/// <summary>

public EnvironmentParameters EnvironmentParameters
{
get { return m_EnvironmentParameters; }
}
public AgentParameters AgentParameters
{
get { return m_AgentParameters; }
}
/// <summary>

SideChannelManager.RegisterSideChannel(new EngineConfigurationChannel());
m_EnvironmentParameters = new EnvironmentParameters();
m_AgentParameters = new AgentParameters();
m_StatsRecorder = new StatsRecorder();
// Try to launch the communicator by using the arguments passed at launch

5
com.unity.ml-agents/Runtime/Agent.cs


Array.Copy(action, m_Action.vectorActions, action.Length);
}
}
public float GetParameterWithDefault(string key, float defaultValue)
{
return Academy.Instance.AgentParameters.GetWithDefault(m_EpisodeId, key, defaultValue);
}
}
}

50
com.unity.ml-agents/Runtime/AgentParameters.cs


using System;
using System.Collections.Generic;
using Unity.MLAgents.SideChannels;
namespace Unity.MLAgents
{
/// <summary>
/// A container for the Environment Parameters that may be modified during training.
/// The keys for those parameters are defined in the trainer configurations and the
/// the values are generated from the training process in features such as Curriculum Learning
/// and Environment Parameter Randomization.
///
/// One current assumption for all the environment parameters is that they are of type float.
/// </summary>
public sealed class AgentParameters
{
/// <summary>
/// The side channel that is used to receive the new parameter values.
/// </summary>
readonly AgentParametersChannel m_Channel;
/// <summary>
/// Constructor.
/// </summary>
internal AgentParameters()
{
m_Channel = new AgentParametersChannel();
SideChannelManager.RegisterSideChannel(m_Channel);
}
/// <summary>
/// Returns the parameter value for the specified key. Returns the default value provided
/// if this parameter key does not have a value. Only returns a parameter value if it is
/// of type float.
/// </summary>
/// <param name="key">The parameter key</param>
/// <param name="defaultValue">Default value for this parameter.</param>
/// <returns></returns>
public float GetWithDefault(int episodeId, string key, float defaultValue)
{
return m_Channel.GetWithDefault(episodeId, key, defaultValue);
}
internal void Dispose()
{
SideChannelManager.UnregisterSideChannel(m_Channel);
}
}
}

11
com.unity.ml-agents/Runtime/AgentParameters.cs.meta


fileFormatVersion: 2
guid: c6d4c5ad59e7b4066b64fa47b5205889
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

65
com.unity.ml-agents/Runtime/SideChannels/AgentParametersChannel.cs


using System.Collections.Generic;
using System;
using UnityEngine;
namespace Unity.MLAgents.SideChannels
{
internal class AgentParametersChannel : SideChannel
{
Dictionary<int, Dictionary<string, float>> m_Parameters = new Dictionary<int, Dictionary<string, float>>();
const string k_EnvParamsId = "534c891e-810f-11ea-a9d0-822485860401";
/// <summary>
/// Initializes the side channel. The constructor is internal because only one instance is
/// supported at a time, and is created by the Academy.
/// </summary>
internal AgentParametersChannel()
{
ChannelId = new Guid(k_EnvParamsId);
}
/// <inheritdoc/>
protected override void OnMessageReceived(IncomingMessage msg)
{
var episodeId = msg.ReadInt32();
var key = msg.ReadString();
var value = msg.ReadFloat32();
if(!m_Parameters.ContainsKey(episodeId))
{
m_Parameters[episodeId] = new Dictionary<string, float>();
}
m_Parameters[episodeId][key] = value;
}
/// <summary>
/// Returns the parameter value associated with the provided key. Returns the default
/// value if one doesn't exist.
/// </summary>
/// <param name="key">Parameter key.</param>
/// <param name="defaultValue">Default value to return.</param>
/// <returns></returns>
public float GetWithDefault(int episodeId, string key, float defaultValue)
{
float value = defaultValue;
bool hasKey = false;
Dictionary<string, float> agent_dict;
if(m_Parameters.TryGetValue(episodeId, out agent_dict))
{
agent_dict.TryGetValue(key, out value);
}
return value;
}
/// <summary>
/// Returns all parameter keys that have a registered value.
/// </summary>
/// <returns></returns>
public IList<string> ListParameters(int episodeId)
{
Dictionary<string, float> agent_dict;
m_Parameters.TryGetValue(episodeId, out agent_dict);
return new List<string>(agent_dict.Keys);
}
}
}

11
com.unity.ml-agents/Runtime/SideChannels/AgentParametersChannel.cs.meta


fileFormatVersion: 2
guid: 22884d2b9466b4a589e059247a2f519f
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

34
ml-agents-envs/mlagents_envs/side_channel/agent_parameters_channel.py


from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
from mlagents_envs.exception import UnityCommunicationException
from mlagents_envs.base_env import AgentId
import uuid
class AgentParametersChannel(SideChannel):
"""
This is the SideChannel for sending agent-specific parameters to Unity.
You can send parameters to an environment with the command
set_float_parameter.
"""
def __init__(self) -> None:
channel_id = uuid.UUID(("534c891e-810f-11ea-a9d0-822485860401"))
super().__init__(channel_id)
def on_message_received(self, msg: IncomingMessage) -> None:
raise UnityCommunicationException(
"The EnvironmentParametersChannel received a message from Unity, "
+ "this should not have happend."
)
def set_float_parameter(self, agent_id: AgentId, key: str, value: float) -> None:
"""
Sets a float environment parameter in the Unity Environment.
:param key: The string identifier of the parameter.
:param value: The float value of the parameter.
"""
msg = OutgoingMessage()
msg.write_int32(agent_id)
msg.write_string(key)
msg.write_float32(value)
super().queue_message_to_send(msg)
正在加载...
取消
保存