浏览代码

fix sac transfer problems

/develop/bisim-sac-transfer
yanchaosun 4 年前
当前提交
00bb821c
共有 7 个文件被更改,包括 98 次插入48 次删除
  1. 9
      config/sac_transfer/3DBall.yaml
  2. 4
      config/sac_transfer/3DBallHard.yaml
  3. 10
      config/sac_transfer/3DBallHardTransfer.yaml
  4. 10
      ml-agents/mlagents/trainers/policy/transfer_policy.py
  5. 89
      ml-agents/mlagents/trainers/sac_transfer/network.py
  6. 23
      ml-agents/mlagents/trainers/sac_transfer/optimizer.py
  7. 1
      ml-agents/mlagents/trainers/settings.py

9
config/sac_transfer/3DBall.yaml


init_entcoef: 0.5
reward_signal_steps_per_update: 10.0
encoder_layers: 1
policy_layers: 1
policy_layers: 2
feature_size: 32
reuse_encoder: false
feature_size: 16
separate_policy_train: true
reuse_encoder: true
in_epoch_alter: false
in_batch_alter: true
use_op_buffer: false

use_bisim: true
use_bisim: false
network_settings:
normalize: true
hidden_units: 64

4
config/sac_transfer/3DBallHard.yaml


policy_layers: 1
forward_layers: 1
value_layers: 2
feature_size: 32
feature_size: 16
reuse_encoder: false
in_epoch_alter: false
in_batch_alter: true

predict_return: true
use_bisim: true
use_bisim: false
network_settings:
normalize: true
hidden_units: 128

10
config/sac_transfer/3DBallHardTransfer.yaml


learning_rate: 0.0003
learning_rate_schedule: linear
batch_size: 256
buffer_size: 24000
buffer_size: 50000
buffer_init_steps: 0
tau: 0.005
steps_per_update: 10.0

policy_layers: 1
forward_layers: 1
value_layers: 2
feature_size: 32
feature_size: 16
in_batch_alter: true
in_batch_alter: false
use_bisim: true
use_bisim: false
transfer_path: "results/sac_model_ball_sep_linear_f32/3DBall"
transfer_path: "results/sac_model_ball_new/3DBall"
network_settings:
normalize: true
hidden_units: 128

10
ml-agents/mlagents/trainers/policy/transfer_policy.py


# action_layers
# )
# if self.inverse_model:
# with tf.variable_scope("inverse"):
# self.create_inverse_model(
# self.encoder, self.targ_encoder, inverse_layers
# )
if self.inverse_model:
with tf.variable_scope("inverse"):
self.create_inverse_model(
self.encoder, self.targ_encoder, inverse_layers
)
with tf.variable_scope("predict"):

89
ml-agents/mlagents/trainers/sac_transfer/network.py


num_layers=2,
stream_names=None,
vis_encode_type=EncoderType.SIMPLE,
separate_train=False
separate_train=False,
):
super().__init__(
policy,

shape=[None, m_size], dtype=tf.float32, name="target_recurrent_in"
)
self.value_memory_in = self.memory_in
hidden_streams = ModelUtils.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,
self.h_size,
0,
vis_encode_type=vis_encode_type,
stream_scopes=["critic/value/"],
)
# hidden_streams = ModelUtils.create_observation_streams(
# self.visual_in,
# self.processed_vector_in,
# 1,
# self.h_size,
# 1,
# vis_encode_type=vis_encode_type,
# stream_scopes=["critic/value/"],
# # reuse=True
# )
hidden_critic = self._create_observation_in(self.visual_in,
self.processed_vector_in,
vis_encode_type)
self._create_cc_critic(hidden_streams[0], hidden_streams[0], TARGET_SCOPE, create_qs=False)
self._create_cc_critic(hidden_critic, hidden_critic, TARGET_SCOPE, create_qs=False)
# self._create_cc_critic(self.policy.targ_encoder, self.policy.targ_encoder, TARGET_SCOPE, create_qs=False)
# self._create_cc_critic(self.policy.targ_encoder, TARGET_SCOPE, create_qs=False)
else:
self._create_dc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False)

self.value_memory_out, axis=1
) # Needed for Barracuda to work
# self.critic_vars += self.get_vars("target_enc")
# self.value_vars += self.get_vars("target_enc")
def copy_normalization(self, mean, variance, steps):
"""

update_norm_step = tf.assign(self.normalization_steps, steps)
return tf.group([update_mean, update_variance, update_norm_step])
def _create_observation_in(self, visual_in, vector_in, vis_encode_type):
"""
Creates the observation inputs, and a CNN if needed,
:param vis_encode_type: Type of CNN encoder.
:param share_ac_cnn: Whether or not to share the actor and critic CNNs.
:return A tuple of (hidden_policy, hidden_critic). We don't save it to self since they're used
once and thrown away.
"""
hidden = self.policy._create_encoder_general(
visual_in,
vector_in,
self.h_size,
self.policy.feature_size,
1,
vis_encode_type,
scope="target_enc", #"target_network/critic/value",
reuse=True
)
return hidden
class SACTransferPolicyNetwork(SACTransferNetwork):
"""

