浏览代码

Update curiosity reward provider

/develop/unified-obs
Ervin Teng 4 年前
当前提交
8d29114d
共有 1 个文件被更改,包括 7 次插入22 次删除
  1. 29
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py

29
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py


"""
Extracts the current state embedding from a mini_batch.
"""
n_vis = len(self._state_encoder.visual_processors)
vec_inputs=[
ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float)
],
vis_inputs=[
ModelUtils.list_to_tensor(
mini_batch["visual_obs%d" % i], dtype=torch.float
)
for i in range(n_vis)
],
net_inputs=ModelUtils.list_to_tensor_list(
AgentBuffer.obs_list_to_obs_batch(mini_batch["obs"]), dtype=torch.float
)
)
return hidden

"""
n_vis = len(self._state_encoder.visual_processors)
vec_inputs=[
ModelUtils.list_to_tensor(
mini_batch["next_vector_in"], dtype=torch.float
)
],
vis_inputs=[
ModelUtils.list_to_tensor(
mini_batch["next_visual_obs%d" % i], dtype=torch.float
)
for i in range(n_vis)
],
net_inputs=ModelUtils.list_to_tensor_list(
AgentBuffer.obs_list_to_obs_batch(mini_batch["next_obs"]),
dtype=torch.float,
)
)
return hidden

正在加载...
取消
保存