浏览代码

use int64 steps, check for NaN actions (#4607)

* use int64 steps

* check for NaN actions

Co-authored-by: Ruo-Ping Dong <ruoping.dong@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
87a7ccf8
共有 9 个文件被更改,包括 52 次插入5 次删除
  1. 2
      com.unity.ml-agents/CHANGELOG.md
  2. 10
      ml-agents/mlagents/trainers/policy/policy.py
  3. 2
      ml-agents/mlagents/trainers/policy/tf_policy.py
  4. 1
      ml-agents/mlagents/trainers/policy/torch_policy.py
  5. 20
      ml-agents/mlagents/trainers/tests/tensorflow/test_nn_policy.py
  6. 8
      ml-agents/mlagents/trainers/tests/torch/test_policy.py
  7. 8
      ml-agents/mlagents/trainers/tf/models.py
  8. 2
      ml-agents/mlagents/trainers/tf/tensorflow_to_barracuda.py
  9. 4
      ml-agents/mlagents/trainers/torch/networks.py

2
com.unity.ml-agents/CHANGELOG.md


Previously, this would result in an infinite loop and cause the editor to hang. (#4573)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Fixed an issue where runs could not be resumed when using TensorFlow and Ghost Training. (#4593)
- Change the tensor type of step count from int32 to int64 to address the overflow issue when step
goes larger than 2^31. Previous Tensorflow checkpoints will become incompatible and cannot be loaded. (#4607)
## [1.5.0-preview] - 2020-10-14

10
ml-agents/mlagents/trainers/policy/policy.py


) -> ActionInfo:
raise NotImplementedError
@staticmethod
def check_nan_action(action: Optional[np.ndarray]) -> None:
# Fast NaN check on the action
# See https://stackoverflow.com/questions/6736590/fast-check-for-nan-in-numpy for background.
if action is not None:
d = np.sum(action)
has_nan = np.isnan(d)
if has_nan:
raise RuntimeError("NaN action detected.")
@abstractmethod
def update_normalization(self, vector_obs: np.ndarray) -> None:
pass

2
ml-agents/mlagents/trainers/policy/tf_policy.py


)
self.save_memories(global_agent_ids, run_out.get("memory_out"))
self.check_nan_action(run_out.get("action"))
return ActionInfo(
action=run_out.get("action"),
value=run_out.get("value"),

1
ml-agents/mlagents/trainers/policy/torch_policy.py


decision_requests, global_agent_ids
) # pylint: disable=assignment-from-no-return
self.save_memories(global_agent_ids, run_out.get("memory_out"))
self.check_nan_action(run_out.get("action"))
return ActionInfo(
action=run_out.get("action"),
value=run_out.get("value"),

20
ml-agents/mlagents/trainers/tests/tensorflow/test_nn_policy.py


enc_func(vis_input, 32, ModelUtils.swish, 1, "test", False)
def test_step_overflow():
behavior_spec = mb.setup_test_behavior_specs(
use_discrete=True, use_visual=False, vector_action_space=[2], vector_obs_space=1
)
policy = TFPolicy(
0,
behavior_spec,
TrainerSettings(network_settings=NetworkSettings(normalize=True)),
create_tf_graph=False,
)
policy.create_input_placeholders()
policy.initialize()
policy.set_step(2 ** 31 - 1)
assert policy.get_current_step() == 2 ** 31 - 1
policy.increment_step(3)
assert policy.get_current_step() == 2 ** 31 + 2
if __name__ == "__main__":
pytest.main()

8
ml-agents/mlagents/trainers/tests/torch/test_policy.py


if rnn:
assert memories.shape == (1, 1, policy.m_size)
def test_step_overflow():
policy = create_policy_mock(TrainerSettings())
policy.set_step(2 ** 31 - 1)
assert policy.get_current_step() == 2 ** 31 - 1 # step = 2147483647
policy.increment_step(3)
assert policy.get_current_step() == 2 ** 31 + 2 # step = 2147483650

8
ml-agents/mlagents/trainers/tf/models.py


def create_global_steps():
"""Creates TF ops to track and increment global training step."""
global_step = tf.Variable(
0, name="global_step", trainable=False, dtype=tf.int32
0, name="global_step", trainable=False, dtype=tf.int64
shape=[], dtype=tf.int32, name="steps_to_increment"
shape=[], dtype=tf.int64, name="steps_to_increment"
)
increment_step = tf.assign(global_step, tf.add(global_step, steps_to_increment))
return global_step, increment_step, steps_to_increment

"normalization_steps",
[],
trainable=False,
dtype=tf.int32,
dtype=tf.int64,
initializer=tf.zeros_initializer(),
)
running_mean = tf.get_variable(

# Based on Welford's algorithm for running mean and standard deviation, for batch updates. Discussion here:
# https://stackoverflow.com/questions/56402955/whats-the-formula-for-welfords-algorithm-for-variance-std-with-batch-updates
steps_increment = tf.shape(vector_input)[0]
total_new_steps = tf.add(steps, steps_increment)
total_new_steps = tf.add(steps, tf.cast(steps_increment, dtype=tf.int64))
# Compute the incremental update and divide by the number of new steps.
input_to_old_mean = tf.subtract(vector_input, running_mean)

2
ml-agents/mlagents/trainers/tf/tensorflow_to_barracuda.py


data = tensor.float_val
if tensor.int_val:
data = np.array(tensor.int_val, dtype=float)
if tensor.int64_val:
data = np.array(tensor.int64_val, dtype=float)
if tensor.bool_val:
data = np.array(tensor.bool_val, dtype=float)
return np.array(data).reshape(dims)

4
ml-agents/mlagents/trainers/torch/networks.py


class GlobalSteps(nn.Module):
def __init__(self):
super().__init__()
self.__global_step = nn.Parameter(torch.Tensor([0]), requires_grad=False)
self.__global_step = nn.Parameter(
torch.Tensor([0]).to(torch.int64), requires_grad=False
)
@property
def current_step(self):

正在加载...
取消
保存