浏览代码

Cleanup LSTM code

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
7f53bf8b
共有 2 个文件被更改,包括 18 次插入13 次删除
  1. 6
      ml-agents/mlagents/trainers/common/nn_policy.py
  2. 25
      ml-agents/mlagents/trainers/ppo/optimizer.py

6
ml-agents/mlagents/trainers/common/nn_policy.py


if self.use_continuous_act:
self.inference_dict["pre_action"] = self.output_pre
if self.use_recurrent:
self.inference_dict["policy_memory_out"] = self.memory_out
self.inference_dict["memory_out"] = self.memory_out
@timed
def evaluate(

name="lstm_policy",
)
self.memory_out = memory_policy_out
self.memory_out = tf.identity(memory_policy_out, name="recurrent_out")
else:
hidden_policy = hidden_stream

name="lstm_policy",
)
self.memory_out = memory_policy_out
self.memory_out = tf.identity(memory_policy_out, "recurrent_out")
else:
hidden_policy = hidden_stream

25
ml-agents/mlagents/trainers/ppo/optimizer.py


import logging
from typing import Optional, Any, Dict
from typing import Optional, Any, Dict, List
import numpy as np
from mlagents.tf_utils import tf

if self.policy.use_recurrent:
self.m_size = self.policy.m_size
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
shape=[None, self.m_size],
dtype=tf.float32,
name="recurrent_value_in",
)
if num_layers < 1:

for i, _ in enumerate(self.policy.visual_in):
feed_dict[self.policy.visual_in[i]] = mini_batch["visual_obs%d" % i]
if self.policy.use_recurrent:
mem_in = [
np.zeros((self.policy.m_size))
for i in range(
0, mini_batch.num_experiences, self.policy.sequence_length
)
]
feed_dict[self.policy.memory_in] = mem_in
feed_dict[self.memory_in] = mem_in
feed_dict[self.policy.memory_in] = self._make_zero_mem(
self.policy.m_size, mini_batch.num_experiences
)
feed_dict[self.memory_in] = self._make_zero_mem(
self.m_size, mini_batch.num_experiences
)
def _make_zero_mem(self, m_size: int, length: int) -> List[np.ndarray]:
return [
np.zeros((m_size)) for i in range(0, length, self.policy.sequence_length)
]
正在加载...
取消
保存