|
|
|
|
|
|
|
|
|
|
class ModelSerializer: |
|
|
|
def __init__(self, policy): |
|
|
|
# ONNX only support input in NCHW (channel first) format. |
|
|
|
# Barracuda also expect to get data in NCHW. |
|
|
|
# Any multi-dimentional input should follow that otherwise will |
|
|
|
# cause problem to barracuda import. |
|
|
|
# create input shape of NCHW |
|
|
|
# (It's NHWC in self.policy.behavior_spec.observation_shapes) |
|
|
|
dummy_vis_obs = [ |
|
|
|
torch.zeros(batch_dim + [shape[2], shape[0], shape[1]]) |
|
|
|
for shape in self.policy.behavior_spec.observation_shapes |
|
|
|