浏览代码

centralize tensorflow imports

/develop-gpu-test
Chris Elion 5 年前
当前提交
806c77e4
共有 24 个文件被更改,包括 54 次插入119 次删除
  1. 21
      ml-agents/mlagents/trainers/__init__.py
  2. 12
      ml-agents/mlagents/trainers/bc/models.py
  3. 5
      ml-agents/mlagents/trainers/components/bc/model.py
  4. 5
      ml-agents/mlagents/trainers/components/reward_signals/__init__.py
  5. 5
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
  6. 5
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
  7. 5
      ml-agents/mlagents/trainers/components/reward_signals/gail/model.py
  8. 5
      ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
  9. 8
      ml-agents/mlagents/trainers/models.py
  10. 5
      ml-agents/mlagents/trainers/ppo/models.py
  11. 5
      ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py
  12. 5
      ml-agents/mlagents/trainers/ppo/policy.py
  13. 12
      ml-agents/mlagents/trainers/sac/models.py
  14. 6
      ml-agents/mlagents/trainers/sac/policy.py
  15. 5
      ml-agents/mlagents/trainers/tensorflow_to_barracuda.py
  16. 5
      ml-agents/mlagents/trainers/tests/test_bc.py
  17. 6
      ml-agents/mlagents/trainers/tests/test_multigpu.py
  18. 5
      ml-agents/mlagents/trainers/tests/test_ppo.py
  19. 5
      ml-agents/mlagents/trainers/tests/test_sac.py
  20. 5
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py
  21. 23
      ml-agents/mlagents/trainers/tf.py
  22. 5
      ml-agents/mlagents/trainers/tf_policy.py
  23. 5
      ml-agents/mlagents/trainers/trainer.py
  24. 5
      ml-agents/mlagents/trainers/trainer_controller.py

21
ml-agents/mlagents/trainers/__init__.py


import logging
def warnings_as_errors(log_record):
# Raise deprecated warnings as exceptions.
if log_record.levelno == logging.WARNING and "deprecated" in log_record.msg:
merged = log_record.getMessage()
raise RuntimeError(merged)
return True
# TODO only enable this with a environment variable
if False:
logging.getLogger('tensorflow').addFilter(warnings_as_errors)
# TODO better place to put this? move everything to tf.py?
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
tf.disable_v2_behavior()
from mlagents.trainers.tf import tf as tf

12
ml-agents/mlagents/trainers/bc/models.py


try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
if True: # TODO TF2
if True: # TODO TF2
tf_variance_scaling = c_layers.variance_scaling_initializer
tf_flatten = c_layers.flatten

size,
activation=None,
use_bias=False,
kernel_initializer=tf_variance_scaling(
0.01
),
kernel_initializer=tf_variance_scaling(0.01),
)
)
self.action_probs = tf.concat(

5
ml-agents/mlagents/trainers/components/bc/model.py


try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from mlagents.trainers.models import LearningModel

5
ml-agents/mlagents/trainers/components/reward_signals/__init__.py


import numpy as np
import abc
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from mlagents.envs.brain import BrainInfo
from mlagents.trainers.trainer import UnityTrainerException

5
ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py


from typing import List, Tuple
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from mlagents.trainers.models import LearningModel

5
ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py


from typing import Any, Dict, List
import numpy as np
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from mlagents.envs.brain import BrainInfo

5
ml-agents/mlagents/trainers/components/reward_signals/gail/model.py


from typing import Tuple, List
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from mlagents.trainers.models import LearningModel

5
ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py


from typing import Any, Dict, List
import logging
import numpy as np
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from mlagents.envs.brain import BrainInfo
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult

8
ml-agents/mlagents/trainers/models.py


from typing import Callable, List
import numpy as np
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
if True: # TODO TF2
if True: # TODO TF2
tf_variance_scaling = c_layers.variance_scaling_initializer
tf_flatten = c_layers.flatten
tf_rnn = tf.contrib.rnn

5
ml-agents/mlagents/trainers/ppo/models.py


import logging
import numpy as np
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from mlagents.trainers.models import LearningModel, EncoderType, LearningRateSchedule

5
ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py


import logging
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from tensorflow.python.client import device_lib
from mlagents.envs.timers import timed

5
ml-agents/mlagents/trainers/ppo/policy.py


import numpy as np
from typing import Any, Dict, Optional
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from mlagents.envs.timers import timed
from mlagents.envs.brain import BrainInfo, BrainParameters

12
ml-agents/mlagents/trainers/sac/models.py


import logging
import numpy as np
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
if True: # TODO TF2
if True: # TODO TF2
tf_variance_scaling = c_layers.variance_scaling_initializer
tf_flatten = c_layers.flatten

size,
activation=None,
use_bias=False,
kernel_initializer=tf_variance_scaling(
0.01
),
kernel_initializer=tf_variance_scaling(0.01),
)
)
all_logits = tf.concat(

6
ml-agents/mlagents/trainers/sac/policy.py


import logging
from typing import Dict, Any, Optional
import numpy as np
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from mlagents.envs.timers import timed
from mlagents.envs.brain import BrainInfo, BrainParameters

5
ml-agents/mlagents/trainers/tensorflow_to_barracuda.py


from __future__ import print_function
import numpy as np
import struct # convert from Python values and C structs
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
import re
# import barracuda

5
ml-agents/mlagents/trainers/tests/test_bc.py


import os
import numpy as np
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
import yaml
from mlagents.trainers.bc.models import BehavioralCloningModel

6
ml-agents/mlagents/trainers/tests/test_multigpu.py


import unittest.mock as mock
import pytest
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
import yaml
from mlagents.trainers.ppo.multi_gpu_policy import MultiGpuPPOPolicy

5
ml-agents/mlagents/trainers/tests/test_ppo.py


import pytest
import numpy as np
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
import yaml

5
ml-agents/mlagents/trainers/tests/test_sac.py


import yaml
import numpy as np
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from mlagents.trainers.sac.models import SACModel

5
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


from unittest.mock import MagicMock, Mock, patch
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
import yaml
import pytest

23
ml-agents/mlagents/trainers/tf.py


import logging
def warnings_as_errors(log_record):
# Raise deprecated warnings as exceptions.
if log_record.levelno == logging.WARNING and "deprecated" in log_record.msg:
merged = log_record.getMessage()
raise RuntimeError(merged)
return True
# TODO only enable this with a environment variable
if False:
logging.getLogger("tensorflow").addFilter(warnings_as_errors)
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
is_tf2 = True
if is_tf2:
tf.disable_v2_behavior()

5
ml-agents/mlagents/trainers/tf_policy.py


from typing import Any, Dict
import numpy as np
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from mlagents.envs.exception import UnityException
from mlagents.envs.policy import Policy

5
ml-agents/mlagents/trainers/trainer.py


from typing import Dict, List, Deque, Any
import os
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
import numpy as np
from collections import deque, defaultdict

5
ml-agents/mlagents/trainers/trainer_controller.py


from typing import Dict, List, Optional
import numpy as np
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from mlagents.trainers import tf
from time import time
from mlagents.envs.env_manager import EnvironmentStep

正在加载...
取消
保存