浏览代码

model weights and large transfer learning weight

/develop/bisim-sac-transfer
yanchaosun 4 年前
当前提交
1a9aaaf6
共有 7 个文件被更改,包括 140 次插入11 次删除
  1. 2
      Project/Assets/ML-Agents/Examples/Reacher/Prefabs/NewAgent.prefab
  2. 8
      config/sac_transfer/Reacher.yaml
  3. 9
      config/sac_transfer/ReacherTransfer.yaml
  4. 5
      ml-agents/mlagents/trainers/sac_transfer/optimizer.py
  5. 2
      ml-agents/mlagents/trainers/settings.py
  6. 112
      ml-agents/mlagents/trainers/tests/tsne_plot.ipynb
  7. 13
      run.sh

2
Project/Assets/ML-Agents/Examples/Reacher/Prefabs/NewAgent.prefab


VectorActionSize: 04000000
VectorActionDescriptions: []
VectorActionSpaceType: 1
m_Model: {fileID: 11400000, guid: 79166d8db14004ba0aa1ffdce3f62667, type: 3}
m_Model: {fileID: 11400000, guid: fccf649e8d1dc4ad1b7d28f21b572cbe, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0
m_BehaviorName: Reacher

8
config/sac_transfer/Reacher.yaml


learning_rate_schedule: linear
model_schedule: constant
batch_size: 256
buffer_size: 10000000
buffer_size: 6000000
buffer_init_steps: 0
tau: 0.005
steps_per_update: 20.0

encoder_layers: 1
policy_layers: 2
forward_layers: 1
forward_layers: 2
feature_size: 32
feature_size: 64
action_feature_size: 16
separate_policy_train: true
separate_policy_net: true

gamma: 0.99
strength: 1.0
keep_checkpoints: 5
max_steps: 10000000
max_steps: 6000000
time_horizon: 1000
summary_freq: 60000
threaded: true

9
config/sac_transfer/ReacherTransfer.yaml


Reacher:
trainer_type: sac_transfer
hyperparameters:
learning_rate: 0.0003
learning_rate: 0.003
model_weight: 0.05
batch_size: 256
buffer_size: 6000000
buffer_init_steps: 0

reward_signal_steps_per_update: 20.0
encoder_layers: 1
policy_layers: 2
forward_layers: 1
forward_layers: 2
feature_size: 32
feature_size: 64
action_feature_size: 16
separate_policy_train: true
separate_policy_net: true

train_model: false
load_action: true
train_action: false
transfer_path: "results/reacher-f32/Reacher"
transfer_path: "results/reacher-lr/Reacher"
network_settings:
normalize: true
hidden_units: 128

5
ml-agents/mlagents/trainers/sac_transfer/optimizer.py


self.train_model = hyperparameters.train_model
self.train_policy = hyperparameters.train_policy
self.train_value = hyperparameters.train_value
self.model_weight = hyperparameters.model_weight
# Transfer
self.use_transfer = hyperparameters.use_transfer

)
self.model_learning_rate = ModelUtils.create_schedule(
hyperparameters.model_schedule,
lr,
hyperparameters.model_learning_rate,
self.policy.global_step,
int(max_step),
min_value=1e-10,

# Make sure policy is updated first, then value, then entropy.
if self.use_transfer:
value_loss = self.total_value_loss + self.model_loss
value_loss = self.total_value_loss + self.model_weight * self.model_loss
else:
value_loss = self.total_value_loss
with tf.control_dependencies([self.update_batch_policy]):

2
ml-agents/mlagents/trainers/settings.py


@attr.s(auto_attribs=True)
class SACTransferSettings(SACSettings):
model_schedule: ScheduleType = ScheduleType.LINEAR
model_learning_rate: float = 3.0e-4
model_weight: float = 0.5
separate_value_train: bool = False
separate_policy_train: bool = False

112
ml-agents/mlagents/trainers/tests/tsne_plot.ipynb
文件差异内容过多而无法显示
查看文件

13
run.sh


#!/bin/bash
# mlagents-learn config/sac/3DBallHard.yaml --run-id=hardball_sac --env=envs/3dballhard_qudrew --num-envs=4 --no-graphics
# mlagents-learn config/sac_transfer/3DBall.yaml --run-id=ball_2f --env=envs/3dball_qudrew --num-envs=4 --no-graphics
# mlagents-learn config/sac_transfer/3DBall.yaml --run-id=ball_1f --env=envs/3dball1f_qudrew --num-envs=4 --no-graphics
mlagents-learn config/sac_transfer/3DBall1fTransfer.yaml --run-id=ball_transfer_2f_noload-model --env=envs/3dball_qudrew --num-envs=4 --no-graphics
# mlagents-learn config/sac_transfer/3DBallHardTransfer.yaml --run-id=transfer_action-enc_linear --env=envs/3dballhard_qudrew --num-envs=4 --no-graphics
# mlagents-learn config/sac_transfer/3DBallHard.yaml --run-id=hardball_action-enc_linear --env=envs/3dballhard_qudrew --num-envs=4 --no-graphics
# mlagents-learn config/sac_transfer/3DBallHardTransfer1.yaml --run-id=sac_transfer_hardball_fixpol_ov --env=envs/3dballhard --num-envs=4 --no-graphics
# mlagents-learn config/sac_transfer/CrawlerStatic.yaml --run-id=oldcs --env=envs/old_crawler_static --num-envs=4 --no-graphics
# mlagents-learn config/sac_transfer/CrawlerStaticTransfer.yaml --run-id=transfer_newcs --env=envs/new_crawler_static --num-envs=4 --no-graphics
# mlagents-learn config/sac_transfer/CrawlerStatic.yaml --run-id=newcs--env=envs/new_crawler_static --num-envs=4 --no-graphics
正在加载...
取消
保存