浏览代码

new config

/develop/bisim-sac-transfer
yanchaosun 4 年前
当前提交
1ebe7054
共有 2 个文件被更改,包括 15 次插入4 次删除
  1. 3
      config/sac_transfer/ReacherTransfer.yaml
  2. 16
      ml-agents/mlagents/trainers/sac_transfer/optimizer.py

3
config/sac_transfer/ReacherTransfer.yaml


hyperparameters:
learning_rate: 0.003
learning_rate_schedule: linear
model_schedule: constant
model_schedule: linear
model_learning_rate: 0.05
batch_size: 256
buffer_size: 6000000
buffer_init_steps: 0

16
ml-agents/mlagents/trainers/sac_transfer/optimizer.py


from mlagents_envs.logging_util import get_logger
from mlagents.trainers.sac_transfer.network import SACTransferPolicyNetwork, SACTransferTargetNetwork
from mlagents.trainers.sac.network import SACPolicyNetwork, SACTargetNetwork
from mlagents.trainers.models import ModelUtils
from mlagents.trainers.models import ModelUtils, ScheduleType
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.policy.transfer_policy import TransferPolicy

self.train_model = hyperparameters.train_model
self.train_policy = hyperparameters.train_policy
self.train_value = hyperparameters.train_value
self.model_weight = hyperparameters.model_weight
# self.model_weight = hyperparameters.model_weight
# Transfer
self.use_transfer = hyperparameters.use_transfer

int(max_step),
min_value=1e-10,
)
# self.model_weight = ModelUtils.create_schedule(
# ModelUtils.ScheduleType.LINEAR,
# hyperparameters.model_weight,
# self.policy.global_step,
# int(max_step),
# min_value=1e-10,
# )
self._create_losses(
self.policy_network.q1_heads,
self.policy_network.q2_heads,

"Policy/Entropy Coeff": "entropy_coef",
"Policy/Learning Rate": "learning_rate",
"Policy/Model Learning Rate": "model_learning_rate",
# "Policy/Model Weight": "model_weight",
}
if self.predict_return:

"model_loss": self.model_loss,
"model_learning_rate": self.model_learning_rate,
"reward_loss": self.policy.reward_loss,
# "model_weight": self.model_weight
})
def _create_inputs_and_outputs(self) -> None:

# Make sure policy is updated first, then value, then entropy.
if self.use_transfer:
value_loss = self.total_value_loss + self.model_weight * self.model_loss
value_loss = self.total_value_loss + self.model_learning_rate * self.model_loss
else:
value_loss = self.total_value_loss
with tf.control_dependencies([self.update_batch_policy]):

正在加载...
取消
保存