浏览代码

Buffer fixes

/develop/action-slice
Ervin Teng 4 年前
当前提交
b3958a8d
共有 2 个文件被更改,包括 7 次插入8 次删除
  1. 11
      ml-agents/mlagents/trainers/buffer.py
  2. 4
      ml-agents/mlagents/trainers/coma/trainer.py

11
ml-agents/mlagents/trainers/buffer.py


)
if batch_size * training_length > len(self):
padding = np.array(self[-1], dtype=np.float32) * self.padding_value
return np.array(
[padding] * (training_length - leftover) + self[:], dtype=np.float32
)
return [padding] * (training_length - leftover) + self[:]
return np.array(
self[len(self) - batch_size * training_length :], dtype=np.float32
)
return self[len(self) - batch_size * training_length :]
else:
# The sequences will have overlapping elements
if batch_size is None:

tmp_list: List[np.ndarray] = []
for end in range(len(self) - batch_size + 1, len(self) + 1):
tmp_list += self[end - training_length : end]
return np.array(tmp_list, dtype=np.float32)
return tmp_list
def reset_field(self) -> None:
"""

4
ml-agents/mlagents/trainers/coma/trainer.py


int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)
advantages = self.update_buffer[BufferKey.ADVANTAGES].get_batch()
advantages = np.array(
self.update_buffer[BufferKey.ADVANTAGES].get_batch(), dtype=np.float32
)
self.update_buffer[BufferKey.ADVANTAGES].set(
(advantages - advantages.mean()) / (advantages.std() + 1e-10)
)

正在加载...
取消
保存