浏览代码

Fix the Python Tests (#1327)

/develop-generalizationTraining-TrainerController
GitHub 6 年前
当前提交
2b6b4570
共有 1 个文件被更改,包括 14 次插入9 次删除
  1. 23
      ml-agents/tests/trainers/test_ppo.py

23
ml-agents/tests/trainers/test_ppo.py


feed_dict = {model.batch_size: 2,
model.sequence_length: 1,
model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
[3, 4, 5, 3, 4, 5]])}
[3, 4, 5, 3, 4, 5]],),
model.epsilon: np.array([[0, 1], [2, 3]])}
sess.run(run_list, feed_dict=feed_dict)
env.close()

model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
[3, 4, 5, 3, 4, 5]]),
model.visual_in[0]: np.ones([2, 40, 30, 3]),
model.visual_in[1]: np.ones([2, 40, 30, 3])}
model.visual_in[1]: np.ones([2, 40, 30, 3]),
model.epsilon: np.array([[0, 1], [2, 3]])}
sess.run(run_list, feed_dict=feed_dict)
env.close()

[3, 4, 5, 3, 4, 5]]),
model.visual_in[0]: np.ones([2, 40, 30, 3]),
model.visual_in[1]: np.ones([2, 40, 30, 3]),
model.action_masks: np.ones([2,2])
model.action_masks: np.ones([2, 2],)
}
sess.run(run_list, feed_dict=feed_dict)
env.close()

model.sequence_length: 1,
model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
[3, 4, 5, 3, 4, 5]]),
model.action_masks: np.ones([2,2])}
model.action_masks: np.ones([2, 2])}
sess.run(run_list, feed_dict=feed_dict)
env.close()

model.memory_in: np.zeros((1, memory_size)),
model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
[3, 4, 5, 3, 4, 5]]),
model.action_masks: np.ones([1,2])}
model.action_masks: np.ones([1, 2])}
sess.run(run_list, feed_dict=feed_dict)
env.close()

model.sequence_length: 2,
model.memory_in: np.zeros((1, memory_size)),
model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
[3, 4, 5, 3, 4, 5]])}
[3, 4, 5, 3, 4, 5]]),
model.epsilon: np.array([[0, 1]])}
sess.run(run_list, feed_dict=feed_dict)
env.close()

[3, 4, 5, 3, 4, 5]]),
model.next_vector_in: np.array([[1, 2, 3, 1, 2, 3],
[3, 4, 5, 3, 4, 5]]),
model.output: [[0.0, 0.0], [0.0, 0.0]]}
model.output: [[0.0, 0.0], [0.0, 0.0]],
model.epsilon: np.array([[0, 1], [2, 3]])}
sess.run(run_list, feed_dict=feed_dict)
env.close()

model.visual_in[1]: np.ones([2, 40, 30, 3]),
model.next_visual_in[0]: np.ones([2, 40, 30, 3]),
model.next_visual_in[1]: np.ones([2, 40, 30, 3]),
model.action_masks: np.ones([2,2])
model.action_masks: np.ones([2, 2])
}
sess.run(run_list, feed_dict=feed_dict)
env.close()

model.visual_in[0]: np.ones([2, 40, 30, 3]),
model.visual_in[1]: np.ones([2, 40, 30, 3]),
model.next_visual_in[0]: np.ones([2, 40, 30, 3]),
model.next_visual_in[1]: np.ones([2, 40, 30, 3])
model.next_visual_in[1]: np.ones([2, 40, 30, 3]),
model.epsilon: np.array([[0, 1], [2, 3]])
}
sess.run(run_list, feed_dict=feed_dict)
env.close()

正在加载...
取消
保存