浏览代码

passing sampler configs to c#

/sampler-refactor-copy
Andrew Cohen 5 年前
当前提交
fe0a077e
共有 5 个文件被更改,包括 68 次插入2 次删除
  1. 6
      Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
  2. 27
      com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
  3. 15
      ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
  4. 17
      ml-agents/mlagents/trainers/learn.py
  5. 5
      ml-agents/mlagents/trainers/subprocess_env_manager.py

6
Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs


{
m_BallRb = ball.GetComponent<Rigidbody>();
m_ResetParams = Academy.Instance.EnvironmentParameters;
var samplerType = m_ResetParams.GetWithDefault("mass-sampler-type", -1.0f);
var min = m_ResetParams.GetWithDefault("mass-min", -1.0f);
var max = m_ResetParams.GetWithDefault("mass-max", -1.0f);
Debug.Log(samplerType);
Debug.Log(min);
Debug.Log(max);
SetResetParameters();
}

27
com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs


/// </summary>
internal enum EnvironmentDataTypes
{
Float = 0
Float = 0,
Sampler = 1
}
/// <summary>

m_RegisteredActions.TryGetValue(key, out action);
action?.Invoke(value);
}
else if ((int)EnvironmentDataTypes.Sampler == type)
{
var samplerType = msg.ReadFloat32();
var statOne = msg.ReadFloat32();
var statTwo = msg.ReadFloat32();
m_Parameters[key+"-sampler-type"] = samplerType;
m_Parameters[key+"-min"] = statOne;
m_Parameters[key+"-max"] = statTwo;
}
else
{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");

/// <param name="defaultValue">Default value to return.</param>
/// <returns></returns>
public float GetWithDefault(string key, float defaultValue)
{
float valueOut;
bool hasKey = m_Parameters.TryGetValue(key, out valueOut);
return hasKey ? valueOut : defaultValue;
}
/// <summary>
/// Returns the parameter value associated with the provided key. Returns the default
/// value if one doesn't exist.
/// </summary>
/// <param name="key">Parameter key.</param>
/// <param name="defaultValue">Default value to return.</param>
/// <returns></returns>
public float GetListWithDefault(string key, float defaultValue)
{
float valueOut;
bool hasKey = m_Parameters.TryGetValue(key, out valueOut);

15
ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py


from mlagents_envs.exception import UnityCommunicationException
import uuid
from enum import IntEnum
from typing import List
class EnvironmentParametersChannel(SideChannel):

class EnvironmentDataTypes(IntEnum):
FLOAT = 0
SAMPLER = 1
def __init__(self) -> None:
channel_id = uuid.UUID(("534c891e-810f-11ea-a9d0-822485860400"))

msg.write_int32(self.EnvironmentDataTypes.FLOAT)
msg.write_float32(value)
super().queue_message_to_send(msg)
def set_sampler_parameters(self, key: str, values: List[float]) -> None:
"""
Sets a float environment parameter in the Unity Environment.
:param key: The string identifier of the parameter.
:param value: The float value of the parameter.
"""
msg = OutgoingMessage()
msg.write_string(key)
msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
for value in values:
msg.write_float32(value)
super().queue_message_to_send(msg)

17
ml-agents/mlagents/trainers/learn.py


maybe_meta_curriculum = try_create_meta_curriculum(
options.curriculum, env_manager, checkpoint_settings.lesson
)
maybe_add_samplers(options.parameter_randomization, env_manager)
trainer_factory = TrainerFactory(
options.behaviors,
checkpoint_settings.run_id,

logger.warning(
f"Unable to save to {timing_path}. Make sure the directory exists"
)
def maybe_add_samplers(sampler_config, env):
restructured_sampler_config: Dict[str, List[float]] = {}
if sampler_config is not None:
for v, config in sampler_config.items():
if v != "resampling-interval":
sampler_type = 0.0 if config["sampler-type"] == "uniform" else 1.0
restructured_sampler_config[v] = [
sampler_type,
config["min_value"],
config["max_value"],
]
env.reset(config=restructured_sampler_config)
def create_sampler_manager(sampler_config, run_seed=None):

5
ml-agents/mlagents/trainers/subprocess_env_manager.py


_send_response(EnvironmentCommand.EXTERNAL_BRAINS, external_brains())
elif req.cmd == EnvironmentCommand.RESET:
for k, v in req.payload.items():
env_parameters.set_float_parameter(k, v)
if isinstance(v, float):
env_parameters.set_float_parameter(k, v)
elif isinstance(v, list):
env_parameters.set_sampler_parameters(k, v)
env.reset()
all_step_result = _generate_all_results()
_send_response(EnvironmentCommand.RESET, all_step_result)

正在加载...
取消
保存