GitHub
5 年前
当前提交
652488d9
共有 21 个文件被更改,包括 167 次插入 和 80 次删除
-
4.circleci/config.yml
-
4gym-unity/gym_unity/tests/test_gym.py
-
20ml-agents-envs/mlagents/envs/brain.py
-
6ml-agents-envs/mlagents/envs/tests/test_brain.py
-
7ml-agents/mlagents/trainers/bc/policy.py
-
10ml-agents/mlagents/trainers/buffer.py
-
5ml-agents/mlagents/trainers/components/bc/module.py
-
7ml-agents/mlagents/trainers/components/reward_signals/__init__.py
-
2ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
-
8ml-agents/mlagents/trainers/demo_loader.py
-
8ml-agents/mlagents/trainers/ppo/trainer.py
-
2ml-agents/mlagents/trainers/rl_trainer.py
-
49ml-agents/mlagents/trainers/tests/__init__.py
-
32ml-agents/mlagents/trainers/tests/mock_brain.py
-
12ml-agents/mlagents/trainers/tests/test_bc.py
-
6ml-agents/mlagents/trainers/tests/test_policy.py
-
33ml-agents/mlagents/trainers/tests/test_ppo.py
-
2ml-agents/mlagents/trainers/tests/test_reward_signals.py
-
8ml-agents/mlagents/trainers/tests/test_rl_trainer.py
-
18ml-agents/mlagents/trainers/tests/test_sac.py
-
4ml-agents/mlagents/trainers/tf_policy.py
|
|||
import os |
|||
|
|||
# Opt-in checking mode to ensure that we always create numpy arrays using float32 |
|||
if os.getenv("TEST_ENFORCE_NUMPY_FLOAT32"): |
|||
# This file is importer by pytest multiple times, but this breaks the patching |
|||
# Removing the env variable seems the easiest way to prevent this. |
|||
del os.environ["TEST_ENFORCE_NUMPY_FLOAT32"] |
|||
import numpy as np |
|||
import traceback |
|||
|
|||
__old_np_array = np.array |
|||
__old_np_zeros = np.zeros |
|||
__old_np_ones = np.ones |
|||
|
|||
def _check_no_float64(arr, kwargs_dtype): |
|||
if arr.dtype == np.float64: |
|||
tb = traceback.extract_stack() |
|||
# tb[-1] in the stack is this function. |
|||
# tb[-2] is the wrapper function, e.g. np_array_no_float64 |
|||
# we want the calling function, so use tb[-3] |
|||
filename = tb[-3].filename |
|||
# Only raise if this came from mlagents code, not tensorflow |
|||
if ( |
|||
"ml-agents/mlagents" in filename |
|||
or "ml-agents-envs/mlagents" in filename |
|||
) and "tensorflow_to_barracuda.py" not in filename: |
|||
raise ValueError( |
|||
f"float64 array created. Set dtype=np.float32 instead of current dtype={kwargs_dtype}. " |
|||
f"Run pytest with TEST_ENFORCE_NUMPY_FLOAT32=1 to confirm fix." |
|||
) |
|||
|
|||
def np_array_no_float64(*args, **kwargs): |
|||
res = __old_np_array(*args, **kwargs) |
|||
_check_no_float64(res, kwargs.get("dtype")) |
|||
return res |
|||
|
|||
def np_zeros_no_float64(*args, **kwargs): |
|||
res = __old_np_zeros(*args, **kwargs) |
|||
_check_no_float64(res, kwargs.get("dtype")) |
|||
return res |
|||
|
|||
def np_ones_no_float64(*args, **kwargs): |
|||
res = __old_np_ones(*args, **kwargs) |
|||
_check_no_float64(res, kwargs.get("dtype")) |
|||
return res |
|||
|
|||
np.array = np_array_no_float64 |
|||
np.zeros = np_zeros_no_float64 |
|||
np.ones = np_ones_no_float64 |
撰写
预览
正在加载...
取消
保存
Reference in new issue