浏览代码

[bug-fix] Fixes for Torch SAC and tests (#4408)

* Fixes for Torch SAC and tests

* FIx recurrent sac test

* Properly update normalization for SAC-continuous

* Fix issue with log ent coef reporting in SAC Torch
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
beb5eb30
共有 5 个文件被更改,包括 125 次插入19 次删除
  1. 25
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  2. 6
      ml-agents/mlagents/trainers/sac/trainer.py
  3. 3
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
  4. 4
      ml-agents/mlagents/trainers/torch/networks.py
  5. 106
      ml-agents/mlagents/trainers/tests/torch/test_sac.py

25
ml-agents/mlagents/trainers/sac/optimizer_torch.py


from typing import Dict, List, Mapping, cast, Tuple, Optional
import torch
from torch import nn
import attr
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import ActionType

for name in self.stream_names
}
# Critics should have 1/2 of the memory of the policy
critic_memory = policy_network_settings.memory
if critic_memory is not None:
critic_memory = attr.evolve(
critic_memory, memory_size=critic_memory.memory_size // 2
)
value_network_settings = attr.evolve(
policy_network_settings, memory=critic_memory
)
value_network_settings,
policy_network_settings,
self.policy.behavior_spec.action_type,
self.act_size,
)

self.policy.behavior_spec.observation_shapes,
value_network_settings,
policy_network_settings,
)
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0)

"Losses/Value Loss": value_loss.item(),
"Losses/Q1 Loss": q1_loss.item(),
"Losses/Q2 Loss": q2_loss.item(),
"Policy/Entropy Coeff": torch.exp(self._log_ent_coef).item(),
"Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(),
for signal in self.reward_signals.values():
signal.update(batch)
return update_stats

return {}
update_stats: Dict[str, float] = {}
for name, update_buffer in reward_signal_minibatches.items():
update_stats.update(self.reward_signals[name].update(update_buffer))
return update_stats
def get_modules(self):
modules = {

6
ml-agents/mlagents/trainers/sac/trainer.py


self.hyperparameters.batch_size,
sequence_length=self.policy.sequence_length,
)
else:
if name != "extrinsic":
reward_signal_minibatches[name] = buffer.sample_mini_batch(
self.hyperparameters.batch_size,
sequence_length=self.policy.sequence_length,
)
update_stats = self.optimizer.update_reward_signals(
reward_signal_minibatches, n_sequences
)

3
ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py


summary_freq=100,
max_steps=1000,
threaded=False,
framework=FrameworkType.PYTORCH,
)

@pytest.mark.parametrize("use_discrete", [True, False])
def test_recurrent_sac(use_discrete):
step_size = 0.5 if use_discrete else 0.2
step_size = 0.2 if use_discrete else 0.5
env = MemoryEnvironment(
[BRAIN_NAME], use_discrete=use_discrete, step_size=step_size
)

4
ml-agents/mlagents/trainers/torch/networks.py


mem_out = None
return dists, value_outputs, mem_out
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
super().update_normalization(vector_obs)
self.critic.network_body.update_normalization(vector_obs)
class GlobalSteps(nn.Module):
def __init__(self):

106
ml-agents/mlagents/trainers/tests/torch/test_sac.py


import pytest
import copy
import torch
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.torch.test_simple_rl import SAC_CONFIG
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.tests.test_reward_signals import ( # noqa: F401; pylint: disable=unused-variable
curiosity_dummy_config,
)
@pytest.fixture
def dummy_config():
return copy.deepcopy(SAC_CONFIG)
VECTOR_ACTION_SPACE = 2
VECTOR_OBS_SPACE = 8
DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 64
NUM_AGENTS = 12
def create_sac_optimizer_mock(dummy_config, use_rnn, use_discrete, use_visual):
mock_brain = mb.setup_test_behavior_specs(
use_discrete,
use_visual,
vector_action_space=DISCRETE_ACTION_SPACE
if use_discrete
else VECTOR_ACTION_SPACE,
vector_obs_space=VECTOR_OBS_SPACE if not use_visual else 0,
)
trainer_settings = dummy_config
trainer_settings.network_settings.memory = (
NetworkSettings.MemorySettings(sequence_length=16, memory_size=12)
if use_rnn
else None
)
policy = TorchPolicy(0, mock_brain, trainer_settings, "test", False)
optimizer = TorchSACOptimizer(policy, trainer_settings)
return optimizer
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_sac_optimizer_update(dummy_config, rnn, visual, discrete):
torch.manual_seed(0)
# Test evaluate
optimizer = create_sac_optimizer_mock(
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
# Test update
update_buffer = mb.simulate_rollout(
BUFFER_INIT_SAMPLES, optimizer.policy.behavior_spec, memory_size=24
)
# Mock out reward signal eval
update_buffer["extrinsic_rewards"] = update_buffer["environment_rewards"]
return_stats = optimizer.update(
update_buffer,
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length,
)
# Make sure we have the right stats
required_stats = [
"Losses/Policy Loss",
"Losses/Value Loss",
"Losses/Q1 Loss",
"Losses/Q2 Loss",
"Policy/Entropy Coeff",
"Policy/Learning Rate",
]
for stat in required_stats:
assert stat in return_stats.keys()
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
def test_sac_update_reward_signals(
dummy_config, curiosity_dummy_config, discrete # noqa: F811
):
# Add a Curiosity module
dummy_config.reward_signals = curiosity_dummy_config
optimizer = create_sac_optimizer_mock(
dummy_config, use_rnn=False, use_discrete=discrete, use_visual=False
)
# Test update, while removing PPO-specific buffer elements.
update_buffer = mb.simulate_rollout(
BUFFER_INIT_SAMPLES, optimizer.policy.behavior_spec
)
# Mock out reward signal eval
update_buffer["extrinsic_rewards"] = update_buffer["environment_rewards"]
update_buffer["curiosity_rewards"] = update_buffer["environment_rewards"]
return_stats = optimizer.update_reward_signals(
{"curiosity": update_buffer}, num_sequences=update_buffer.num_experiences
)
required_stats = ["Losses/Curiosity Forward Loss", "Losses/Curiosity Inverse Loss"]
for stat in required_stats:
assert stat in return_stats.keys()
if __name__ == "__main__":
pytest.main()
正在加载...
取消
保存