浏览代码

new config

/develop/bisim-sac-transfer
yanchaosun 4 年前
当前提交
a505cb16
共有 7 个文件被更改,包括 168 次插入48 次删除
  1. 4
      config/ppo_transfer/CrawlerStatic.yaml
  2. 6
      config/ppo_transfer/TransferCrawlerStatic.yaml
  3. 4
      ml-agents/mlagents/trainers/policy/transfer_policy.py
  4. 36
      ml-agents/mlagents/trainers/sac_transfer/optimizer.py
  5. 7
      ml-agents/mlagents/trainers/sac_transfer/trainer.py
  6. 126
      ml-agents/mlagents/trainers/tests/reward_plot.ipynb
  7. 33
      ml-agents/mlagents/trainers/tests/test_simple_transfer.py

4
config/ppo_transfer/CrawlerStatic.yaml


epsilon: 0.2
lambd: 0.95
num_epoch: 3
learning_rate_schedule: constant
model_schedule: constant
learning_rate_schedule: linear
model_schedule: linear
encoder_layers: 2
action_layers: 2
policy_layers: 2

6
config/ppo_transfer/TransferCrawlerStatic.yaml


epsilon: 0.2
lambd: 0.95
num_epoch: 3
learning_rate_schedule: constant
model_schedule: constant
learning_rate_schedule: linear
model_schedule: linear
encoder_layers: 2
action_layers: 2
policy_layers: 2

load_model: true
train_action: false
train_model: false
transfer_path: "results/csold-bisim-l2/CrawlerStatic"
transfer_path: "results/csold-bisim-linear/CrawlerStatic"
network_settings:
normalize: true
hidden_units: 512

4
ml-agents/mlagents/trainers/policy/transfer_policy.py


# activation=tf.tanh,
# kernel_initializer=tf.initializers.variance_scaling(1.0),
)
if not self.transfer:
encoded_next_state = tf.stop_gradient(encoded_next_state)
# if not self.transfer:
encoded_next_state = tf.stop_gradient(encoded_next_state)
squared_difference = 0.5 * tf.reduce_sum(
tf.squared_difference(tf.tanh(self.predict), encoded_next_state), axis=1
)

36
ml-agents/mlagents/trainers/sac_transfer/optimizer.py


class SACTransferOptimizer(TFOptimizer):
def __init__(self, policy: TFPolicy, trainer_params: TrainerSettings):
def __init__(self, policy: TransferPolicy, trainer_params: TrainerSettings):
"""
Takes a Unity environment and model-specific hyper-parameters and returns the
appropriate PPO agent model for the environment.

# Create the graph here to give more granular control of the TF graph to the Optimizer.
policy.create_tf_graph(
# hyperparameters.encoder_layers,
# hyperparameters.action_layers,
# hyperparameters.policy_layers,
# hyperparameters.forward_layers,
# hyperparameters.inverse_layers,
# hyperparameters.feature_size,
# hyperparameters.action_feature_size,
# self.use_transfer,
# self.separate_policy_train,
# self.use_var_encoder,
# self.use_var_predict,
# self.predict_return,
# self.use_inverse_model,
# self.reuse_encoder,
# self.use_bisim,
hyperparameters.encoder_layers,
hyperparameters.action_layers,
hyperparameters.policy_layers,
hyperparameters.forward_layers,
hyperparameters.inverse_layers,
hyperparameters.feature_size,
hyperparameters.action_feature_size,
self.use_transfer,
self.separate_policy_train,
self.use_var_encoder,
self.use_var_predict,
self.predict_return,
self.use_inverse_model,
self.reuse_encoder,
self.use_bisim,
)
with policy.graph.as_default():

)
self.policy.initialize_or_load()
print("All variables in the graph:")
for variable in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
print(variable)
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",

7
ml-agents/mlagents/trainers/sac_transfer/trainer.py


from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.policy.transfer_policy import TransferPolicy
from mlagents.trainers.policy.nn_policy import NNPolicy
from mlagents.trainers.sac_transfer.optimizer import SACTransferOptimizer
from mlagents.trainers.trainer.rl_trainer import RLTrainer

def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters
) -> TFPolicy:
policy = NNPolicy(
policy = TransferPolicy(
self.seed,
brain_parameters,
self.trainer_settings,

self.__class__.__name__
)
)
if not isinstance(policy, NNPolicy):
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()")
if not isinstance(policy, TransferPolicy):
raise RuntimeError("Non-TransferPolicy passed to SACTransferTrainer.add_policy()")
self.policy = policy
self.optimizer = SACTransferOptimizer(self.policy, self.trainer_settings)
for _reward_signal in self.optimizer.reward_signals.keys():

126
ml-agents/mlagents/trainers/tests/reward_plot.ipynb
文件差异内容过多而无法显示
查看文件

33
ml-agents/mlagents/trainers/tests/test_simple_transfer.py


trainer_type=TrainerType.PPO_Transfer,
hyperparameters=PPOTransferSettings(
learning_rate=5.0e-3,
# learning_rate_schedule=ScheduleType.CONSTANT,
learning_rate_schedule=ScheduleType.CONSTANT,
batch_size=16,
buffer_size=64,
feature_size=4,

batch_size=1200,
buffer_size=12000,
learning_rate=5.0e-3,
use_bisim=False,
use_bisim=True,
predict_return=True,
reuse_encoder=True,
separate_value_train=True,

buffer_size=12000,
use_transfer=True,
transfer_path=transfer_from, # separate_policy_train=True, separate_value_train=True,
use_op_buffer=False,
use_op_buffer=True,
train_model=False,
train_action=False,
train_model=True,
train_action=True,
train_encoder=True,
reuse_encoder=True,
separate_value_train=True,

if __name__ == "__main__":
# if seed > -1:
# for obs in ["normal", "long", "longpre"]:
# test_2d_model(seed=seed, obs_spec_type=obs, run_id="model_" + obs)
# test_2d_ppo(seed=seed, obs_spec_type=obs, run_id="ppo_" + obs)
if seed > -1:
for obs in ["long-n", "longpre-n"]:
test_2d_model(seed=seed, obs_spec_type=obs, run_id="model_bisim_" + obs)
# test_2d_ppo(seed=seed, obs_spec_type=obs, run_id="ppo_" + obs)
for obs in ["long", "longpre"]:
test_2d_transfer(
seed=seed,
obs_spec_type=obs,
transfer_from="./transfer_results/model_normal_s" + str(seed) + "/Simple",
run_id="normal_transfer_linear_fix_to_" + obs,
)
# for obs in ["long-n", "longpre-n"]:
# test_2d_transfer(
# seed=seed,
# obs_spec_type=obs,
# transfer_from="./transfer_results/model_normal_s" + str(seed) + "/Simple",
# run_id="normal_transfer_bisim_to_" + obs,
# )
# # test_2d_model(config=SAC_CONFIG, run_id="sac_rich2_hard", seed=0)
# for obs in ["normal", "rich2"]:

正在加载...
取消
保存