|
|
|
|
|
|
val = node.attr[attr_name] |
|
|
|
|
|
|
|
if val.HasField("list"): |
|
|
|
return val.list.i |
|
|
|
if len(val.list.shape) > 0: |
|
|
|
return val.list.shape |
|
|
|
else: |
|
|
|
return val.list.i |
|
|
|
if val.HasField("b"): |
|
|
|
return val.b |
|
|
|
if val.HasField("i"): |
|
|
|
|
|
|
|
|
|
|
def get_layer_rank(layer): |
|
|
|
shape = get_attr(layer, "shape") |
|
|
|
if not shape: |
|
|
|
outputShapes = get_attr(layer, "_output_shapes") |
|
|
|
if outputShapes: |
|
|
|
shape = outputShapes[0] |
|
|
|
if not shape: |
|
|
|
return None |
|
|
|
if isinstance(shape, list): |
|
|
|
|
|
|
W = 2 |
|
|
|
C = 3 |
|
|
|
if axis < 0: |
|
|
|
axis = input_rank - axis |
|
|
|
axis = input_rank + axis |
|
|
|
assert axis >= 0 |
|
|
|
assert axis < input_rank |
|
|
|
if input_rank == 4: |
|
|
|