浏览代码

fix transfer

/develop/bisim-sac-transfer
yanchaosun 4 年前
当前提交
e2f0b3ca
共有 5 个文件被更改,包括 30 次插入42 次删除
  1. 4
      config/sac_transfer/3DBall.yaml
  2. 9
      config/sac_transfer/3DBallHard.yaml
  3. 8
      config/sac_transfer/3DBallHardTransfer.yaml
  4. 42
      ml-agents/mlagents/trainers/sac_transfer/network.py
  5. 9
      ml-agents/mlagents/trainers/sac_transfer/optimizer.py

4
config/sac_transfer/3DBall.yaml


init_entcoef: 0.5
reward_signal_steps_per_update: 10.0
encoder_layers: 1
policy_layers: 2
policy_layers: 1
value_layers: 2
value_layers: 1
feature_size: 16
separate_policy_train: true
reuse_encoder: true

9
config/sac_transfer/3DBallHard.yaml


learning_rate: 0.0003
learning_rate_schedule: linear
batch_size: 256
buffer_size: 24000
buffer_size: 50000
buffer_init_steps: 0
tau: 0.005
steps_per_update: 10.0

encoder_layers: 1
policy_layers: 1
forward_layers: 1
value_layers: 2
value_layers: 1
separate_policy_train: true
reuse_encoder: false
in_epoch_alter: false
in_batch_alter: true

use_bisim: false
network_settings:
normalize: true
hidden_units: 128
hidden_units: 64
num_layers: 2
vis_encode_type: simple
reward_signals:

keep_checkpoints: 5
max_steps: 1000000
max_steps: 500000
time_horizon: 1000
summary_freq: 12000
threaded: true

8
config/sac_transfer/3DBallHardTransfer.yaml


encoder_layers: 1
policy_layers: 1
forward_layers: 1
value_layers: 2
value_layers: 1
feature_size: 16
reuse_encoder: false
in_epoch_alter: false

use_transfer: true
load_model: true
train_model: false
transfer_path: "results/sac_model_ball_new/3DBall"
transfer_path: "results/sac_model_ball/3DBall"
hidden_units: 128
hidden_units: 64
num_layers: 2
vis_encode_type: simple
reward_signals:

keep_checkpoints: 5
max_steps: 1000000
max_steps: 500000
time_horizon: 1000
summary_freq: 12000
threaded: true

42
ml-agents/mlagents/trainers/sac_transfer/network.py


self.value_heads[name] = value
self.value = tf.reduce_mean(list(self.value_heads.values()), 0)
def _create_cc_critic(self, encoder, hidden_value, scope, create_qs=True):
def _create_cc_critic(self, hidden_value, scope, create_qs=True):
"""
Creates just the critic network
"""

encoder,
hidden_value,
self.num_layers,
self.h_size,
self.join_scopes(scope, "value"),

shape=[None, m_size], dtype=tf.float32, name="target_recurrent_in"
)
self.value_memory_in = self.memory_in
# hidden_streams = ModelUtils.create_observation_streams(
# self.visual_in,
# self.processed_vector_in,
# 1,
# self.h_size,
# 1,
# vis_encode_type=vis_encode_type,
# stream_scopes=["critic/value/"],
# # reuse=True
# )
hidden_critic = self._create_observation_in(self.visual_in,
hidden_critic = self._create_observation_in(
self.visual_in,
vis_encode_type)
vis_encode_type
)
self._create_cc_critic(hidden_critic, hidden_critic, TARGET_SCOPE, create_qs=False)
self._create_cc_critic(hidden_critic, TARGET_SCOPE, create_qs=False)
self._create_dc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False)
self._create_dc_critic(hidden_critic, TARGET_SCOPE, create_qs=False)
# self._create_dc_critic(self.policy.targ_encoder, TARGET_SCOPE, create_qs=False)
if self.use_recurrent:
self.memory_out = tf.concat(

else:
hidden = self.policy.encoder
if self.policy.use_continuous_act:
self._create_cc_critic(hidden_critic, hidden_critic, POLICY_SCOPE)
self._create_cc_critic(hidden_critic, POLICY_SCOPE)
# self._create_cc_critic(self.policy.encoder, self.policy.encoder, POLICY_SCOPE)
# self._create_cc_critic(hidden, POLICY_SCOPE)
else:

if self.use_recurrent:
mem_outs = [self.value_memory_out, self.q1_memory_out, self.q2_memory_out]
self.memory_out = tf.concat(mem_outs, axis=1)
# self.critic_vars += self.get_vars("encoding")
# self.value_vars += self.get_vars("encoding")
def _create_memory_ins(self, m_size):
"""

:return A tuple of (hidden_policy, hidden_critic). We don't save it to self since they're used
once and thrown away.
"""
# with tf.variable_scope(POLICY_SCOPE):
# hidden_streams = ModelUtils.create_observation_streams(
# self.policy.visual_in,
# self.policy.processed_vector_in,
# 1,
# self.h_size,
# 1,
# vis_encode_type=vis_encode_type,
# stream_scopes=["critic/value/"],
# )
# hidden_critic = hidden_streams[0]
# return hidden_critic
hidden = self.policy._create_encoder_general(
self.policy.visual_in,
self.policy.processed_vector_in,

9
ml-agents/mlagents/trainers/sac_transfer/optimizer.py


)
self.policy.initialize_or_load()
if self.use_transfer:
self.policy.load_graph_partial(
self.transfer_path,
hyperparameters.load_model,
hyperparameters.load_policy,
hyperparameters.load_value,
hyperparameters.load_encoder,
hyperparameters.load_action,
)
self.policy.run_hard_copy()
print("All variables in the graph:")

正在加载...
取消
保存