Andrew Cohen
5 年前
当前提交
a7a372b9
共有 18 个文件被更改,包括 333 次插入 和 22 次删除
-
12Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorSettings.cs
-
3com.unity.ml-agents/CHANGELOG.md
-
28com.unity.ml-agents/Runtime/Academy.cs
-
16com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs
-
28com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
-
7com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs
-
4docs/Getting-Started.md
-
7docs/Using-Tensorboard.md
-
18ml-agents-envs/mlagents_envs/environment.py
-
21ml-agents/mlagents/trainers/agent_processor.py
-
10ml-agents/mlagents/trainers/env_manager.py
-
6ml-agents/mlagents/trainers/simple_env_manager.py
-
19ml-agents/mlagents/trainers/subprocess_env_manager.py
-
34ml-agents/mlagents/trainers/tests/test_agent_processor.py
-
11ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py
-
72com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs
-
11com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs.meta
-
48ml-agents-envs/mlagents_envs/side_channel/stats_side_channel.py
|
|||
using System; |
|||
namespace MLAgents.SideChannels |
|||
{ |
|||
/// <summary>
|
|||
/// Determines the behavior of how multiple stats within the same summary period are combined.
|
|||
/// </summary>
|
|||
public enum StatAggregationMethod |
|||
{ |
|||
/// <summary>
|
|||
/// Values within the summary period are averaged before reporting.
|
|||
/// Note that values from the same C# environment in the same step may replace each other.
|
|||
/// </summary>
|
|||
Average = 0, |
|||
|
|||
/// <summary>
|
|||
/// Only the most recent value is reported.
|
|||
/// To avoid conflicts between multiple environments, the ML Agents environment will only
|
|||
/// keep stats from worker index 0.
|
|||
/// </summary>
|
|||
MostRecent = 1 |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Add stats (key-value pairs) for reporting. The ML Agents environment will send these to a StatsReporter
|
|||
/// instance, which means the values will appear in the Tensorboard summary, as well as trainer gauges.
|
|||
/// Note that stats are only written every summary_frequency steps; See <see cref="StatAggregationMethod"/>
|
|||
/// for options on how multiple values are handled.
|
|||
/// </summary>
|
|||
public class StatsSideChannel : SideChannel |
|||
{ |
|||
const string k_StatsSideChannelDefaultId = "a1d8f7b7-cec8-50f9-b78b-d3e165a78520"; |
|||
|
|||
/// <summary>
|
|||
/// Initializes the side channel with the provided channel ID.
|
|||
/// The constructor is internal because only one instance is
|
|||
/// supported at a time, and is created by the Academy.
|
|||
/// </summary>
|
|||
internal StatsSideChannel() |
|||
{ |
|||
ChannelId = new Guid(k_StatsSideChannelDefaultId); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Add a stat value for reporting. This will appear in the Tensorboard summary and trainer gauges.
|
|||
/// You can nest stats in Tensorboard with "/".
|
|||
/// Note that stats are only written to Tensorboard each summary_frequency steps; if a stat is
|
|||
/// received multiple times, only the most recent version is used.
|
|||
/// To avoid conflicts between multiple environments, only stats from worker index 0 are used.
|
|||
/// </summary>
|
|||
/// <param name="key">The stat name.</param>
|
|||
/// <param name="value">The stat value. You can nest stats in Tensorboard by using "/". </param>
|
|||
/// <param name="aggregationMethod">How multiple values should be treated.</param>
|
|||
public void AddStat( |
|||
string key, float value, StatAggregationMethod aggregationMethod = StatAggregationMethod.Average |
|||
) |
|||
{ |
|||
using (var msg = new OutgoingMessage()) |
|||
{ |
|||
msg.WriteString(key); |
|||
msg.WriteFloat32(value); |
|||
msg.WriteInt32((int)aggregationMethod); |
|||
QueueMessageToSend(msg); |
|||
} |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override void OnMessageReceived(IncomingMessage msg) |
|||
{ |
|||
throw new UnityAgentsException("StatsSideChannel should never receive messages."); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 83a07fdb9e8f04536908a51447dfe548 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
from mlagents_envs.side_channel import SideChannel, IncomingMessage |
|||
import uuid |
|||
from typing import Dict, Tuple |
|||
from enum import Enum |
|||
|
|||
|
|||
# Determines the behavior of how multiple stats within the same summary period are combined. |
|||
class StatsAggregationMethod(Enum): |
|||
# Values within the summary period are averaged before reporting. |
|||
AVERAGE = 0 |
|||
|
|||
# Only the most recent value is reported. |
|||
MOST_RECENT = 1 |
|||
|
|||
|
|||
class StatsSideChannel(SideChannel): |
|||
""" |
|||
Side channel that receives (string, float) pairs from the environment, so that they can eventually |
|||
be passed to a StatsReporter. |
|||
""" |
|||
|
|||
def __init__(self) -> None: |
|||
# >>> uuid.uuid5(uuid.NAMESPACE_URL, "com.unity.ml-agents/StatsSideChannel") |
|||
# UUID('a1d8f7b7-cec8-50f9-b78b-d3e165a78520') |
|||
super().__init__(uuid.UUID("a1d8f7b7-cec8-50f9-b78b-d3e165a78520")) |
|||
|
|||
self.stats: Dict[str, Tuple[float, StatsAggregationMethod]] = {} |
|||
|
|||
def on_message_received(self, msg: IncomingMessage) -> None: |
|||
""" |
|||
Receive the message from the environment, and save it for later retrieval. |
|||
:param msg: |
|||
:return: |
|||
""" |
|||
key = msg.read_string() |
|||
val = msg.read_float32() |
|||
agg_type = StatsAggregationMethod(msg.read_int32()) |
|||
|
|||
self.stats[key] = (val, agg_type) |
|||
|
|||
def get_and_reset_stats(self) -> Dict[str, Tuple[float, StatsAggregationMethod]]: |
|||
""" |
|||
Returns the current stats, and resets the internal storage of the stats. |
|||
:return: |
|||
""" |
|||
s = self.stats |
|||
self.stats = {} |
|||
return s |
撰写
预览
正在加载...
取消
保存
Reference in new issue