浏览代码

fixing tests

/MLA-1734-demo-provider
vincentpierre 4 年前
当前提交
a482a47c
共有 1 个文件被更改,包括 5 次插入5 次删除
  1. 10
      ml-agents/mlagents/trainers/torch/model_serialization.py

10
ml-agents/mlagents/trainers/torch/model_serialization.py


batch_dim = [1]
seq_len_dim = [1]
vec_obs_size = 0
for shape in self.policy.behavior_spec.observation_shapes:
if len(shape) == 1:
vec_obs_size += shape[0]
for sens_spec in self.policy.behavior_spec.sensor_specs:
if len(sens_spec.shape) == 1:
vec_obs_size += sens_spec.shape[0]
for shape in self.policy.behavior_spec.observation_shapes
if len(shape) == 3
for sens_spec in self.policy.behavior_spec.sensor_specs
if len(sens_spec.shape) == 3
)
dummy_vec_obs = [torch.zeros(batch_dim + [vec_obs_size])]
# create input shape of NCHW

正在加载...
取消
保存