浏览代码

Fix next_obs in get_trajectory_value_estimates

/develop/add-fire/sac-lst
Ervin Teng 4 年前
当前提交
fa0d3cb6
共有 1 个文件被更改,包括 10 次插入3 次删除
  1. 13
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

13
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


import numpy as np
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.components.bc.module import BCModule
from mlagents.trainers.torch.components.reward_providers import create_reward_provider

memory = torch.zeros([1, 1, self.policy.m_size])
next_obs = np.concatenate(next_obs, axis=-1)
next_obs = [ModelUtils.list_to_tensor(next_obs).unsqueeze(0)]
vec_vis_obs = SplitObservations.from_observations(next_obs)
next_vec_obs = [
ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0)
]
next_vis_obs = [
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0)
for _vis_ob in vec_vis_obs.visual_observations
]
value_estimates, next_memory = self.policy.actor_critic.critic_pass(
vector_obs, visual_obs, memory, sequence_length=batch.num_experiences

next_obs, next_obs, next_memory, sequence_length=1
next_vec_obs, next_vis_obs, next_memory, sequence_length=1
)
for name, estimate in value_estimates.items():

正在加载...
取消
保存