浏览代码

Fix tests

/reward-dist
Arthur Juliani 4 年前
当前提交
372c784c
共有 3 个文件被更改,包括 12 次插入10 次删除
  1. 14
      ml-agents/mlagents/trainers/stats.py
  2. 3
      ml-agents/mlagents/trainers/tests/test_agent_processor.py
  3. 5
      ml-agents/mlagents/trainers/tests/test_stats.py

14
ml-agents/mlagents/trainers/stats.py


std: float
num: int
sum: float
full_dist: np.array
full_dist: List[float]
return StatsSummary(
0.0, 0.0, 0, 0.0, np.zeros(1), StatsAggregationMethod.AVERAGE
)
return StatsSummary(0.0, 0.0, 0, 0.0, [0.0], StatsAggregationMethod.AVERAGE)
@property
def aggregated_value(self):

)
if key == "Environment/Cumulative Reward":
self.summary_writers[category].add_histogram(
f"{key}_hist", value.full_dist, step
f"{key}_hist", np.array(value.full_dist), step
)
self.summary_writers[category].flush()

if len(stat_values) == 0:
return StatsSummary.empty()
if key == "Environment/Cumulative Reward":
full = np.array(stat_values)
else:
full = np.zeros(1)
full_dist=full,
full_dist=stat_values,
aggregation_method=StatsReporter.stats_aggregation[self.category][key],
)

3
ml-agents/mlagents/trainers/tests/test_agent_processor.py


std=mock.ANY,
num=2,
sum=4.0,
full_dist=mock.ANY,
aggregation_method=StatsAggregationMethod.AVERAGE,
),
"most_recent": StatsSummary(

sum=4.0,
full_dist=mock.ANY,
aggregation_method=StatsAggregationMethod.MOST_RECENT,
),
"summed": StatsSummary(

sum=4.2,
full_dist=mock.ANY,
aggregation_method=StatsAggregationMethod.SUM,
),
}

5
ml-agents/mlagents/trainers/tests/test_stats.py


std=1.0,
num=1,
sum=1.0,
full_dist=[0.0],
aggregation_method=StatsAggregationMethod.AVERAGE,
)
tb_writer.write_stats("category1", {"key1": statssummary1}, 10)

std=1.0,
num=1,
sum=1.0,
full_dist=[0.0],
aggregation_method=StatsAggregationMethod.AVERAGE,
)
tb_writer.write_stats("category1", {"key1": statssummary1}, 10)

std=1.0,
num=1,
sum=1.0,
full_dist=[1.0],
aggregation_method=StatsAggregationMethod.AVERAGE,
)
console_writer.write_stats(

std=0.0,
num=1,
sum=0.0,
full_dist=[0.0],
aggregation_method=StatsAggregationMethod.AVERAGE,
)
console_writer.write_stats(

std=1.0,
num=1,
sum=1.0,
full_dist=[1.0],
aggregation_method=StatsAggregationMethod.AVERAGE,
)
console_writer.write_stats(

正在加载...
取消
保存