浏览代码

Speed up processing large vector observations (#2717)

* faster NaN checks for large float obs

* add unit test, remove is_nan check on obs

* add comment

* rename var
/develop-gpu-test
GitHub 5 年前
当前提交
af3ed33f
共有 2 个文件被更改,包括 88 次插入7 次删除
  1. 33
      ml-agents-envs/mlagents/envs/brain.py
  2. 62
      ml-agents-envs/mlagents/envs/tests/test_brain.py

33
ml-agents-envs/mlagents/envs/brain.py


logger.warning(
"An agent had a NaN reward for brain " + brain_params.brain_name
)
if any(np.isnan(x.stacked_vector_observation).any() for x in agent_info_list):
logger.warning(
"An agent had a NaN observation for brain " + brain_params.brain_name
)
if len(agent_info_list) == 0:
vector_obs = np.zeros(

)
)
else:
vector_obs = np.nan_to_num(
np.array([x.stacked_vector_observation for x in agent_info_list])
)
stacked_obs = []
has_nan = False
has_inf = False
for x in agent_info_list:
np_obs = np.array(x.stacked_vector_observation)
# Check for NaNs or infs in the observations
# If there's a NaN in the observations, the dot() result will be NaN
# If there's an Inf (either sign) then the result will be Inf
# See https://stackoverflow.com/questions/6736590/fast-check-for-nan-in-numpy for background
# Note that a very large values (larger than sqrt(float_max)) will result in an Inf value here
# This is OK though, worst case it results in an unnecessary (but harmless) nan_to_num call.
d = np.dot(np_obs, np_obs)
has_nan = has_nan or np.isnan(d)
has_inf = has_inf or not np.isfinite(d)
stacked_obs.append(np_obs)
vector_obs = np.array(stacked_obs)
# In we have any NaN or Infs, use np.nan_to_num to replace these with finite values
if has_nan or has_inf:
vector_obs = np.nan_to_num(vector_obs)
if has_nan:
logger.warning(
f"An agent had a NaN observation for brain {brain_params.brain_name}"
)
agents = [f"${worker_id}-{x.id}" for x in agent_info_list]
brain_info = BrainInfo(
visual_observation=vis_obs,

62
ml-agents-envs/mlagents/envs/tests/test_brain.py


import logging
import numpy as np
import sys
from unittest import mock
from mlagents.envs.communicator_objects.agent_info_pb2 import AgentInfoProto
from mlagents.envs.brain import BrainInfo, BrainParameters
test_brain = BrainParameters(
brain_name="test_brain",
vector_observation_space_size=3,
num_stacked_vector_observations=1,
camera_resolutions=[],
vector_action_space_size=[],
vector_action_descriptions=[],
vector_action_space_type=1,
)
@mock.patch.object(np, "nan_to_num", wraps=np.nan_to_num)
@mock.patch.object(logging.Logger, "warning")
def test_from_agent_proto_nan(mock_warning, mock_nan_to_num):
agent_info_proto = AgentInfoProto()
agent_info_proto.stacked_vector_observation.extend([1.0, 2.0, float("nan")])
brain_info = BrainInfo.from_agent_proto(1, [agent_info_proto], test_brain)
# nan gets set to 0.0
expected = [1.0, 2.0, 0.0]
assert (brain_info.vector_observations == expected).all()
mock_nan_to_num.assert_called()
mock_warning.assert_called()
@mock.patch.object(np, "nan_to_num", wraps=np.nan_to_num)
@mock.patch.object(logging.Logger, "warning")
def test_from_agent_proto_inf(mock_warning, mock_nan_to_num):
agent_info_proto = AgentInfoProto()
agent_info_proto.stacked_vector_observation.extend([1.0, float("inf"), 0.0])
brain_info = BrainInfo.from_agent_proto(1, [agent_info_proto], test_brain)
# inf should get set to float_max
expected = [1.0, sys.float_info.max, 0.0]
assert (brain_info.vector_observations == expected).all()
mock_nan_to_num.assert_called()
# We don't warn on inf, just NaN
mock_warning.assert_not_called()
@mock.patch.object(np, "nan_to_num", wraps=np.nan_to_num)
@mock.patch.object(logging.Logger, "warning")
def test_from_agent_proto_fast_path(mock_warning, mock_nan_to_num):
"""
Check that all finite values skips the nan_to_num call
"""
agent_info_proto = AgentInfoProto()
agent_info_proto.stacked_vector_observation.extend([1.0, 2.0, 3.0])
brain_info = BrainInfo.from_agent_proto(1, [agent_info_proto], test_brain)
expected = [1.0, 2.0, 3.0]
assert (brain_info.vector_observations == expected).all()
mock_nan_to_num.assert_not_called()
mock_warning.assert_not_called()
正在加载...
取消
保存