浏览代码

Add normalizer update context

/develop/add-fire/normalize-context
Ervin Teng 4 年前
当前提交
bd97532d
共有 2 个文件被更改,包括 44 次插入6 次删除
  1. 15
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 35
      ml-agents/mlagents/trainers/torch/encoders.py

15
ml-agents/mlagents/trainers/policy/torch_policy.py


from mlagents.trainers.settings import TrainerSettings, TestingConfiguration
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.torch.networks import ActorCritic
from mlagents.trainers.torch.encoders import Normalizer
EPSILON = 1e-7 # Small value to avoid divide by zero

If this policy normalizes vector observations, this will update the norm values in the graph.
:param vector_obs: The vector observations to add to the running estimate of the distribution.
"""
vector_obs = [torch.as_tensor(vector_obs)]
if self.use_vec_obs and self.normalize:
self.actor_critic.update_normalization(vector_obs)
pass
# vector_obs = [torch.as_tensor(vector_obs)]
# if self.use_vec_obs and self.normalize:
# self.actor_critic.update_normalization(vector_obs)
@timed
def sample_actions(

run_out = {}
with torch.no_grad():
action, log_probs, entropy, value_heads, memories = self.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories
)
with Normalizer.update_normalizer():
action, log_probs, entropy, value_heads, memories = self.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories
)
run_out["action"] = action.detach().cpu().numpy()
run_out["pre_action"] = action.detach().cpu().numpy()
# Todo - make pre_action difference

35
ml-agents/mlagents/trainers/torch/encoders.py


from typing import Tuple, Optional
import threading
from mlagents.trainers.exception import UnityTrainerException

class Normalizer(nn.Module):
class update_normalizer:
"""
Helper class that allows context. Set this context by calling
```
with Normalizer.update_normalizer()
```
All Normalizers executed with this context will also be updated.
"""
_local_data = threading.local()
_local_data.must_update = False
@classmethod
def __enter__(cls):
Normalizer.update_normalizer._local_data.must_update = True
@classmethod
def __exit__(cls, *args):
Normalizer.update_normalizer._local_data.must_update = False
@staticmethod
def must_update():
try:
return Normalizer.update_normalizer._local_data.must_update
except AttributeError:
return False
def __init__(self, vec_obs_size: int):
super().__init__()
self.normalization_steps = torch.tensor(1)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if Normalizer.update_normalizer.must_update():
self.update(inputs)
normalized_state = torch.clamp(
(inputs - self.running_mean)
/ torch.sqrt(self.running_variance / self.normalization_steps),

return normalized_state
def update(self, vector_input: torch.Tensor) -> None:
"""
Updates the normalizer based on the input.
Note: this will be made a private method in the future.
"""
steps_increment = vector_input.size()[0]
total_new_steps = self.normalization_steps + steps_increment

正在加载...
取消
保存