浏览代码

Merge pull request #1427 from Unity-Technologies/coldfix-action-masking

Cold fix for action masking
/develop-generalizationTraining-TrainerController
GitHub 6 年前
当前提交
5b68086c
共有 1 个文件被更改,包括 4 次插入4 次删除
  1. 8
      ml-agents/mlagents/trainers/models.py

8
ml-agents/mlagents/trainers/models.py


action_idx = [0] + list(np.cumsum(action_size))
branches_logits = [all_logits[:, action_idx[i]:action_idx[i + 1]] for i in range(len(action_size))]
branch_masks = [action_masks[:, action_idx[i]:action_idx[i + 1]] for i in range(len(action_size))]
raw_probs = [tf.multiply(tf.nn.softmax(branches_logits[k]), branch_masks[k]) + 1.0e-10
raw_probs = [tf.multiply(tf.nn.softmax(branches_logits[k]) + 1.0e-10, branch_masks[k])
tf.divide(raw_probs[k], tf.reduce_sum(raw_probs[k] + 1.0e-10, axis=1, keepdims=True))
for k in range(len(action_size))]
tf.divide(raw_probs[k], tf.reduce_sum(raw_probs[k], axis=1, keepdims=True))
for k in range(len(action_size))]
return output, tf.concat([tf.log(normalized_probs[k]) for k in range(len(action_size))], axis=1)
return output, tf.concat([tf.log(normalized_probs[k] + 1.0e-10) for k in range(len(action_size))], axis=1)
def create_observation_streams(self, num_streams, h_size, num_layers):
"""

正在加载...
取消
保存