浏览代码

Merge branch 'develop-fix-lstms' into develop-gru

/develop/gru
Ervin Teng 4 年前
当前提交
f3a2a81f
共有 4 个文件被更改,包括 29 次插入5 次删除
  1. 2
      ml-agents/mlagents/trainers/agent_processor.py
  2. 10
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  3. 15
      ml-agents/mlagents/trainers/policy/policy.py
  4. 7
      ml-agents/mlagents/trainers/torch/networks.py

2
ml-agents/mlagents/trainers/agent_processor.py


if stored_decision_step is not None and stored_take_action_outputs is not None:
obs = stored_decision_step.obs
if self.policy.use_recurrent:
memory = self.policy.retrieve_memories([global_id])[0, :]
memory = self.policy.retrieve_previous_memories([global_id])[0, :]
else:
memory = None
done = terminated # Since this is an ongoing step

10
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


from mlagents.torch_utils import torch
import numpy as np
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.buffer import AgentBuffer, BufferKey
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.components.bc.module import BCModule
from mlagents.trainers.torch.components.reward_providers import create_reward_provider

current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]
memory = torch.zeros([1, 1, self.policy.m_size])
memory = (
ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][0])
.unsqueeze(0)
.unsqueeze(0)
if self.policy.use_recurrent
else None
)
next_obs = [obs.unsqueeze(0) for obs in next_obs]

15
ml-agents/mlagents/trainers/policy/policy.py


self.network_settings: NetworkSettings = trainer_settings.network_settings
self.seed = seed
self.previous_action_dict: Dict[str, np.ndarray] = {}
self.previous_memory_dict: Dict[str, np.ndarray] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings.network_settings.normalize
self.use_recurrent = self.network_settings.memory is not None

if memory_matrix is None:
return
# Pass old memories into previous_memory_dict
for agent_id in agent_ids:
if agent_id in self.memory_dict:
self.previous_memory_dict[agent_id] = self.memory_dict[agent_id]
for index, agent_id in enumerate(agent_ids):
self.memory_dict[agent_id] = memory_matrix[index, :]

memory_matrix[index, :] = self.memory_dict[agent_id]
return memory_matrix
def retrieve_previous_memories(self, agent_ids: List[str]) -> np.ndarray:
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.previous_memory_dict:
memory_matrix[index, :] = self.previous_memory_dict[agent_id]
return memory_matrix
if agent_id in self.previous_memory_dict:
self.previous_memory_dict.pop(agent_id)
def make_empty_previous_action(self, num_agents: int) -> np.ndarray:
"""

7
ml-agents/mlagents/trainers/torch/networks.py


inputs, masks=masks, memories=actor_mem, sequence_length=sequence_length
)
if critic_mem is not None:
# Make memories with the actor mem unchanged
memories_out = torch.cat([actor_mem_out, critic_mem], dim=-1)
# Get value memories with the actor mem unchanged
_, critic_mem_outs = self.critic(
inputs, memories=critic_mem, sequence_length=sequence_length
)
memories_out = torch.cat([actor_mem_out, critic_mem_outs], dim=-1)
else:
memories_out = None
return action, log_probs, entropies, memories_out

正在加载...
取消
保存