浏览代码

Fix TF Nan bug (#1178)

* Fix for TF NaNs

* New soccer model
/develop-generalizationTraining-TrainerController
GitHub 6 年前
当前提交
ab6eb8dc
共有 4 个文件被更改,包括 1022 次插入9 次删除
  1. 18
      config/trainer_config.yaml
  2. 5
      ml-agents/mlagents/trainers/models.py
  3. 1001
      UnitySDK/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.bytes
  4. 7
      UnitySDK/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.bytes.meta

18
config/trainer_config.yaml


normalize: false
StrikerBrain:
max_steps: 1.0e5
max_steps: 5.0e5
learning_rate: 1e-3
buffer_size: 2048
beta: 5.0e-3
num_epoch: 3
buffer_size: 2000
beta: 1.0e-2
hidden_units: 256
summary_freq: 2000
time_horizon: 128

GoalieBrain:
max_steps: 1.0e5
batch_size: 128
buffer_size: 2048
beta: 5.0e-3
max_steps: 5.0e5
learning_rate: 1e-3
batch_size: 320
num_epoch: 3
buffer_size: 2000
beta: 1.0e-2
hidden_units: 256
summary_freq: 2000
time_horizon: 128

5
ml-agents/mlagents/trainers/models.py


action_idx = [0] + list(np.cumsum(action_size))
branches_logits = [all_logits[:, action_idx[i]:action_idx[i + 1]] for i in range(len(action_size))]
branch_masks = [action_masks[:, action_idx[i]:action_idx[i + 1]] for i in range(len(action_size))]
raw_probs = [tf.multiply(tf.nn.softmax(branches_logits[k]), branch_masks[k]) + (1-branch_masks[k])*1.0e-10
raw_probs = [tf.multiply(tf.nn.softmax(branches_logits[k]), branch_masks[k]) + 1.0e-10
normalized_probs = [tf.divide(raw_probs[k], tf.reduce_sum(raw_probs[k], axis=1, keepdims=True))
normalized_probs = [
tf.divide(raw_probs[k], tf.reduce_sum(raw_probs[k] + 1.0e-10, axis=1, keepdims=True))
for k in range(len(action_size))]
output = tf.concat([tf.multinomial(tf.log(normalized_probs[k]), 1) for k in range(len(action_size))], axis=1)
return output, tf.concat([tf.log(normalized_probs[k]) for k in range(len(action_size))], axis=1)

1001
UnitySDK/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.bytes
文件差异内容过多而无法显示
查看文件

7
UnitySDK/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.bytes.meta


fileFormatVersion: 2
guid: 4856a334c6d4a4984ba1cc6610f31b20
TextScriptImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存