|
|
|
|
|
|
from tensorflow.python.client import device_lib |
|
|
|
from mlagents.trainers.brain import BrainParameters |
|
|
|
from mlagents_envs.timers import timed |
|
|
|
from mlagents.trainers.models import EncoderType, LearningRateSchedule |
|
|
|
from mlagents.trainers.ppo.models import PPOModel |
|
|
|
from mlagents.trainers.components.reward_signals import RewardSignal |
|
|
|
from mlagents.trainers.components.reward_signals.reward_signal_factory import ( |
|
|
|
create_reward_signal, |
|
|
|
|
|
|
is_training: bool, |
|
|
|
load: bool, |
|
|
|
): |
|
|
|
self.towers: List[PPOModel] = [] |
|
|
|
self.towers: List[PPOPolicy] = [] |
|
|
|
self.model: Optional[PPOModel] = None |
|
|
|
self.model: Optional[PPOPolicy] = None |
|
|
|
self.total_policy_loss: Optional[tf.Tensor] = None |
|
|
|
self.reward_signal_towers: List[Dict[str, RewardSignal]] = [] |
|
|
|
self.reward_signals: Dict[str, RewardSignal] = {} |
|
|
|
|
|
|
for device in self.devices: |
|
|
|
with tf.device(device): |
|
|
|
self.towers.append( |
|
|
|
PPOModel( |
|
|
|
PPOPolicy( |
|
|
|
seed=seed, |
|
|
|
lr=float(trainer_params["learning_rate"]), |
|
|
|
lr_schedule=LearningRateSchedule( |
|
|
|
trainer_params.get( |
|
|
|
"learning_rate_schedule", "linear" |
|
|
|
) |
|
|
|
), |
|
|
|
h_size=int(trainer_params["hidden_units"]), |
|
|
|
epsilon=float(trainer_params["epsilon"]), |
|
|
|
beta=float(trainer_params["beta"]), |
|
|
|
max_step=float(trainer_params["max_steps"]), |
|
|
|
normalize=trainer_params["normalize"], |
|
|
|
use_recurrent=trainer_params["use_recurrent"], |
|
|
|
num_layers=int(trainer_params["num_layers"]), |
|
|
|
m_size=self.m_size, |
|
|
|
seed=seed, |
|
|
|
stream_names=list(reward_signal_configs.keys()), |
|
|
|
vis_encode_type=EncoderType( |
|
|
|
trainer_params.get("vis_encode_type", "simple") |
|
|
|
), |
|
|
|
trainer_params=trainer_params, |
|
|
|
is_training=is_training, |
|
|
|
load=load, |
|
|
|
) |
|
|
|
) |
|
|
|
self.towers[-1].create_ppo_optimizer() |
|
|
|
|
|
|
for batch, tower, reward_tower in zip( |
|
|
|
device_batches, self.towers, self.reward_signal_towers |
|
|
|
): |
|
|
|
feed_dict.update(self.construct_feed_dict(tower, batch, num_sequences)) |
|
|
|
# feed_dict.update(self.construct_feed_dict(tower, batch, num_sequences)) TODO: Fix multi-GPU optimizer |
|
|
|
stats_needed.update(self.stats_name_to_update_name) |
|
|
|
for _, reward_signal in reward_tower.items(): |
|
|
|
feed_dict.update( |
|
|
|