浏览代码

Initialize normalizer with mean/variance from first trajectory (#4299)

Co-authored-by: Chris Elion <chris.elion@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
9dc1d99e
共有 5 个文件被更改,包括 147 次插入14 次删除
  1. 3
      com.unity.ml-agents/CHANGELOG.md
  2. 16
      ml-agents/mlagents/trainers/policy/tf_policy.py
  3. 3
      ml-agents/mlagents/trainers/tests/mock_brain.py
  4. 113
      ml-agents/mlagents/trainers/tests/test_nn_policy.py
  5. 26
      ml-agents/mlagents/trainers/tf/models.py

3
com.unity.ml-agents/CHANGELOG.md


Previously, this would result in an infinite loop and cause the editor to hang.
(#4226)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- The algorithm used to normalize observations was introducing NaNs if the initial observations were too large
due to incorrect initialization. The initialization was fixed and is now the observation means from the
first trajectory processed. (#4299)
## [1.2.0-preview] - 2020-07-15

16
ml-agents/mlagents/trainers/policy/tf_policy.py


self.assign_ops: List[tf.Operation] = []
self.update_dict: Dict[str, tf.Tensor] = {}
self.inference_dict: Dict[str, tf.Tensor] = {}
self.first_normalization_update: bool = False
self.graph = tf.Graph()
self.sess = tf.Session(

:param vector_obs: The vector observations to add to the running estimate of the distribution.
"""
if self.use_vec_obs and self.normalize:
self.sess.run(
self.update_normalization_op, feed_dict={self.vector_in: vector_obs}
)
if self.first_normalization_update:
self.sess.run(
self.init_normalization_op, feed_dict={self.vector_in: vector_obs}
)
self.first_normalization_update = False
else:
self.sess.run(
self.update_normalization_op, feed_dict={self.vector_in: vector_obs}
)
@property
def use_vis_obs(self):

self.normalization_steps: Optional[tf.Variable] = None
self.running_mean: Optional[tf.Variable] = None
self.running_variance: Optional[tf.Variable] = None
self.init_normalization_op: Optional[tf.Operation] = None
self.update_normalization_op: Optional[tf.Operation] = None
self.value: Optional[tf.Tensor] = None
self.all_log_probs: tf.Tensor = None

self.behavior_spec.observation_shapes
)
if self.normalize:
self.first_normalization_update = True
self.init_normalization_op = normalization_tensors.init_op
self.normalization_steps = normalization_tensors.steps
self.running_mean = normalization_tensors.running_mean
self.running_variance = normalization_tensors.running_variance

3
ml-agents/mlagents/trainers/tests/mock_brain.py


memory=memory,
)
steps_list.append(experience)
obs = []
for _shape in observation_shapes:
obs.append(np.ones(_shape, dtype=np.float32))
last_experience = AgentExperience(
obs=obs,
reward=reward,

113
ml-agents/mlagents/trainers/tests/test_nn_policy.py


DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 32
NUM_AGENTS = 12
EPSILON = 1e-7
def create_policy_mock(

assert run_out["action"].shape == (NUM_AGENTS, VECTOR_ACTION_SPACE)
def test_large_normalization():
behavior_spec = mb.setup_test_behavior_specs(
use_discrete=True, use_visual=False, vector_action_space=[2], vector_obs_space=1
)
# Taken from Walker seed 3713 which causes NaN without proper initialization
large_obs1 = [
1800.00036621,
1799.96972656,
1800.01245117,
1800.07214355,
1800.02758789,
1799.98303223,
1799.88647461,
1799.89575195,
1800.03479004,
1800.14025879,
1800.17675781,
1800.20581055,
1800.33740234,
1800.36450195,
1800.43457031,
1800.45544434,
1800.44604492,
1800.56713867,
1800.73901367,
]
large_obs2 = [
1799.99975586,
1799.96679688,
1799.92980957,
1799.89550781,
1799.93774414,
1799.95300293,
1799.94067383,
1799.92993164,
1799.84057617,
1799.69873047,
1799.70605469,
1799.82849121,
1799.85095215,
1799.76977539,
1799.78283691,
1799.76708984,
1799.67163086,
1799.59191895,
1799.5135498,
1799.45556641,
1799.3717041,
]
policy = TFPolicy(
0,
behavior_spec,
TrainerSettings(network_settings=NetworkSettings(normalize=True)),
"testdir",
False,
)
time_horizon = len(large_obs1)
trajectory = make_fake_trajectory(
length=time_horizon,
max_step_complete=True,
observation_shapes=[(1,)],
action_space=[2],
)
for i in range(time_horizon):
trajectory.steps[i].obs[0] = np.array([large_obs1[i]], dtype=np.float32)
trajectory_buffer = trajectory.to_agentbuffer()
policy.update_normalization(trajectory_buffer["vector_obs"])
# Check that the running mean and variance is correct
steps, mean, variance = policy.sess.run(
[policy.normalization_steps, policy.running_mean, policy.running_variance]
)
assert mean[0] == pytest.approx(np.mean(large_obs1, dtype=np.float32), abs=0.01)
assert variance[0] / steps == pytest.approx(
np.var(large_obs1, dtype=np.float32), abs=0.01
)
time_horizon = len(large_obs2)
trajectory = make_fake_trajectory(
length=time_horizon,
max_step_complete=True,
observation_shapes=[(1,)],
action_space=[2],
)
for i in range(time_horizon):
trajectory.steps[i].obs[0] = np.array([large_obs2[i]], dtype=np.float32)
trajectory_buffer = trajectory.to_agentbuffer()
policy.update_normalization(trajectory_buffer["vector_obs"])
steps, mean, variance = policy.sess.run(
[policy.normalization_steps, policy.running_mean, policy.running_variance]
)
assert mean[0] == pytest.approx(
np.mean(large_obs1 + large_obs2, dtype=np.float32), abs=0.01
)
assert variance[0] / steps == pytest.approx(
np.var(large_obs1 + large_obs2, dtype=np.float32), abs=0.01
)
time_horizon = 6
trajectory = make_fake_trajectory(
length=time_horizon,

assert steps == 6
assert mean[0] == 0.5
# Note: variance is divided by number of steps, and initialized to 1 to avoid
# divide by 0. The right answer is 0.25
assert (variance[0] - 1) / steps == 0.25
# Note: variance is initalized to the variance of the initial trajectory + EPSILON
# (to avoid divide by 0) and multiplied by the number of steps. The correct answer is 0.25
assert variance[0] / steps == pytest.approx(0.25, abs=0.01)
# Make another update, this time with all 1's
time_horizon = 10
trajectory = make_fake_trajectory(

assert steps == 16
assert mean[0] == 0.8125
assert (variance[0] - 1) / steps == pytest.approx(0.152, abs=0.01)
assert variance[0] / steps == pytest.approx(0.152, abs=0.01)
def test_min_visual_size():

26
ml-agents/mlagents/trainers/tf/models.py


class NormalizerTensors(NamedTuple):
init_op: tf.Operation
update_op: tf.Operation
steps: tf.Tensor
running_mean: tf.Tensor

:return: A NormalizerTensors tuple that holds running mean, running variance, number of steps,
and the update operation.
"""
steps = tf.get_variable(
"normalization_steps",
[],

dtype=tf.float32,
initializer=tf.ones_initializer(),
)
update_normalization = ModelUtils.create_normalizer_update(
initialize_normalization, update_normalization = ModelUtils.create_normalizer_update(
update_normalization, steps, running_mean, running_variance
initialize_normalization,
update_normalization,
steps,
running_mean,
running_variance,
)
@staticmethod

running_mean: tf.Tensor,
running_variance: tf.Tensor,
) -> tf.Operation:
) -> Tuple[tf.Operation, tf.Operation]:
"""
Creates the update operation for the normalizer.
:param vector_input: Vector observation to use for updating the running mean and variance.

update_mean = tf.assign(running_mean, new_mean)
update_variance = tf.assign(running_variance, new_variance)
update_norm_step = tf.assign(steps, total_new_steps)
return tf.group([update_mean, update_variance, update_norm_step])
# First mean and variance calculated normally
initial_mean, initial_variance = tf.nn.moments(vector_input, axes=[0])
initialize_mean = tf.assign(running_mean, initial_mean)
# Multiplied by total_new_step because it is divided by total_new_step in the normalization
initialize_variance = tf.assign(
running_variance,
(initial_variance + EPSILON) * tf.cast(total_new_steps, dtype=tf.float32),
)
return (
tf.group([initialize_mean, initialize_variance, update_norm_step]),
tf.group([update_mean, update_variance, update_norm_step]),
)
@staticmethod
def create_vector_observation_encoder(

正在加载...
取消
保存