|
|
|
|
|
|
import logging |
|
|
|
from typing import Optional, Any, Dict |
|
|
|
from typing import Optional, Any, Dict, List |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from mlagents.tf_utils import tf |
|
|
|
|
|
|
if self.policy.use_recurrent: |
|
|
|
self.m_size = self.policy.m_size |
|
|
|
self.memory_in = tf.placeholder( |
|
|
|
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" |
|
|
|
shape=[None, self.m_size], |
|
|
|
dtype=tf.float32, |
|
|
|
name="recurrent_value_in", |
|
|
|
) |
|
|
|
|
|
|
|
if num_layers < 1: |
|
|
|
|
|
|
for i, _ in enumerate(self.policy.visual_in): |
|
|
|
feed_dict[self.policy.visual_in[i]] = mini_batch["visual_obs%d" % i] |
|
|
|
if self.policy.use_recurrent: |
|
|
|
mem_in = [ |
|
|
|
np.zeros((self.policy.m_size)) |
|
|
|
for i in range( |
|
|
|
0, mini_batch.num_experiences, self.policy.sequence_length |
|
|
|
) |
|
|
|
] |
|
|
|
feed_dict[self.policy.memory_in] = mem_in |
|
|
|
feed_dict[self.memory_in] = mem_in |
|
|
|
feed_dict[self.policy.memory_in] = self._make_zero_mem( |
|
|
|
self.policy.m_size, mini_batch.num_experiences |
|
|
|
) |
|
|
|
feed_dict[self.memory_in] = self._make_zero_mem( |
|
|
|
self.m_size, mini_batch.num_experiences |
|
|
|
) |
|
|
|
|
|
|
|
def _make_zero_mem(self, m_size: int, length: int) -> List[np.ndarray]: |
|
|
|
return [ |
|
|
|
np.zeros((m_size)) for i in range(0, length, self.policy.sequence_length) |
|
|
|
] |