您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
59 行
1.7 KiB
59 行
1.7 KiB
import uuid
|
|
from typing import Tuple, List, Mapping
|
|
from enum import Enum
|
|
from collections import defaultdict
|
|
|
|
from mlagents_envs.side_channel import SideChannel, IncomingMessage
|
|
|
|
|
|
# Determines the behavior of how multiple stats within the same summary period are combined.
|
|
class StatsAggregationMethod(Enum):
|
|
# Values within the summary period are averaged before reporting.
|
|
AVERAGE = 0
|
|
|
|
# Only the most recent value is reported.
|
|
MOST_RECENT = 1
|
|
|
|
# Values within the summary period are summed up before reporting.
|
|
SUM = 2
|
|
|
|
|
|
StatList = List[Tuple[float, StatsAggregationMethod]]
|
|
EnvironmentStats = Mapping[str, StatList]
|
|
|
|
|
|
class StatsSideChannel(SideChannel):
|
|
"""
|
|
Side channel that receives (string, float) pairs from the environment, so that they can eventually
|
|
be passed to a StatsReporter.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
# >>> uuid.uuid5(uuid.NAMESPACE_URL, "com.unity.ml-agents/StatsSideChannel")
|
|
# UUID('a1d8f7b7-cec8-50f9-b78b-d3e165a78520')
|
|
super().__init__(uuid.UUID("a1d8f7b7-cec8-50f9-b78b-d3e165a78520"))
|
|
|
|
self.stats: EnvironmentStats = defaultdict(list)
|
|
|
|
def on_message_received(self, msg: IncomingMessage) -> None:
|
|
"""
|
|
Receive the message from the environment, and save it for later retrieval.
|
|
|
|
:param msg:
|
|
:return:
|
|
"""
|
|
key = msg.read_string()
|
|
val = msg.read_float32()
|
|
agg_type = StatsAggregationMethod(msg.read_int32())
|
|
|
|
self.stats[key].append((val, agg_type))
|
|
|
|
def get_and_reset_stats(self) -> EnvironmentStats:
|
|
"""
|
|
Returns the current stats, and resets the internal storage of the stats.
|
|
|
|
:return:
|
|
"""
|
|
s = self.stats
|
|
self.stats = defaultdict(list)
|
|
return s
|