GitHub
5 年前
当前提交
35c995e9
共有 130 个文件被更改,包括 1384 次插入 和 1423 次删除
-
4.circleci/config.yml
-
11CONTRIBUTING.md
-
3README.md
-
1UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
-
1UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs
-
8UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
-
9UnitySDK/Assets/ML-Agents/Editor/Tests/TimerTest.cs
-
7UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAcademy.cs
-
8UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
-
8UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs
-
2UnitySDK/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs
-
6UnitySDK/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs
-
5UnitySDK/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
-
12UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAcademy.cs
-
4UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
-
14UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs
-
13UnitySDK/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs
-
3UnitySDK/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAcademy.cs
-
9UnitySDK/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
-
4UnitySDK/Assets/ML-Agents/Examples/Soccer/Scripts/SoccerAcademy.cs
-
2UnitySDK/Assets/ML-Agents/Examples/Soccer/Scripts/SoccerFieldArea.cs
-
5UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAcademy.cs
-
8UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
-
6UnitySDK/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAcademy.cs
-
14UnitySDK/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
-
11UnitySDK/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs
-
197UnitySDK/Assets/ML-Agents/Scripts/Academy.cs
-
52UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlInitializationOutput.cs
-
119UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlInput.cs
-
44UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/UnityRlOutput.cs
-
23UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
-
126UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs
-
34UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs
-
70UnitySDK/Assets/ML-Agents/Scripts/Timer.cs
-
1UnitySDK/UnitySDK.sln.DotSettings
-
38docs/Basic-Guide.md
-
23docs/Getting-Started-with-Balance-Ball.md
-
6docs/Installation-Windows.md
-
4docs/Installation.md
-
12docs/Learning-Environment-Design-Academy.md
-
6docs/Learning-Environment-Design.md
-
28docs/Learning-Environment-Examples.md
-
18docs/Learning-Environment-Executable.md
-
13docs/Migrating.md
-
81docs/Python-API.md
-
5docs/Training-Curriculum-Learning.md
-
4docs/Training-Generalized-Reinforcement-Learning-Agents.md
-
17docs/Training-ML-Agents.md
-
2docs/Training-on-Amazon-Web-Service.md
-
24docs/Using-Virtual-Environment.md
-
4gym-unity/gym_unity/tests/test_gym.py
-
14ml-agents-envs/mlagents/envs/base_unity_environment.py
-
21ml-agents-envs/mlagents/envs/brain.py
-
17ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_initialization_output_pb2.py
-
14ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_initialization_output_pb2.pyi
-
37ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.py
-
18ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_input_pb2.pyi
-
19ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.py
-
6ml-agents-envs/mlagents/envs/communicator_objects/unity_rl_output_pb2.pyi
-
11ml-agents-envs/mlagents/envs/env_manager.py
-
124ml-agents-envs/mlagents/envs/environment.py
-
29ml-agents-envs/mlagents/envs/simple_env_manager.py
-
71ml-agents-envs/mlagents/envs/subprocess_env_manager.py
-
6ml-agents-envs/mlagents/envs/tests/test_brain.py
-
35ml-agents-envs/mlagents/envs/tests/test_subprocess_env_manager.py
-
7ml-agents/mlagents/trainers/bc/policy.py
-
17ml-agents/mlagents/trainers/bc/trainer.py
-
470ml-agents/mlagents/trainers/buffer.py
-
15ml-agents/mlagents/trainers/components/bc/module.py
-
7ml-agents/mlagents/trainers/components/reward_signals/__init__.py
-
2ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
-
5ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
-
9ml-agents/mlagents/trainers/curriculum.py
-
100ml-agents/mlagents/trainers/demo_loader.py
-
70ml-agents/mlagents/trainers/learn.py
-
8ml-agents/mlagents/trainers/meta_curriculum.py
-
7ml-agents/mlagents/trainers/ppo/policy.py
-
61ml-agents/mlagents/trainers/ppo/trainer.py
-
44ml-agents/mlagents/trainers/rl_trainer.py
-
47ml-agents/mlagents/trainers/sac/trainer.py
-
49ml-agents/mlagents/trainers/tests/__init__.py
-
48ml-agents/mlagents/trainers/tests/mock_brain.py
-
12ml-agents/mlagents/trainers/tests/test_bc.py
-
71ml-agents/mlagents/trainers/tests/test_buffer.py
-
8ml-agents/mlagents/trainers/tests/test_curriculum.py
-
4ml-agents/mlagents/trainers/tests/test_demo_loader.py
-
4ml-agents/mlagents/trainers/tests/test_learn.py
-
9ml-agents/mlagents/trainers/tests/test_meta_curriculum.py
-
6ml-agents/mlagents/trainers/tests/test_policy.py
-
51ml-agents/mlagents/trainers/tests/test_ppo.py
-
4ml-agents/mlagents/trainers/tests/test_reward_signals.py
-
22ml-agents/mlagents/trainers/tests/test_rl_trainer.py
-
63ml-agents/mlagents/trainers/tests/test_sac.py
-
4ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
2ml-agents/mlagents/trainers/tests/test_trainer_controller.py
-
16ml-agents/mlagents/trainers/tf_policy.py
-
4ml-agents/mlagents/trainers/trainer_controller.py
-
3protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_initialization_output.proto
-
6protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_input.proto
-
1protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_output.proto
|
|||
import os |
|||
|
|||
# Opt-in checking mode to ensure that we always create numpy arrays using float32 |
|||
if os.getenv("TEST_ENFORCE_NUMPY_FLOAT32"): |
|||
# This file is importer by pytest multiple times, but this breaks the patching |
|||
# Removing the env variable seems the easiest way to prevent this. |
|||
del os.environ["TEST_ENFORCE_NUMPY_FLOAT32"] |
|||
import numpy as np |
|||
import traceback |
|||
|
|||
__old_np_array = np.array |
|||
__old_np_zeros = np.zeros |
|||
__old_np_ones = np.ones |
|||
|
|||
def _check_no_float64(arr, kwargs_dtype): |
|||
if arr.dtype == np.float64: |
|||
tb = traceback.extract_stack() |
|||
# tb[-1] in the stack is this function. |
|||
# tb[-2] is the wrapper function, e.g. np_array_no_float64 |
|||
# we want the calling function, so use tb[-3] |
|||
filename = tb[-3].filename |
|||
# Only raise if this came from mlagents code, not tensorflow |
|||
if ( |
|||
"ml-agents/mlagents" in filename |
|||
or "ml-agents-envs/mlagents" in filename |
|||
) and "tensorflow_to_barracuda.py" not in filename: |
|||
raise ValueError( |
|||
f"float64 array created. Set dtype=np.float32 instead of current dtype={kwargs_dtype}. " |
|||
f"Run pytest with TEST_ENFORCE_NUMPY_FLOAT32=1 to confirm fix." |
|||
) |
|||
|
|||
def np_array_no_float64(*args, **kwargs): |
|||
res = __old_np_array(*args, **kwargs) |
|||
_check_no_float64(res, kwargs.get("dtype")) |
|||
return res |
|||
|
|||
def np_zeros_no_float64(*args, **kwargs): |
|||
res = __old_np_zeros(*args, **kwargs) |
|||
_check_no_float64(res, kwargs.get("dtype")) |
|||
return res |
|||
|
|||
def np_ones_no_float64(*args, **kwargs): |
|||
res = __old_np_ones(*args, **kwargs) |
|||
_check_no_float64(res, kwargs.get("dtype")) |
|||
return res |
|||
|
|||
np.array = np_array_no_float64 |
|||
np.zeros = np_zeros_no_float64 |
|||
np.ones = np_ones_no_float64 |
部分文件因为文件数量过多而无法显示
撰写
预览
正在加载...
取消
保存
Reference in new issue