浏览代码

formatting

/MLA-1734-demo-provider
vincentpierre 4 年前
当前提交
ab0dd150
共有 1 个文件被更改,包括 5 次插入1 次删除
  1. 6
      ml-agents/mlagents/trainers/torch/model_serialization.py

6
ml-agents/mlagents/trainers/torch/model_serialization.py


for shape in self.policy.behavior_spec.observation_shapes:
if len(shape) == 1:
vec_obs_size += shape[0]
num_vis_obs = sum(1 for shape in self.policy.behavior_spec.observation_shapes if len(shape) == 3)
num_vis_obs = sum(
1
for shape in self.policy.behavior_spec.observation_shapes
if len(shape) == 3
)
dummy_vec_obs = [torch.zeros(batch_dim + [vec_obs_size])]
# create input shape of NCHW
# (It's NHWC in self.policy.behavior_spec.observation_shapes)

正在加载...
取消
保存