import os # Opt-in checking mode to ensure that we always create numpy arrays using float32 if os.getenv("TEST_ENFORCE_NUMPY_FLOAT32"): # This file is importer by pytest multiple times, but this breaks the patching # Removing the env variable seems the easiest way to prevent this. del os.environ["TEST_ENFORCE_NUMPY_FLOAT32"] import numpy as np import traceback __old_np_array = np.array __old_np_zeros = np.zeros __old_np_ones = np.ones def _check_no_float64(arr, kwargs_dtype): if arr.dtype == np.float64: tb = traceback.extract_stack() # tb[-1] in the stack is this function. # tb[-2] is the wrapper function, e.g. np_array_no_float64 # we want the calling function, so use tb[-3] filename = tb[-3].filename # Only raise if this came from mlagents code if ( "ml-agents/mlagents" in filename or "ml-agents-envs/mlagents" in filename ) and "tensorflow_to_barracuda.py" not in filename: raise ValueError( f"float64 array created. Set dtype=np.float32 instead of current dtype={kwargs_dtype}. " f"Run pytest with TEST_ENFORCE_NUMPY_FLOAT32=1 to confirm fix." ) def np_array_no_float64(*args, **kwargs): res = __old_np_array(*args, **kwargs) _check_no_float64(res, kwargs.get("dtype")) return res def np_zeros_no_float64(*args, **kwargs): res = __old_np_zeros(*args, **kwargs) _check_no_float64(res, kwargs.get("dtype")) return res def np_ones_no_float64(*args, **kwargs): res = __old_np_ones(*args, **kwargs) _check_no_float64(res, kwargs.get("dtype")) return res np.array = np_array_no_float64 np.zeros = np_zeros_no_float64 np.ones = np_ones_no_float64