浏览代码

ignoring precommit, grabbing baseline/critic mems from buffer in trainer

/develop/action-slice
Andrew Cohen 4 年前
当前提交
8f799687
共有 4 个文件被更改,包括 39 次插入8 次删除
  1. 1
      ml-agents/mlagents/trainers/buffer.py
  2. 38
      ml-agents/mlagents/trainers/coma/optimizer_torch.py
  3. 6
      ml-agents/mlagents/trainers/coma/trainer.py
  4. 2
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

1
ml-agents/mlagents/trainers/buffer.py


MASKS = "masks"
MEMORY = "memory"
CRITIC_MEMORY = "critic_memory"
BASELINE_MEMORY = "coma_baseline_memory"
PREV_ACTION = "prev_action"
ADVANTAGES = "advantages"

38
ml-agents/mlagents/trainers/coma/optimizer_torch.py


]
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
value_memories = [
ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i])
for i in range(
0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length
)
]
baseline_memories = [
ModelUtils.list_to_tensor(batch[BufferKey.BASELINE_MEMORY][i])
for i in range(
0, len(batch[BufferKey.BASELINE_MEMORY]), self.policy.sequence_length
)
]
if len(value_memories) > 0:
value_memories = torch.stack(value_memories).unsqueeze(0)
baseline_memories = torch.stack(baseline_memories).unsqueeze(0)
log_probs, entropy = self.policy.evaluate_actions(
current_obs,

)
all_obs = [current_obs] + group_obs
values, _ = self.critic.critic_pass(
all_obs, memories=memories, sequence_length=self.policy.sequence_length
all_obs,
memories=value_memories,
sequence_length=self.policy.sequence_length,
memories=memories,
memories=baseline_memories,
sequence_length=self.policy.sequence_length,
)
old_log_probs = ActionLogProbs.from_buffer(batch).flatten()

for team_obs, team_action in zip(obs, actions):
seq_obs = []
for (_obs,) in team_obs:
for _obs in team_obs:
first_seq_obs = _obs[0:first_seq_len]
seq_obs.append(first_seq_obs)
team_seq_obs.append(seq_obs)

_init_value_mem = self.value_memory_dict[agent_id]
_init_baseline_mem = self.baseline_memory_dict[agent_id]
else:
memory = (
_init_value_mem = (
torch.zeros((1, 1, self.critic.memory_size))
if self.policy.use_recurrent
else None
)
_init_baseline_mem = (
torch.zeros((1, 1, self.critic.memory_size))
if self.policy.use_recurrent
else None

all_next_value_mem: Optional[AgentBufferField] = None
all_next_baseline_mem: Optional[AgentBufferField] = None
if self.policy.use_recurrent:
value_estimates, baseline_estimates, all_next_value_mem, all_next_baseline_mem, next_value_mem, next_baseline_mem = self.critic._evaluate_by_sequence_team(
value_estimates, baseline_estimates, all_next_value_mem, all_next_baseline_mem, next_value_mem, next_baseline_mem = self._evaluate_by_sequence_team(
all_obs, memory, sequence_length=batch.num_experiences
all_obs, _init_value_mem, sequence_length=batch.num_experiences
)
baseline_estimates, baseline_mem = self.critic.baseline(

memory,
_init_baseline_mem,
sequence_length=batch.num_experiences,
)
# Store the memory for the next trajectory

6
ml-agents/mlagents/trainers/coma/trainer.py


value_estimates,
baseline_estimates,
value_next,
value_memories,
baseline_memories,
) = self.optimizer.get_trajectory_and_baseline_value_estimates(
agent_buffer_trajectory,
trajectory.next_obs,

and not trajectory.interrupted,
)
if value_memories is not None:
agent_buffer_trajectory[BufferKey.CRITIC_MEMORY].set(value_memories)
agent_buffer_trajectory[BufferKey.BASELINE_MEMORY].set(baseline_memories)
for name, v in value_estimates.items():
agent_buffer_trajectory[RewardSignalUtil.value_estimates_key(name)].extend(

2
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


)
def _evaluate_by_sequence(
self, tensor_obs: List[torch.Tensor], initial_memory: np.ndarray
self, tensor_obs: List[torch.Tensor], initial_memory: torch.Tensor
) -> Tuple[Dict[str, torch.Tensor], AgentBufferField, torch.Tensor]:
"""
Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the

正在加载...
取消
保存