浏览代码

Saving the reward providers

/develop/torch-save-rp
vincentpierre 4 年前
当前提交
9f51ab14
共有 5 个文件被更改,包括 23 次插入2 次删除
  1. 5
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  2. 5
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  3. 9
      ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py
  4. 3
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
  5. 3
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py

5
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


return update_stats
def get_modules(self):
return {"Optimizer": self.optimizer}
modules = {"Optimizer": self.optimizer}
for reward_provider in self.reward_signals.values():
modules.update(reward_provider.get_modules())
return modules

5
ml-agents/mlagents/trainers/sac/optimizer_torch.py


return {}
def get_modules(self):
return {
modules = {
"Optimizer:value_network": self.value_network,
"Optimizer:target_network": self.target_network,
"Optimizer:policy_optimizer": self.policy_optimizer,

for reward_provider in self.reward_signals.values():
modules.update(reward_provider.get_modules())
return modules

9
ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py


import numpy as np
import torch
from abc import ABC, abstractmethod
from typing import Dict

raise NotImplementedError(
"The reward provider's update method has not been implemented "
)
def get_modules(self) -> Dict[str, torch.nn.Module]:
"""
Returns a dictionary of string identifiers to the torch.nn.Modules used by
the reward providers. This method is used for loading and saving the weights
of the reward providers.
"""
return {}

3
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py


"Losses/Curiosity Inverse Loss": inverse_loss.detach().cpu().numpy(),
}
def get_modules(self):
return {f"Optimizer:{self.name}": self._network}
class CuriosityNetwork(torch.nn.Module):
EPSILON = 1e-10

3
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


self.optimizer.step()
return stats_dict
def get_modules(self):
return {f"Optimizer:{self.name}": self._discriminator_network}
class DiscriminatorNetwork(torch.nn.Module):
gradient_penalty_weight = 10.0

正在加载...
取消
保存