浏览代码

Fix for visual observations (#847)

/develop-generalizationTraining-TrainerController
GitHub 6 年前
当前提交
8526dcfc
共有 1 个文件被更改,包括 11 次插入5 次删除
  1. 16
      python/unitytrainers/ppo/trainer.py

16
python/unitytrainers/ppo/trainer.py


feed_dict = {self.model.batch_size: 1, self.model.sequence_length: 1}
if self.use_visual_obs:
for i in range(len(brain_info.visual_observations)):
feed_dict[self.model.visual_in[i]] = brain_info.visual_observations[i][idx]
feed_dict[self.model.visual_in[i]] = [brain_info.visual_observations[i][idx]]
if self.use_vector_obs:
feed_dict[self.model.vector_in] = [brain_info.vector_observations[idx]]
if self.use_recurrent:

if self.use_visual_obs:
for i, _ in enumerate(self.model.visual_in):
_obs = np.array(buffer['visual_obs%d' % i][start:end])
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.model.visual_in[i]] = _obs.reshape([-1, _w, _h, _c])
if self.sequence_length > 1 and self.use_recurrent:
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.model.visual_in[i]] = _obs.reshape([-1, _w, _h, _c])
else:
feed_dict[self.model.visual_in[i]] = _obs
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.model.next_visual_in[i]] = _obs.reshape([-1, _w, _h, _c])
if self.sequence_length > 1 and self.use_recurrent:
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.model.next_visual_in[i]] = _obs.reshape([-1, _w, _h, _c])
else:
feed_dict[self.model.next_visual_in[i]] = _obs
if self.use_recurrent:
mem_in = np.array(buffer['memory'][start:end])[:, 0, :]
feed_dict[self.model.memory_in] = mem_in

正在加载...
取消
保存