浏览代码

Add test for lstm layer

/develop/add-fire/sac-lst
Ervin Teng 4 年前
当前提交
69bae3cc
共有 1 个文件被更改,包括 21 次插入1 次删除
  1. 22
      ml-agents/mlagents/trainers/tests/torch/test_layers.py

22
ml-agents/mlagents/trainers/tests/torch/test_layers.py


import torch
from mlagents.trainers.torch.layers import Swish, linear_layer, Initialization
from mlagents.trainers.torch.layers import (
Swish,
linear_layer,
lstm_layer,
Initialization,
)
def test_swish():

)
assert torch.all(torch.eq(layer.weight.data, torch.zeros_like(layer.weight.data)))
assert torch.all(torch.eq(layer.bias.data, torch.zeros_like(layer.bias.data)))
def test_lstm_layer():
torch.manual_seed(0)
# Test zero for LSTM
layer = lstm_layer(
4, 4, kernel_init=Initialization.Zero, bias_init=Initialization.Zero
)
for name, param in layer.named_parameters():
if "weight" in name:
assert torch.all(torch.eq(param.data, torch.zeros_like(param.data)))
elif "bias" in name:
assert torch.all(
torch.eq(param.data[4:8], torch.ones_like(param.data[4:8]))
)
正在加载...
取消
保存