浏览代码

cherry pick PR#3032 (#3066)

/tag-0.12.1
GitHub 5 年前
当前提交
681093cf
共有 2 个文件被更改,包括 11 次插入8 次删除
  1. 7
      ml-agents/mlagents/trainers/ppo/policy.py
  2. 12
      ml-agents/mlagents/trainers/tf_policy.py

7
ml-agents/mlagents/trainers/ppo/policy.py


]
if self.use_vec_obs:
feed_dict[self.model.vector_in] = [brain_info.vector_observations[idx]]
agent_id = brain_info.agents[idx]
feed_dict[self.model.memory_in] = self.retrieve_memories([idx])
feed_dict[self.model.memory_in] = self.retrieve_memories([agent_id])
feed_dict[self.model.prev_action] = self.retrieve_previous_action([idx])
feed_dict[self.model.prev_action] = self.retrieve_previous_action(
[agent_id]
)
value_estimates = self.sess.run(self.model.value_heads, feed_dict)
value_estimates = {k: float(v) for k, v in value_estimates.items()}

12
ml-agents/mlagents/trainers/tf_policy.py


self.seed = seed
self.brain = brain
self.use_recurrent = trainer_parameters["use_recurrent"]
self.memory_dict: Dict[int, np.ndarray] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.previous_action_dict: Dict[int, np.array] = {}
self.previous_action_dict: Dict[str, np.array] = {}
self.normalize = trainer_parameters.get("normalize", False)
self.use_continuous_act = brain.vector_action_space_type == "continuous"
if self.use_continuous_act:

return np.zeros((num_agents, self.m_size), dtype=np.float)
def save_memories(
self, agent_ids: List[int], memory_matrix: Optional[np.ndarray]
self, agent_ids: List[str], memory_matrix: Optional[np.ndarray]
) -> None:
if memory_matrix is None:
return

def retrieve_memories(self, agent_ids: List[int]) -> np.ndarray:
def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray:
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.memory_dict:

return np.zeros((num_agents, self.num_branches), dtype=np.int)
def save_previous_action(
self, agent_ids: List[int], action_matrix: Optional[np.ndarray]
self, agent_ids: List[str], action_matrix: Optional[np.ndarray]
) -> None:
if action_matrix is None:
return

def retrieve_previous_action(self, agent_ids: List[int]) -> np.ndarray:
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray:
action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.previous_action_dict:

正在加载...
取消
保存