浏览代码

Fix crash when SAC is used with Curiosity and Continuous Actions (#2740)

* Add test for curiosity + SAC

* Use actions for all curiosity (need to test on PPO)

* Fix issue with reward signals updating multiple times

* Put curiosity actions in the right placeholder

* Test PPO curiosity update
/develop-gpu-test
GitHub 5 年前
当前提交
619465e1
共有 5 个文件被更改,包括 33 次插入11 次删除
  1. 2
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
  2. 10
      ml-agents/mlagents/trainers/sac/trainer.py
  3. 7
      ml-agents/mlagents/trainers/tests/mock_brain.py
  4. 10
      ml-agents/mlagents/trainers/tests/test_ppo.py
  5. 15
      ml-agents/mlagents/trainers/tests/test_sac.py

2
ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py


policy_model.mask_input: mini_batch["masks"],
}
if self.policy.use_continuous_act:
feed_dict[policy_model.output_pre] = mini_batch["actions_pre"]
feed_dict[policy_model.selected_actions] = mini_batch["actions"]
else:
feed_dict[policy_model.action_holder] = mini_batch["actions"]
if self.policy.use_vec_obs:

10
ml-agents/mlagents/trainers/sac/trainer.py


self.trainer_parameters["batch_size"],
sequence_length=self.policy.sequence_length,
)
update_stats = self.policy.update_reward_signals(
reward_signal_minibatches, n_sequences
)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)
update_stats = self.policy.update_reward_signals(
reward_signal_minibatches, n_sequences
)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)
for stat, stat_list in batch_update_stats.items():
self.stats[stat].append(np.mean(stat_list))

7
ml-agents/mlagents/trainers/tests/mock_brain.py


mock_env.return_value.step.return_value = {brain_name: mock_braininfo}
def simulate_rollout(env, policy, buffer_init_samples):
def simulate_rollout(env, policy, buffer_init_samples, exclude_key_list=None):
# If a key_list was given, remove those keys
if exclude_key_list:
for key in exclude_key_list:
if key in buffer.update_buffer:
buffer.update_buffer.pop(key)
return buffer

10
ml-agents/mlagents/trainers/tests/test_ppo.py


trainer_params = dummy_config
trainer_params["use_recurrent"] = True
# Test curiosity reward signal
trainer_params["reward_signals"]["curiosity"] = {}
trainer_params["reward_signals"]["curiosity"]["strength"] = 1.0
trainer_params["reward_signals"]["curiosity"]["gamma"] = 0.99
trainer_params["reward_signals"]["curiosity"]["encoding_size"] = 128
trainer = PPOTrainer(mock_brain, 0, trainer_params, True, False, 0, "0", False)
# Test update with sequence length smaller than batch size
buffer = mb.simulate_rollout(env, trainer.policy, BUFFER_INIT_SAMPLES)

buffer.update_buffer["extrinsic_value_estimates"] = buffer.update_buffer["rewards"]
buffer.update_buffer["curiosity_rewards"] = buffer.update_buffer["rewards"]
buffer.update_buffer["curiosity_returns"] = buffer.update_buffer["rewards"]
buffer.update_buffer["curiosity_value_estimates"] = buffer.update_buffer["rewards"]
trainer.training_buffer = buffer
trainer.update_policy()
# Make batch length a larger multiple of sequence length

15
ml-agents/mlagents/trainers/tests/test_sac.py


env.close()
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
def test_sac_update_reward_signals(mock_env, dummy_config):
def test_sac_update_reward_signals(mock_env, dummy_config, discrete):
# Test evaluate
tf.reset_default_graph()
# Add a Curiosity module

dummy_config["reward_signals"]["curiosity"]["encoding_size"] = 128
env, policy = create_sac_policy_mock(
mock_env, dummy_config, use_rnn=False, use_discrete=False, use_visual=False
mock_env, dummy_config, use_rnn=False, use_discrete=discrete, use_visual=False
)
# Test update, while removing PPO-specific buffer elements.
buffer = mb.simulate_rollout(
env,
policy,
BUFFER_INIT_SAMPLES,
exclude_key_list=["advantages", "actions_pre", "random_normal_epsilon"],
# Test update
buffer = mb.simulate_rollout(env, policy, BUFFER_INIT_SAMPLES)
# Mock out reward signal eval
buffer.update_buffer["extrinsic_rewards"] = buffer.update_buffer["rewards"]
buffer.update_buffer["curiosity_rewards"] = buffer.update_buffer["rewards"]

正在加载...
取消
保存