您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
73 行
2.5 KiB
73 行
2.5 KiB
from typing import List
|
|
import logging
|
|
import numpy as np
|
|
from unittest import mock
|
|
|
|
from mlagents.envs.communicator_objects.agent_info_pb2 import AgentInfoProto
|
|
from mlagents.envs.communicator_objects.observation_pb2 import (
|
|
ObservationProto,
|
|
NONE as COMPRESSION_TYPE_NONE,
|
|
)
|
|
from mlagents.envs.brain import BrainInfo, BrainParameters
|
|
|
|
test_brain = BrainParameters(
|
|
brain_name="test_brain",
|
|
vector_observation_space_size=3,
|
|
camera_resolutions=[],
|
|
vector_action_space_size=[],
|
|
vector_action_descriptions=[],
|
|
vector_action_space_type=1,
|
|
)
|
|
|
|
|
|
def _make_agent_info_proto(vector_obs: List[float]) -> AgentInfoProto:
|
|
obs = ObservationProto(
|
|
float_data=ObservationProto.FloatData(data=vector_obs),
|
|
shape=[len(vector_obs)],
|
|
compression_type=COMPRESSION_TYPE_NONE,
|
|
)
|
|
agent_info_proto = AgentInfoProto(observations=[obs])
|
|
return agent_info_proto
|
|
|
|
|
|
@mock.patch.object(np, "nan_to_num", wraps=np.nan_to_num)
|
|
@mock.patch.object(logging.Logger, "warning")
|
|
def test_from_agent_proto_nan(mock_warning, mock_nan_to_num):
|
|
agent_info_proto = _make_agent_info_proto([1.0, 2.0, float("nan")])
|
|
|
|
brain_info = BrainInfo.from_agent_proto(1, [agent_info_proto], test_brain)
|
|
# nan gets set to 0.0
|
|
expected = [1.0, 2.0, 0.0]
|
|
assert (brain_info.vector_observations == expected).all()
|
|
mock_nan_to_num.assert_called()
|
|
mock_warning.assert_called()
|
|
|
|
|
|
@mock.patch.object(np, "nan_to_num", wraps=np.nan_to_num)
|
|
@mock.patch.object(logging.Logger, "warning")
|
|
def test_from_agent_proto_inf(mock_warning, mock_nan_to_num):
|
|
agent_info_proto = _make_agent_info_proto([1.0, float("inf"), 0.0])
|
|
|
|
brain_info = BrainInfo.from_agent_proto(1, [agent_info_proto], test_brain)
|
|
# inf should get set to float32_max
|
|
float32_max = np.finfo(np.float32).max
|
|
expected = [1.0, float32_max, 0.0]
|
|
assert (brain_info.vector_observations == expected).all()
|
|
mock_nan_to_num.assert_called()
|
|
# We don't warn on inf, just NaN
|
|
mock_warning.assert_not_called()
|
|
|
|
|
|
@mock.patch.object(np, "nan_to_num", wraps=np.nan_to_num)
|
|
@mock.patch.object(logging.Logger, "warning")
|
|
def test_from_agent_proto_fast_path(mock_warning, mock_nan_to_num):
|
|
"""
|
|
Check that all finite values skips the nan_to_num call
|
|
"""
|
|
agent_info_proto = _make_agent_info_proto([1.0, 2.0, 3.0])
|
|
|
|
brain_info = BrainInfo.from_agent_proto(1, [agent_info_proto], test_brain)
|
|
expected = [1.0, 2.0, 3.0]
|
|
assert (brain_info.vector_observations == expected).all()
|
|
mock_nan_to_num.assert_not_called()
|
|
mock_warning.assert_not_called()
|