浏览代码

Execute critic with LSTM

/MLA-1734-demo-provider
Ervin Teng 4 年前
当前提交
b6095151
共有 2 个文件被更改,包括 14 次插入4 次删除
  1. 12
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 6
      ml-agents/mlagents/trainers/torch/networks.py

12
ml-agents/mlagents/trainers/policy/torch_policy.py


"""
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
"""
dists, memories = self.actor_critic.get_dists(
vec_obs, vis_obs, masks, memories, seq_len
)
if memories is None:
dists, memories = self.actor_critic.get_dists(
vec_obs, vis_obs, masks, memories, seq_len
)
else:
# If we're using LSTM. we need to execute the values to get the critic memories
dists, _, memories = self.actor_critic.get_dist_and_value(
vec_obs, vis_obs, masks, memories, seq_len
)
action_list = self.actor_critic.sample_action(dists)
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
action_list, dists

6
ml-agents/mlagents/trainers/torch/networks.py


else 0
)
self.visual_processors, self.vector_processors, encoder_input_size = ModelUtils.create_input_processors(
(
self.visual_processors,
self.vector_processors,
encoder_input_size,
) = ModelUtils.create_input_processors(
observation_shapes,
self.h_size,
network_settings.vis_encode_type,

正在加载...
取消
保存