浏览代码

Address comments

/develop/lstm-burnin
Ervin Teng 4 年前
当前提交
9fd4a81e
共有 2 个文件被更改,包括 25 次插入28 次删除
  1. 27
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 26
      ml-agents/mlagents/trainers/poca/optimizer_torch.py

27
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


from typing import Dict, Optional, Tuple, List
from mlagents.torch_utils import torch
import numpy as np
import math
from collections import defaultdict
from mlagents.trainers.buffer import AgentBuffer, AgentBufferField

"""
num_experiences = tensor_obs[0].shape[0]
all_next_memories = AgentBufferField()
# In the buffer, the last sequence are the ones that are padded. So if seq_len = 3 and
# trajectory is of length 10, the last sequence is [obs,pad,pad].
# Compute the number of elements in this padded seq.
leftover = num_experiences % self.policy.sequence_length
# When using LSTM, we need to divide the trajectory into sequences of even length. Sometimes,
# that division isn't even, and we must pad the leftover sequence.
# When it is added to the buffer, the last sequence will be padded. So if seq_len = 3 and
# trajectory is of length 10, the last sequence is [obs,pad,pad] once it is added to the buffer.
# Compute the number of elements in this sequence that will end up being padded.
leftover_seq_len = num_experiences % self.policy.sequence_length
for seq_num in range(
0, math.floor((num_experiences) / (self.policy.sequence_length))
):
for seq_num in range(num_experiences // (self.policy.sequence_length)):
seq_obs = []
for _ in range(self.policy.sequence_length):
all_next_memories.append(ModelUtils.to_numpy(_mem.squeeze()))

for signal_name, _val in values.items():
all_values[signal_name].append(_val)
# Compute values for the potentially truncated last sequence
# Compute values for the potentially truncated last sequence. Note that this
# sequence isn't padded yet, but will be.
last_seq_len = leftover
if last_seq_len > 0:
if leftover_seq_len > 0:
last_seq_obs = _obs[-last_seq_len:]
last_seq_obs = _obs[-leftover_seq_len:]
for _ in range(last_seq_len):
for _ in range(leftover_seq_len):
seq_obs, _mem, sequence_length=last_seq_len
seq_obs, _mem, sequence_length=leftover_seq_len
)
for signal_name, _val in last_values.items():
all_values[signal_name].append(_val)

26
ml-agents/mlagents/trainers/poca/optimizer_torch.py


ExtrinsicRewardProvider,
)
import numpy as np
import math
from mlagents.torch_utils import torch, default_device
from mlagents.trainers.buffer import (

num_experiences = self_obs[0].shape[0]
all_next_value_mem = AgentBufferField()
all_next_baseline_mem = AgentBufferField()
# When using LSTM, we need to divide the trajectory into sequences of even length. Sometimes,
# that division isn't even, and we must pad the leftover sequence.
leftover = num_experiences % self.policy.sequence_length
leftover_seq_len = num_experiences % self.policy.sequence_length
all_values: Dict[str, List[np.ndarray]] = defaultdict(list)
all_baseline: Dict[str, List[np.ndarray]] = defaultdict(list)

# Evaluate other trajectories, carrying over _mem after each
# trajectory
for seq_num in range(
0, math.floor((num_experiences) / (self.policy.sequence_length))
):
for seq_num in range(num_experiences // self.policy.sequence_length):
for _ in range(self.policy.sequence_length):
all_next_value_mem.append(ModelUtils.to_numpy(_value_mem.squeeze()))
all_next_baseline_mem.append(

all_baseline[signal_name].append(_val)
# Compute values for the potentially truncated initial sequence
last_seq_len = leftover
if last_seq_len > 0:
if leftover_seq_len > 0:
last_seq_obs = _self_obs[-last_seq_len:]
last_seq_obs = _self_obs[-leftover_seq_len:]
seq_obs.append(last_seq_obs)
self_seq_obs.append(seq_obs)

last_seq_obs = _obs[-last_seq_len:]
last_seq_obs = _obs[-leftover_seq_len:]
_act = groupmate_action.slice(len(_obs) - last_seq_len, len(_obs))
_act = groupmate_action.slice(len(_obs) - leftover_seq_len, len(_obs))
last_seq_len = leftover
for _ in range(last_seq_len):
for _ in range(leftover_seq_len):
all_next_value_mem.append(ModelUtils.to_numpy(_value_mem.squeeze()))
all_next_baseline_mem.append(
ModelUtils.to_numpy(_baseline_mem.squeeze())

last_values, _value_mem = self.critic.critic_pass(
all_seq_obs, _value_mem, sequence_length=last_seq_len
all_seq_obs, _value_mem, sequence_length=leftover_seq_len
)
for signal_name, _val in last_values.items():
all_values[signal_name].append(_val)

groupmate_obs_and_actions,
_baseline_mem,
sequence_length=last_seq_len,
sequence_length=leftover_seq_len,
)
for signal_name, _val in last_baseline.items():
all_baseline[signal_name].append(_val)

正在加载...
取消
保存