|
|
|
|
|
|
from mlagents_envs.logging_util import get_logger |
|
|
|
from mlagents.trainers.sac_transfer.network import SACTransferPolicyNetwork, SACTransferTargetNetwork |
|
|
|
from mlagents.trainers.sac.network import SACPolicyNetwork, SACTargetNetwork |
|
|
|
from mlagents.trainers.models import ModelUtils |
|
|
|
from mlagents.trainers.models import ModelUtils, ScheduleType |
|
|
|
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer |
|
|
|
from mlagents.trainers.policy.tf_policy import TFPolicy |
|
|
|
from mlagents.trainers.policy.transfer_policy import TransferPolicy |
|
|
|
|
|
|
self.train_model = hyperparameters.train_model |
|
|
|
self.train_policy = hyperparameters.train_policy |
|
|
|
self.train_value = hyperparameters.train_value |
|
|
|
self.model_weight = hyperparameters.model_weight |
|
|
|
# self.model_weight = hyperparameters.model_weight |
|
|
|
|
|
|
|
|
|
|
|
# Transfer |
|
|
|
self.use_transfer = hyperparameters.use_transfer |
|
|
|
|
|
|
int(max_step), |
|
|
|
min_value=1e-10, |
|
|
|
) |
|
|
|
# self.model_weight = ModelUtils.create_schedule( |
|
|
|
# ModelUtils.ScheduleType.LINEAR, |
|
|
|
# hyperparameters.model_weight, |
|
|
|
# self.policy.global_step, |
|
|
|
# int(max_step), |
|
|
|
# min_value=1e-10, |
|
|
|
# ) |
|
|
|
self._create_losses( |
|
|
|
self.policy_network.q1_heads, |
|
|
|
self.policy_network.q2_heads, |
|
|
|
|
|
|
"Policy/Entropy Coeff": "entropy_coef", |
|
|
|
"Policy/Learning Rate": "learning_rate", |
|
|
|
"Policy/Model Learning Rate": "model_learning_rate", |
|
|
|
# "Policy/Model Weight": "model_weight", |
|
|
|
} |
|
|
|
|
|
|
|
if self.predict_return: |
|
|
|
|
|
|
"model_loss": self.model_loss, |
|
|
|
"model_learning_rate": self.model_learning_rate, |
|
|
|
"reward_loss": self.policy.reward_loss, |
|
|
|
# "model_weight": self.model_weight |
|
|
|
}) |
|
|
|
|
|
|
|
def _create_inputs_and_outputs(self) -> None: |
|
|
|
|
|
|
|
|
|
|
# Make sure policy is updated first, then value, then entropy. |
|
|
|
if self.use_transfer: |
|
|
|
value_loss = self.total_value_loss + self.model_weight * self.model_loss |
|
|
|
value_loss = self.total_value_loss + self.model_learning_rate * self.model_loss |
|
|
|
else: |
|
|
|
value_loss = self.total_value_loss |
|
|
|
with tf.control_dependencies([self.update_batch_policy]): |
|
|
|