浏览代码

add comment

/develop/add-fire/export-discrete
Ruo-Ping Dong 4 年前
当前提交
964a2f76
共有 1 个文件被更改,包括 6 次插入0 次删除
  1. 6
      ml-agents/mlagents/trainers/torch/model_serialization.py

6
ml-agents/mlagents/trainers/torch/model_serialization.py


class ModelSerializer:
def __init__(self, policy):
# ONNX only support input in NCHW (channel first) format.
# Barracuda also expect to get data in NCHW.
# Any multi-dimentional input should follow that otherwise will
# cause problem to barracuda import.
# create input shape of NCHW
# (It's NHWC in self.policy.behavior_spec.observation_shapes)
dummy_vis_obs = [
torch.zeros(batch_dim + [shape[2], shape[0], shape[1]])
for shape in self.policy.behavior_spec.observation_shapes

正在加载...
取消
保存