浏览代码

[skip-ci] small refactors

/bug-failed-api-check
Anupam Bhatnagar 4 年前
当前提交
07b15ae7
共有 5 个文件被更改,包括 22 次插入12 次删除
  1. 18
      ml-agents/mlagents/trainers/optimizer/tf_optimizer.py
  2. 2
      ml-agents/mlagents/trainers/policy/tf_policy.py
  3. 5
      ml-agents/mlagents/trainers/ppo/optimizer.py
  4. 6
      ml-agents/mlagents/trainers/trainer/trainer.py
  5. 3
      ml-agents/mlagents/trainers/trainer_controller.py

18
ml-agents/mlagents/trainers/optimizer/tf_optimizer.py


from mlagents.trainers.components.bc.module import BCModule
try:
import horovod.tensorflow as hvd
except ImportError:
hvd = None
class TFOptimizer(Optimizer): # pylint: disable=W0223
def __init__(self, policy: TFPolicy, trainer_params: Dict[str, Any]):
self.sess = policy.sess

def create_optimizer_op(
self, learning_rate: tf.Tensor, name: str = "Adam"
) -> tf.train.Optimizer:
return tf.train.AdamOptimizer(learning_rate=learning_rate, name=name)
if hvd is not None:
adam_optimizer = tf.train.AdamOptimizer(
learning_rate=learning_rate * hvd.size(), name=name
)
horovod_optimizer = hvd.DistributedOptimizer(adam_optimizer)
else:
adam_optimizer = tf.train.AdamOptimizer(
learning_rate=learning_rate, name=name
)
return horovod_optimizer if hvd is not None else adam_optimizer
def _execute_model(
self, feed_dict: Dict[tf.Tensor, np.ndarray], out_dict: Dict[str, tf.Tensor]

2
ml-agents/mlagents/trainers/policy/tf_policy.py


self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
init = tf.global_variables_initializer()
self.sess.run(init)
self.sess.run(hvd.broadcast_global_variables(0))
def _load_graph(self):
with self.graph.as_default():

self._load_graph()
else:
self._initialize_graph()
self.sess.run(hvd.broadcast_global_variables(0))
def get_weights(self):
with self.graph.as_default():

5
ml-agents/mlagents/trainers/ppo/optimizer.py


from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers.buffer import AgentBuffer
import horovod.tensorflow as hvd
class PPOOptimizer(TFOptimizer):

)
def _create_ppo_optimizer_ops(self):
self.tf_optimizer = self.create_optimizer_op(self.learning_rate * hvd.size())
if hvd is not None:
self.tf_optimizer = hvd.DistributedOptimizer(self.tf_optimizer)
self.tf_optimizer = self.create_optimizer_op(self.learning_rate)
self.grads = self.tf_optimizer.compute_gradients(self.loss)
self.update_batch = self.tf_optimizer.minimize(self.loss)

6
ml-agents/mlagents/trainers/trainer/trainer.py


from mlagents.trainers.policy import Policy
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.timers import hierarchical_timer
import horovod.tensorflow as hvd
logger = logging.getLogger("mlagents.trainers")

"""
Saves the model
"""
if hvd.rank() != 0:
return
self.get_policy(name_behavior_id).save_model(self.get_step)
def export_model(self, name_behavior_id: str) -> None:

if hvd.rank() != 0:
return
policy = self.get_policy(name_behavior_id)
settings = SerializationSettings(policy.model_path, policy.brain.brain_name)
export_policy_model(settings, policy.graph, policy.sess)

3
ml-agents/mlagents/trainers/trainer_controller.py


"""
Exports latest saved models to .nn format for Unity embedding.
"""
if hvd.rank() != 0:
return
for brain_name in self.trainers.keys():
for name_behavior_id in self.brain_name_to_identifier[brain_name]:
self.trainers[brain_name].export_model(name_behavior_id)

正在加载...
取消
保存