浏览代码

Fix Barracuda export for LSTM

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
1407db53
共有 1 个文件被更改,包括 5 次插入2 次删除
  1. 7
      ml-agents/mlagents/trainers/common/nn_policy.py

7
ml-agents/mlagents/trainers/common/nn_policy.py


self.trainable_variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy"
)
self.trainable_variables += tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm"
) # LSTMs need to be root scope for Barracuda export
self.inference_dict: Dict[str, tf.Tensor] = {
"action": self.output,

hidden_stream,
self.memory_in,
self.sequence_length_ph,
name="policy/lstm",
name="lstm_policy",
)
self.memory_out = tf.identity(memory_policy_out, name="recurrent_out")

hidden_policy,
self.memory_in,
self.sequence_length_ph,
name="policy/lstm",
name="lstm_policy",
)
self.memory_out = tf.identity(memory_policy_out, "recurrent_out")

正在加载...
取消
保存