浏览代码

Fix per-block lstm initialization

/develop/add-fire/sac-lst
Ervin Teng 4 年前
当前提交
ea7b85a7
共有 1 个文件被更改,包括 16 次插入4 次删除
  1. 20
      ml-agents/mlagents/trainers/torch/layers.py

20
ml-agents/mlagents/trainers/torch/layers.py


lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=batch_first)
# Add forget_bias to forget gate bias
for name, param in lstm.named_parameters():
# Each weight and bias is a concatenation of 4 matrices
_init_methods[kernel_init](param.data)
elif "bias" in name:
_init_methods[bias_init](param.data)
param.data[hidden_size : 2 * hidden_size].add_(forget_bias)
for idx in range(4):
block_size = param.shape[0] // 4
_init_methods[kernel_init](
param.data[idx * block_size : (idx + 1) * block_size]
)
if "bias" in name:
for idx in range(4):
block_size = param.shape[0] // 4
_init_methods[bias_init](
param.data[idx * block_size : (idx + 1) * block_size]
)
if idx == 1:
param.data[idx * block_size : (idx + 1) * block_size].add_(
forget_bias
)
return lstm
正在加载...
取消
保存