|
|
|
|
|
|
return normalized_state |
|
|
|
|
|
|
|
def update(self, vector_input): |
|
|
|
mean_current_observation = vector_input.mean(0).type(torch.float32) |
|
|
|
steps_increment = vector_input.size()[0] |
|
|
|
total_new_steps = self.normalization_steps + steps_increment |
|
|
|
|
|
|
|
input_to_old_mean = vector_input - self.running_mean |
|
|
|
mean_current_observation - self.running_mean |
|
|
|
) / (self.normalization_steps + 1).type(torch.float32) |
|
|
|
new_variance = self.running_variance + (mean_current_observation - new_mean) * ( |
|
|
|
mean_current_observation - self.running_mean |
|
|
|
) |
|
|
|
input_to_old_mean / total_new_steps.type(torch.float32) |
|
|
|
).sum(0) |
|
|
|
|
|
|
|
input_to_new_mean = vector_input - new_mean |
|
|
|
new_variance = self.running_variance + ( |
|
|
|
input_to_new_mean * input_to_old_mean |
|
|
|
).sum(0) |
|
|
|
self.normalization_steps = self.normalization_steps + 1 |
|
|
|
self.normalization_steps = total_new_steps |
|
|
|
|
|
|
|
|
|
|
|
class ValueHeads(nn.Module): |
|
|
|