浏览代码

Barracuda hotfix for LSTM and tests (#2352)

* Removed obsolete 'TestDstWrongShape' test as it does not reflect how Barracuda tensors work
* Added proper test cleanup, to avoid warning messages from finalizer thread.
* Hotfix for recurrent + continous action nets in ML Agents
/develop-generalizationTraining-TrainerController
Ervin T 5 年前
当前提交
7cfce1a9
共有 1 个文件被更改,包括 15 次插入5 次删除
  1. 20
      ml-agents/mlagents/trainers/tensorflow_to_barracuda.py

20
ml-agents/mlagents/trainers/tensorflow_to_barracuda.py


id=1,
rank=2,
out_shapes=lambda shapes: [
[shapes[0][0], 1, 1, shapes[0][1]], # W
[shapes[0][0], 1, 1, shapes[0][1]]
if len(shapes[0]) > 1
else [1, 1, 1, 1], # W
[1, 1, 1, shapes[-1][-1]], # B
],
patch_data=lambda data: [data[0], data[1]],

"ConcatV2",
"Identity",
]
): "BasicLSTM",
repr([re.compile("^lstm/"), "Reshape", "ConcatV2", "Identity"]): "BasicLSTM",
repr(["Reshape", re.compile("^lstm_[a-z]*/"), "Reshape", "ConcatV2"]): "BasicLSTM",
): "BasicLSTMReshapeOut",
repr(
[re.compile("^lstm/"), "Reshape", "ConcatV2", "Identity"]
): "BasicLSTMReshapeOut",
repr(
["Reshape", re.compile("^lstm_[a-z]*/"), "Reshape", "ConcatV2"]
): "BasicLSTMReshapeOut",
repr(["Reshape", re.compile("^lstm_[a-z]*/"), "ConcatV2"]): "BasicLSTMConcatOut",
repr(["Sigmoid", "Mul"]): "Swish",
repr(["Mul", "Abs", "Mul", "Add"]): "LeakyRelu",
repr(

"SquaredDifference": lambda nodes, inputs, tensors, _: sqr_diff(
nodes[-1].name, inputs[0], inputs[1]
),
"BasicLSTM": lambda nodes, inputs, tensors, context: basic_lstm(
"BasicLSTMReshapeOut": lambda nodes, inputs, tensors, context: basic_lstm(
),
"BasicLSTMConcatOut": lambda nodes, inputs, tensors, context: basic_lstm(
nodes, inputs, tensors, context, find_type="ConcatV2"
),
"Swish": lambda nodes, inputs, tensors, _: Struct(op="Swish", input=inputs),
"LeakyRelu": lambda nodes, inputs, tensors, _: Struct(op="LeakyRelu", input=inputs),

正在加载...
取消
保存