浏览代码

fix bisim metric

/develop/bisim-review
Andrew Cohen 4 年前
当前提交
b6bf1860
共有 2 个文件被更改,包括 26 次插入14 次删除
  1. 4
      ml-agents/mlagents/trainers/policy/transfer_policy.py
  2. 36
      ml-agents/mlagents/trainers/tests/test_simple_transfer.py

4
ml-agents/mlagents/trainers/policy/transfer_policy.py


return kl
def w_distance(self, another):
return tf.squared_difference(self.mu, another.mu) + tf.squared_difference(
return tf.reduce_sum(tf.squared_difference(self.mu, another.mu) + tf.squared_difference(
)
))
class TransferPolicy(TFPolicy):

36
ml-agents/mlagents/trainers/tests/test_simple_transfer.py


# use_inverse_model=True
)
config = attr.evolve(
config, hyperparameters=new_hyperparams, max_steps=500000, summary_freq=5000
config, hyperparameters=new_hyperparams, max_steps=20000, summary_freq=5000
)
_check_environment_trains(
env, {BRAIN_NAME: config}, run_id=run_id + "_s" + str(seed), seed=seed

train_policy=True,
train_value=True,
train_model=False,
separate_value_train=True,
separate_policy_train=False,
feature_size=16,
use_var_predict=True,
with_prior=False,

use_bisim=True,
)
config = attr.evolve(
config, hyperparameters=new_hyperparams, max_steps=500000, summary_freq=5000
config, hyperparameters=new_hyperparams, max_steps=20000, summary_freq=5000
)
_check_environment_trains(
env, {BRAIN_NAME: config}, run_id=run_id + "_s" + str(seed), seed=seed

if __name__ == "__main__":
for obs in ["normal"]: # ["normal", "rich1", "rich2"]:
test_2d_model(seed=0, obs_spec_type=obs, run_id="model_" + obs)
for seed in range(5):
for obs in ["normal", "rich1", "rich2"]:
test_2d_model(seed=seed, obs_spec_type=obs, run_id="model_" + obs)
# test_2d_model(config=SAC_CONFIG, run_id="sac_rich2_hard", seed=0)
for obs in ["normal"]:
test_2d_transfer(
seed=0,
obs_spec_type="rich1",
transfer_from="./transfer_results/model_" + obs + "_s0/Simple",
run_id="transfer_rich1",
)
# test_2d_model(config=SAC_CONFIG, run_id="sac_rich2_hard", seed=0)
for obs in ["normal", "rich2"]:
test_2d_transfer(
seed=seed,
obs_spec_type="rich1",
transfer_from="./transfer_results/model_" + obs + "_s" + str(seed) + "/Simple",
run_id=obs + "transfer_to_rich1",
)
for obs in ["normal", "rich1"]:
test_2d_transfer(
seed=seed,
obs_spec_type="rich2",
transfer_from="./transfer_results/model_" + obs + "_s" + str(seed) + "/Simple",
run_id=obs + "transfer_to_rich2",
)
# for obs in ["normal"]:
# test_2d_transfer(seed=0, obs_spec_type="rich1",

正在加载...
取消
保存