浏览代码

Don't run value during inference

/MLA-1734-demo-provider
Ervin Teng 4 年前
当前提交
7754ad7b
共有 5 个文件被更改,包括 12 次插入35 次删除
  1. 20
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 11
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  3. 4
      ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
  4. 10
      ml-agents/mlagents/trainers/tests/torch/test_policy.py
  5. 2
      ml-agents/mlagents/trainers/torch/components/bc/module.py

20
ml-agents/mlagents/trainers/policy/torch_policy.py


memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
all_log_probs: bool = False,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
dists, value_heads, memories = self.actor_critic.get_dist_and_value(
dists, memories = self.actor_critic.get_dists(
vec_obs, vis_obs, masks, memories, seq_len
)
action_list = self.actor_critic.sample_action(dists)

else:
actions = actions[:, 0, :]
return (
actions,
all_logs if all_log_probs else log_probs,
entropies,
value_heads,
memories,
)
return (actions, all_logs if all_log_probs else log_probs, entropies, memories)
def evaluate_actions(
self,

run_out = {}
with torch.no_grad():
action, log_probs, entropy, value_heads, memories = self.sample_actions(
action, log_probs, entropy, memories = self.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories
)
run_out["action"] = ModelUtils.to_numpy(action)

run_out["entropy"] = ModelUtils.to_numpy(entropy)
run_out["value_heads"] = {
name: ModelUtils.to_numpy(t) for name, t in value_heads.items()
}
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0)
run_out["learning_rate"] = 0.0
if self.use_recurrent:
run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0)

11
ml-agents/mlagents/trainers/sac/optimizer_torch.py


self.target_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
(
sampled_actions,
log_probs,
entropies,
sampled_values,
_,
) = self.policy.sample_actions(
(sampled_actions, log_probs, _, _) = self.policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

)
sampled_values, _ = self.policy.actor_critic.critic_pass(
vec_obs, vis_obs, memories, sequence_length=self.policy.sequence_length
)
if self.policy.use_continuous_act:
squeezed_actions = actions.squeeze(-1)

4
ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py


).unsqueeze(0)
with torch.no_grad():
_, log_probs1, _, _, _ = policy1.sample_actions(
_, log_probs1, _, _ = policy1.sample_actions(
_, log_probs2, _, _, _ = policy2.sample_actions(
_, log_probs2, _, _ = policy2.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
)

10
ml-agents/mlagents/trainers/tests/torch/test_policy.py


if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
(
sampled_actions,
log_probs,
entropies,
sampled_values,
memories,
) = policy.sample_actions(
(sampled_actions, log_probs, entropies, memories) = policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

else:
assert log_probs.shape == (64, policy.behavior_spec.action_shape)
assert entropies.shape == (64, policy.behavior_spec.action_size)
for val in sampled_values.values():
assert val.shape == (64,)
if rnn:
assert memories.shape == (1, 1, policy.m_size)

2
ml-agents/mlagents/trainers/torch/components/bc/module.py


else:
vis_obs = []
selected_actions, all_log_probs, _, _, _ = self.policy.sample_actions(
selected_actions, all_log_probs, _, _ = self.policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

正在加载...
取消
保存