浏览代码

Typing for torch policy

/develop/add-fire/policy-tests
Ervin Teng 4 年前
当前提交
116303f1
共有 1 个文件被更改,包括 18 次插入10 次删除
  1. 28
      ml-agents/mlagents/trainers/policy/torch_policy.py

28
ml-agents/mlagents/trainers/policy/torch_policy.py


from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Optional
import numpy as np
import torch

@timed
def sample_actions(
self,
vec_obs,
vis_obs,
masks=None,
memories=None,
seq_len=1,
all_log_probs=False,
):
vec_obs: List[torch.Tensor],
vis_obs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
all_log_probs: bool = False,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
"""
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
"""

)
def evaluate_actions(
self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1
):
self,
vec_obs: torch.Tensor,
vis_obs: torch.Tensor,
actions: torch.Tensor,
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
dists, value_heads, _ = self.actor_critic.get_dist_and_value(
vec_obs, vis_obs, masks, memories, seq_len
)

正在加载...
取消
保存