浏览代码

make critic a property

/develop/action-slice
Andrew Cohen 4 年前
当前提交
8efdeeb0
共有 3 个文件被更改,包括 19 次插入12 次删除
  1. 13
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 12
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  3. 6
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

13
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.networks import ValueNetwork
from mlagents.trainers.torch.components.bc.module import BCModule
from mlagents.trainers.torch.components.reward_providers import create_reward_provider

class TorchOptimizer(Optimizer):
def __init__(
self,
policy: TorchPolicy,
critic: ValueNetwork,
trainer_settings: TrainerSettings,
):
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
self.critic = critic
self.trainer_settings = trainer_settings
self.update_dict: Dict[str, torch.Tensor] = {}
self.value_heads: Dict[str, torch.Tensor] = {}

default_batch_size=trainer_settings.hyperparameters.batch_size,
default_num_epoch=3,
)
@property
def critic(self):
raise NotImplementedError
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
pass

12
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


"""
# Create the graph here to give more granular control of the TF graph to the Optimizer.
super().__init__(policy, trainer_settings)
critic = ValueNetwork(
self.value_net = ValueNetwork(
super().__init__(policy, critic, trainer_settings)
params = list(self.policy.actor.parameters()) + list(self.critic.parameters())
params = list(self.policy.actor.parameters()) + list(
self.value_net.parameters()
)
self.hyperparameters: PPOSettings = cast(
PPOSettings, trainer_settings.hyperparameters
)

}
self.stream_names = list(self.reward_signals.keys())
@property
def critic(self):
return self.value_net
def ppo_value_loss(
self,

6
ml-agents/mlagents/trainers/sac/optimizer_torch.py


self.continuous = continuous
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
super().__init__(policy, trainer_params)
reward_signal_configs = trainer_params.reward_signals
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]

policy.network_settings,
)
super().__init__(policy, self.value_network, trainer_params)
hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters)
self.tau = hyperparameters.tau
self.init_entcoef = hyperparameters.init_entcoef

self._log_ent_coef.parameters(), lr=hyperparameters.learning_rate
)
self._move_to_device(default_device())
@property
def critic(self):
return self.value_network
def _move_to_device(self, device: torch.device) -> None:
self._log_ent_coef.to(device)

正在加载...
取消
保存