浏览代码

Fix BC and Reward Signals

/develop/permutepytorch
Ervin Teng 4 年前
当前提交
43c41d66
共有 3 个文件被更改,包括 14 次插入6 次删除
  1. 2
      ml-agents/mlagents/trainers/torch/components/bc/module.py
  2. 12
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
  3. 6
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py

2
ml-agents/mlagents/trainers/torch/components/bc/module.py


vis_ob = ModelUtils.list_to_tensor(
mini_batch_demo["visual_obs%d" % idx]
)
vis_obs.append(vis_ob)
vis_obs.append(ModelUtils.nhwc_to_nchw(vis_ob))
else:
vis_obs = []

12
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py


ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float)
],
vis_inputs=[
ModelUtils.list_to_tensor(
mini_batch["visual_obs%d" % i], dtype=torch.float
ModelUtils.nhwc_to_nchw(
ModelUtils.list_to_tensor(
mini_batch["visual_obs%d" % i], dtype=torch.float
)
)
for i in range(n_vis)
],

)
],
vis_inputs=[
ModelUtils.list_to_tensor(
mini_batch["next_visual_obs%d" % i], dtype=torch.float
ModelUtils.nhwc_to_nchw(
ModelUtils.list_to_tensor(
mini_batch["next_visual_obs%d" % i], dtype=torch.float
)
)
for i in range(n_vis)
],

6
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


else []
)
vis_inputs = [
ModelUtils.list_to_tensor(mini_batch["visual_obs%d" % i], dtype=torch.float)
ModelUtils.nhwc_to_nchw(
ModelUtils.list_to_tensor(
mini_batch["visual_obs%d" % i], dtype=torch.float
)
)
for i in range(n_vis)
]
return vec_inputs, vis_inputs

正在加载...
取消
保存