浏览代码

add SharedActorCritic

/develop/action-slice
Andrew Cohen 4 年前
当前提交
c74dca9f
共有 4 个文件被更改,包括 99 次插入26 次删除
  1. 33
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 13
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  3. 14
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  4. 65
      ml-agents/mlagents/trainers/torch/networks.py

33
ml-agents/mlagents/trainers/policy/torch_policy.py


from mlagents_envs.timers import timed
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.torch.networks import SimpleActor, GlobalSteps
from mlagents.trainers.torch.networks import SimpleActor, SharedActorCritic, GlobalSteps
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.buffer import AgentBuffer

"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
}
self.actor = SimpleActor(
observation_specs=self.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
action_spec=behavior_spec.action_spec,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
if separate_critic:
self.actor = SimpleActor(
observation_specs=self.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
action_spec=behavior_spec.action_spec,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
self.shared_critic = False
else:
reward_signal_configs = trainer_settings.reward_signals
reward_signal_names = [
key.value for key, _ in reward_signal_configs.items()
]
self.actor = SharedActorCritic(
observation_specs=self.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
action_spec=behavior_spec.action_spec,
stream_names=reward_signal_names,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
self.shared_critic = False
# Save the m_size needed for export
self._export_m_size = self.m_size
# m_size needed for training is determined by network, not trainer settings

13
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


reward_signal_configs = trainer_settings.reward_signals
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]
self.value_net = ValueNetwork(
reward_signal_names,
policy.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
)
if policy.shared_critic:
self.value_net = policy.actor
else:
self.value_net = ValueNetwork(
reward_signal_names,
policy.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
)
params = list(self.policy.actor.parameters()) + list(
self.value_net.parameters()

14
ml-agents/mlagents/trainers/sac/optimizer_torch.py


super().__init__(policy, trainer_params)
reward_signal_configs = trainer_params.reward_signals
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]
self.value_network = ValueNetwork(
reward_signal_names,
policy.behavior_spec.observation_specs,
policy.network_settings,
)
if policy.shared_critic:
self.value_network = policy.actor
else:
self.value_network = ValueNetwork(
reward_signal_names,
policy.behavior_spec.observation_specs,
policy.network_settings,
)
hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters)
self.tau = hyperparameters.tau

65
ml-agents/mlagents/trainers/torch/networks.py


return encoding, memories
class ValueNetwork(nn.Module):
class Critic(abc.ABC):
@abc.abstractmethod
def update_normalization(self, buffer: AgentBuffer) -> None:
"""
Updates normalization of Actor based on the provided List of vector obs.
:param vector_obs: A List of vector obs as tensors.
"""
pass
def critic_pass(
self,
inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""
Get value outputs for the given obs.
:param inputs: List of inputs as tensors.
:param memories: Tensor of memories, if using memory. Otherwise, None.
:returns: Dict of reward stream to output tensor for values.
"""
pass
class ValueNetwork(nn.Module, Critic):
def __init__(
self,
stream_names: List[str],

memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""
Get value outputs for the given obs.
:param inputs: List of inputs as tensors.
:param memories: Tensor of memories, if using memory. Otherwise, None.
:returns: Dict of reward stream to output tensor for values.
"""
value_outputs, critic_mem_out = self.forward(
inputs, memories=memories, sequence_length=sequence_length
)

self.act_size_vector_deprecated,
]
return tuple(export_out)
class SharedActorCritic(SimpleActor, Critic):
def __init__(
self,
observation_specs: List[ObservationSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,
):
self.use_lstm = network_settings.memory is not None
super().__init__(
observation_specs,
network_settings,
action_spec,
conditional_sigma,
tanh_squash,
)
self.stream_names = stream_names
self.value_heads = ValueHeads(stream_names, self.encoding_size)
def critic_pass(
self,
inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories_out = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
)
return self.value_heads(encoding), memories_out
class GlobalSteps(nn.Module):

正在加载...
取消
保存