浏览代码

backport tf2bc changes from barracuda-release (#3341)

/asymm-envs
GitHub 5 年前
当前提交
63959125
共有 1 个文件被更改,包括 9 次插入2 次删除
  1. 11
      ml-agents/mlagents/trainers/tensorflow_to_barracuda.py

11
ml-agents/mlagents/trainers/tensorflow_to_barracuda.py


val = node.attr[attr_name]
if val.HasField("list"):
return val.list.i
if len(val.list.shape) > 0:
return val.list.shape
else:
return val.list.i
if val.HasField("b"):
return val.b
if val.HasField("i"):

def get_layer_rank(layer):
shape = get_attr(layer, "shape")
if not shape:
outputShapes = get_attr(layer, "_output_shapes")
if outputShapes:
shape = outputShapes[0]
if not shape:
return None
if isinstance(shape, list):

W = 2
C = 3
if axis < 0:
axis = input_rank - axis
axis = input_rank + axis
assert axis >= 0
assert axis < input_rank
if input_rank == 4:

正在加载...
取消
保存