|
|
|
|
|
|
return normalized_state |
|
|
|
|
|
|
|
def update(self, vector_input: torch.Tensor) -> None: |
|
|
|
steps_increment = vector_input.size()[0] |
|
|
|
total_new_steps = self.normalization_steps + steps_increment |
|
|
|
with torch.no_grad(): |
|
|
|
steps_increment = vector_input.size()[0] |
|
|
|
total_new_steps = self.normalization_steps + steps_increment |
|
|
|
input_to_old_mean = vector_input - self.running_mean |
|
|
|
new_mean = self.running_mean + (input_to_old_mean / total_new_steps).sum(0) |
|
|
|
input_to_old_mean = vector_input - self.running_mean |
|
|
|
new_mean: torch.Tensor = self.running_mean + ( |
|
|
|
input_to_old_mean / total_new_steps |
|
|
|
).sum(0) |
|
|
|
input_to_new_mean = vector_input - new_mean |
|
|
|
new_variance = self.running_variance + ( |
|
|
|
input_to_new_mean * input_to_old_mean |
|
|
|
).sum(0) |
|
|
|
# Update in-place |
|
|
|
self.running_mean.data.copy_(new_mean.data) |
|
|
|
self.running_variance.data.copy_(new_variance.data) |
|
|
|
self.normalization_steps.data.copy_(total_new_steps.data) |
|
|
|
input_to_new_mean = vector_input - new_mean |
|
|
|
new_variance = self.running_variance + ( |
|
|
|
input_to_new_mean * input_to_old_mean |
|
|
|
).sum(0) |
|
|
|
# Update references. This is much faster than in-place data update. |
|
|
|
self.running_mean: torch.Tensor = new_mean |
|
|
|
self.running_variance: torch.Tensor = new_variance |
|
|
|
self.normalization_steps: torch.Tensor = total_new_steps |
|
|
|
|
|
|
|
def copy_from(self, other_normalizer: "Normalizer") -> None: |
|
|
|
self.normalization_steps.data.copy_(other_normalizer.normalization_steps.data) |
|
|
|