self.policy.output = self.policy.output
# Use the sequence length of the policy
self.sequence_length_ph = self.policy.sequence_length_ph
# self.hidden = hidden_critic
if separate_train:
hidden = tf.stop_gradient(self.policy.encoder)

self._create_cc_critic(hidden_critic, hidden_critic, POLICY_SCOPE)
# self._create_cc_critic(self.policy.encoder, self.policy.encoder, POLICY_SCOPE)
# self._create_cc_critic(hidden, POLICY_SCOPE)
else:
self._create_dc_critic(hidden_critic, POLICY_SCOPE)

mem_outs = [self.value_memory_out, self.q1_memory_out, self.q2_memory_out]
self.memory_out = tf.concat(mem_outs, axis=1)
# self.critic_vars += self.get_vars("encoding")
# self.value_vars += self.get_vars("encoding")
def _create_memory_ins(self, m_size):
"""

:return A tuple of (hidden_policy, hidden_critic). We don't save it to self since they're used
once and thrown away.
"""
with tf.variable_scope(POLICY_SCOPE):
hidden_streams = ModelUtils.create_observation_streams(
self.policy.visual_in,
self.policy.processed_vector_in,
1,
self.h_size,
0,
vis_encode_type=vis_encode_type,
stream_scopes=["critic/value/"],
)
hidden_critic = hidden_streams[0]
return hidden_critic
# with tf.variable_scope(POLICY_SCOPE):
# hidden_streams = ModelUtils.create_observation_streams(
# self.policy.visual_in,
# self.policy.processed_vector_in,
# 1,
# self.h_size,
# 1,
# vis_encode_type=vis_encode_type,
# stream_scopes=["critic/value/"],
# )
# hidden_critic = hidden_streams[0]
# return hidden_critic
hidden = self.policy._create_encoder_general(
self.policy.visual_in,
self.policy.processed_vector_in,
self.h_size,
self.policy.feature_size,
1,
vis_encode_type,
scope="encoding", #"critic/value",
reuse=True
)
return hidden

23
ml-agents/mlagents/trainers/sac_transfer/optimizer.py


self.update_batch_value: Optional[tf.Operation] = None
self.update_batch_entropy: Optional[tf.Operation] = None
self.policy_network = SACTransferPolicyNetwork(
policy=self.policy,
m_size=self.policy.m_size, # 3x policy.m_size

)
self.policy.initialize_or_load()
self.policy.run_hard_copy()
print("All variables in the graph:")
for variable in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):

]
policy_vars = self.policy.get_trainable_variables(
train_encoder=self.train_encoder,
train_encoder=not self.separate_policy_train,
train_model=self.train_model,
train_model=False,
train_policy=self.train_policy
)

train_model=self.train_model,
train_policy=self.train_policy
train_policy=False
encoding_vars = self.policy.encoding_variables
critic_vars = self.policy_network.critic_vars + policy_vars
critic_vars = self.policy_network.critic_vars + encoding_vars
critic_vars = policy_vars
critic_vars = encoding_vars
self.target_init_op = [
tf.assign(target, source)

self.update_batch_policy = policy_optimizer.minimize(
self.policy_loss, var_list=policy_vars
)
print("value trainable:", critic_vars)
self.total_value_loss, var_list=self.policy_network.critic_vars
self.total_value_loss, var_list=critic_vars
)
# Add entropy coefficient optimization operation
with tf.control_dependencies([self.update_batch_value]):

stats_needed = self.stats_name_to_update_name
update_stats: Dict[str, float] = {}
# update_vals = self._execute_model(feed_dict, self.update_dict)
for stat_name, update_name in stats_needed.items():
update_stats[stat_name] = update_vals[update_name]

# Update target network. By default, target update happens at every policy update.
self.sess.run(self.target_update_op)
self.policy.run_soft_copy()
if not self.reuse_encoder:
self.policy.run_soft_copy()
return update_stats
def update_encoder(self, mini_batch1: AgentBuffer, mini_batch2: AgentBuffer):

1
ml-agents/mlagents/trainers/settings.py


threaded: bool = True
self_play: Optional[SelfPlaySettings] = None
behavioral_cloning: Optional[BehavioralCloningSettings] = None
transfer: bool = False
cattr.register_structure_hook(
Dict[RewardSignalType, RewardSignalSettings], RewardSignalSettings.structure

正在加载...
取消
保存