浏览代码

fix torch policy tests

/develop/actionmodel-csharp
Andrew Cohen 4 年前
当前提交
60309d8f
共有 2 个文件被更改,包括 6 次插入11 次删除
  1. 14
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 3
      ml-agents/mlagents/trainers/tests/torch/test_policy.py

14
ml-agents/mlagents/trainers/policy/torch_policy.py


masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
) -> Tuple[
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]:
"""
:param vec_obs: List of vector observations.
:param vis_obs: List of visual observations.

:return: Tuple of AgentAction, ActionLogProbs, entropies, and output memories.
"""
actions, log_probs, entropies, value_heads, memories = self.actor_critic.get_action_stats_and_value(
actions, log_probs, entropies, _, memories = self.actor_critic.get_action_stats_and_value(
return (actions, log_probs, entropies, value_heads, memories)
return (actions, log_probs, entropies, memories)
def evaluate_actions(
self,

run_out = {}
with torch.no_grad():
action, log_probs, entropy, value_heads, memories = self.sample_actions(
action, log_probs, entropy, memories = self.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories
)
action_dict = action.to_numpy_dict()

)
run_out["log_probs"] = log_probs.to_numpy_dict()
run_out["entropy"] = ModelUtils.to_numpy(entropy)
run_out["value_heads"] = {
name: ModelUtils.to_numpy(t) for name, t in value_heads.items()
}
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0)
run_out["learning_rate"] = 0.0
if self.use_recurrent:
run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0)

3
ml-agents/mlagents/trainers/tests/torch/test_policy.py


from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.settings import TrainerSettings, NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils, AgentAction
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.agent_action import AgentAction
VECTOR_ACTION_SPACE = 2
VECTOR_OBS_SPACE = 8

正在加载...
取消
保存