浏览代码

Fix normalization

/develop/add-fire
Ervin Teng 5 年前
当前提交
21a8de45
共有 1 个文件被更改,包括 12 次插入7 次删除
  1. 19
      ml-agents/mlagents/trainers/models_torch.py

19
ml-agents/mlagents/trainers/models_torch.py


return normalized_state
def update(self, vector_input):
mean_current_observation = vector_input.mean(0).type(torch.float32)
steps_increment = vector_input.size()[0]
total_new_steps = self.normalization_steps + steps_increment
input_to_old_mean = vector_input - self.running_mean
mean_current_observation - self.running_mean
) / (self.normalization_steps + 1).type(torch.float32)
new_variance = self.running_variance + (mean_current_observation - new_mean) * (
mean_current_observation - self.running_mean
)
input_to_old_mean / total_new_steps.type(torch.float32)
).sum(0)
input_to_new_mean = vector_input - new_mean
new_variance = self.running_variance + (
input_to_new_mean * input_to_old_mean
).sum(0)
self.normalization_steps = self.normalization_steps + 1
self.normalization_steps = total_new_steps
class ValueHeads(nn.Module):

正在加载...
取消
保存