浏览代码

Still somewhat broken but cleaner

/develop/critic-op-lstm-currentmem
Ervin Teng 3 年前
当前提交
2b0dd850
共有 1 个文件被更改,包括 80 次插入38 次删除
  1. 118
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

118
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


from typing import Dict, Optional, Tuple, List
from mlagents.torch_utils import torch
import numpy as np
import math
from mlagents.trainers.buffer import AgentBuffer, AgentBufferField
from mlagents.trainers.trajectory import ObsUtil

reward_signal, self.policy.behavior_spec, settings
)
def _evaluate_by_sequence(
self, tensor_obs: List[torch.Tensor], initial_memory: np.ndarray
) -> Tuple[Dict[str, torch.Tensor], AgentBufferField, torch.Tensor]:
"""
Evaluate the batch sequence-by-sequence, assembling the result. This enables us to get the
intermediate memories for the critic.
"""
num_experiences = tensor_obs[0].shape[0]
all_next_memories = AgentBufferField()
# The 1st sequence are the ones that are padded. So if seq_len = 3 and
# trajectory is of length 10, the ist sequence is [pad,pad,obs].
# Compute the number of elements in this padded seq.
leftover = num_experiences % self.policy.sequence_length
first_seq_len = self.policy.sequence_length if leftover == 0 else leftover
for _ in range(first_seq_len):
all_next_memories.append(initial_memory.squeeze().detach().numpy())
# Compute values for the potentially truncated initial sequence
_mem = initial_memory
seq_obs = []
for _obs in tensor_obs:
if leftover > 0:
# Pad
# _obs will always be bigger than leftover
padding = torch.zeros_like(
_obs[0 : self.policy.sequence_length - leftover]
)
padded_obs = torch.cat([padding, _obs[0:leftover]])
else:
padded_obs = _obs[0 : self.policy.sequence_length]
seq_obs.append(padded_obs)
init_values, _mem = self.critic.critic_pass(
seq_obs, _mem, sequence_length=self.policy.sequence_length
)
# Trim out padded part
all_values = {
signal_name: [init_values[signal_name][leftover:]]
for signal_name in init_values.keys()
}
# Evaluate other trajectories
for seq_num in range(
1, math.ceil((num_experiences) / (self.policy.sequence_length))
):
seq_obs = []
for _obs in tensor_obs:
start = seq_num * self.policy.sequence_length - leftover
end = (seq_num + 1) * self.policy.sequence_length - leftover
seq_obs.append(_obs[start:end])
values, _mem = self.critic.critic_pass(
seq_obs, _mem, sequence_length=self.policy.sequence_length
)
for _ in range(self.policy.sequence_length):
all_next_memories.append(_mem.squeeze().detach().numpy())
for signal_name, _val in values.items():
all_values[signal_name].append(_val)
# Create one tensor per reward signal
all_value_tensors = {
signal_name: torch.cat(value_list, dim=0)
for signal_name, value_list in all_values.items()
}
next_mem = _mem
return all_value_tensors, all_next_memories, next_mem
def get_trajectory_value_estimates(
self,
batch: AgentBuffer,

else None
)
# If we're using LSTM, we want to get all the intermediate memories.
all_next_memories: Optional[AgentBufferField] = None
if self.policy.use_recurrent:
resequenced_buffer = AgentBuffer()
all_next_memories = AgentBufferField()
# The 1st sequence are the ones that are padded. So if seq_len = 3 and
# trajectory is of length 10, the ist sequence is [pad,pad,obs].
# Compute the number of elements in this padded seq.
leftover = batch.num_experiences % self.policy.sequence_length
first_seq_len = self.policy.sequence_length if leftover == 0 else leftover
for _ in range(first_seq_len):
all_next_memories.append(memory.squeeze().detach().numpy())
batch.resequence_and_append(
resequenced_buffer, training_length=self.policy.sequence_length
)
reseq_obs = ObsUtil.from_buffer(resequenced_buffer, n_obs)
reseq_obs = [ModelUtils.list_to_tensor(obs) for obs in reseq_obs]
# By now, the buffer should be of length seq_len * num_seq, padded
_mem = memory
for seq_num in range(
resequenced_buffer.num_experiences // self.policy.sequence_length - 1
):
seq_obs = []
for _obs in reseq_obs:
start = seq_num * self.policy.sequence_length
end = (seq_num + 1) * self.policy.sequence_length
seq_obs.append(_obs[start:end])
_, next_seq_mem = self.critic.critic_pass(
seq_obs, _mem, sequence_length=self.policy.sequence_length
)
for _ in range(self.policy.sequence_length):
all_next_memories.append(next_seq_mem.squeeze().detach().numpy())
# Convert to tensors
current_obs = ObsUtil.from_buffer(batch, n_obs)
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]

value_estimates, next_memory = self.critic.critic_pass(
current_obs, memory, sequence_length=batch.num_experiences
)
# If we're using LSTM, we want to get all the intermediate memories.
all_next_memories: Optional[AgentBufferField] = None
if self.policy.use_recurrent:
(
value_estimates,
all_next_memories,
next_memory,
) = self._evaluate_by_sequence(current_obs, memory)
else:
value_estimates, next_memory = self.critic.critic_pass(
current_obs, memory, sequence_length=batch.num_experiences
)
# Store the memory for the next trajectory
self.critic_memory_dict[agent_id] = next_memory

next_value_estimate[k] = 0.0
if agent_id in self.critic_memory_dict:
self.critic_memory_dict.pop(agent_id)
assert len(value_estimates["extrinsic"]) == batch.num_experiences
return value_estimates, next_value_estimate, all_next_memories
正在加载...
取消
保存