浏览代码

fix pytorch checkpointing. add tensors in Normalizer as parameter

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
b98eb5f3
共有 1 个文件被更改,包括 7 次插入6 次删除
  1. 13
      ml-agents/mlagents/trainers/torch/encoders.py

13
ml-agents/mlagents/trainers/torch/encoders.py


class Normalizer(nn.Module):
def __init__(self, vec_obs_size: int):
super().__init__()
self.normalization_steps = torch.tensor(1)
self.running_mean = torch.zeros(vec_obs_size)
self.running_variance = torch.ones(vec_obs_size)
self.normalization_steps = nn.Parameter(torch.tensor(1), requires_grad=False)
self.running_mean = nn.Parameter(torch.zeros(vec_obs_size), requires_grad=False)
self.running_variance = nn.Parameter(torch.ones(vec_obs_size), requires_grad=False)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
normalized_state = torch.clamp(

new_variance = self.running_variance + (
input_to_new_mean * input_to_old_mean
).sum(0)
self.running_mean = new_mean
self.running_variance = new_variance
self.normalization_steps = total_new_steps
self.running_mean.data = new_mean.data
self.running_variance.data = new_variance.data
self.normalization_steps.data = total_new_steps.data
def copy_from(self, other_normalizer: "Normalizer") -> None:
self.normalization_steps.data.copy_(other_normalizer.normalization_steps.data)

self.normalizer: Optional[Normalizer] = None
super().__init__()
self.layers = [nn.Linear(input_size, hidden_size)]
print('-'*10, normalize, '-'*10)
if normalize:
self.normalizer = Normalizer(input_size)

正在加载...
取消
保存