浏览代码

Fixes for recurrent

/develop-newnormalization
Ervin Teng 5 年前
当前提交
5ab2563b
共有 2 个文件被更改,包括 2 次插入2 次删除
  1. 2
      ml-agents/mlagents/trainers/ppo/policy.py
  2. 2
      ml-agents/mlagents/trainers/trajectory.py

2
ml-agents/mlagents/trainers/ppo/policy.py


def get_batched_value_estimates(self, batch: AgentBuffer) -> Dict[str, np.ndarray]:
feed_dict: Dict[tf.Tensor, Any] = {
self.model.batch_size: batch.num_experiences,
self.model.sequence_length: self.sequence_length,
self.model.sequence_length: 1, # We want to feed data in batch-wise, not time-wise.
}
if self.use_vec_obs:

2
ml-agents/mlagents/trainers/trajectory.py


agent_buffer_trajectory["next_vector_in"].append(
next_vec_vis_obs.vector_observations
)
if exp.memory:
if exp.memory is not None:
agent_buffer_trajectory["memory"].append(exp.memory)
agent_buffer_trajectory["masks"].append(1.0)

正在加载...
取消
保存