from typing import Dict , cast , List , Tuple , Optional
from collections import defaultdict
import math
from mlagents.torch_utils import torch , default_device
from mlagents.trainers.buffer import (
num_experiences = self_obs [ 0 ] . shape [ 0 ]
all_next_value_mem = AgentBufferField ( )
all_next_baseline_mem = AgentBufferField ( )
# In the buffer, the 1st sequence are the ones that are padded. So if seq_len = 3 and
# trajectory is of length 10, the 1st sequence is [pad,pad,obs].
# When using LSTM, we need to divide the trajectory into sequences of equal length. Sometimes,
# that division isn't even, and we must pad the leftover sequence.
# In the buffer, the last sequence are the ones that are padded. So if seq_len = 3 and
# trajectory is of length 10, the last sequence is [obs,pad,pad].
leftover = num_experiences % self . policy . sequence_length
leftover_seq_len = num_experiences % self . policy . sequence_length
# Compute values for the potentially truncated initial sequence
first_seq_len = leftover if leftover > 0 else self . policy . sequence_length
self_seq_obs = [ ]
groupmate_seq_obs = [ ]
groupmate_seq_act = [ ]
seq_obs = [ ]
for _self_obs in self_obs :
first_seq_obs = _self_obs [ 0 : first_seq_len ]
seq_obs . append ( first_seq_obs )
self_seq_obs . append ( seq_obs )
for groupmate_obs , groupmate_action in zip ( obs , actions ) :
seq_obs = [ ]
for _obs in groupmate_obs :
first_seq_obs = _obs [ 0 : first_seq_len ]
seq_obs . append ( first_seq_obs )
groupmate_seq_obs . append ( seq_obs )
_act = groupmate_action . slice ( 0 , first_seq_len )
groupmate_seq_act . append ( _act )
# For the first sequence, the initial memory should be the one at the
# beginning of this trajectory.
for _ in range ( first_seq_len ) :
all_next_value_mem . append ( ModelUtils . to_numpy ( init_value_mem . squeeze ( ) ) )
all_next_baseline_mem . append (
ModelUtils . to_numpy ( init_baseline_mem . squeeze ( ) )
)
all_seq_obs = self_seq_obs + groupmate_seq_obs
init_values , _value_mem = self . critic . critic_pass (
all_seq_obs , init_value_mem , sequence_length = first_seq_len
)
all_values = {
signal_name : [ init_values [ signal_name ] ]
for signal_name in init_values . keys ( )
}
groupmate_obs_and_actions = ( groupmate_seq_obs , groupmate_seq_act )
init_baseline , _baseline_mem = self . critic . baseline (
self_seq_obs [ 0 ] ,
groupmate_obs_and_actions ,
init_baseline_mem ,
sequence_length = first_seq_len ,
)
all_baseline = {
signal_name : [ init_baseline [ signal_name ] ]
for signal_name in init_baseline . keys ( )
}
all_values : Dict [ str , List [ np . ndarray ] ] = defaultdict ( list )
all_baseline : Dict [ str , List [ np . ndarray ] ] = defaultdict ( list )
_baseline_mem = init_baseline_mem
_value_mem = init_value_mem
for seq_num in range (
1 , math . ceil ( ( num_experiences ) / ( self . policy . sequence_length ) )
) :
for seq_num in range ( num_experiences / / self . policy . sequence_length ) :
for _ in range ( self . policy . sequence_length ) :
all_next_value_mem . append ( ModelUtils . to_numpy ( _value_mem . squeeze ( ) ) )
all_next_baseline_mem . append (
start = seq_num * self . policy . sequence_length - (
self . policy . sequence_length - leftover
)
end = ( seq_num + 1 ) * self . policy . sequence_length - (
self . policy . sequence_length - leftover
)
start = seq_num * self . policy . sequence_length
end = ( seq_num + 1 ) * self . policy . sequence_length
self_seq_obs = [ ]
groupmate_seq_obs = [ ]
seq_obs . append ( _obs [ start : end ] )
seq_obs . append ( _self_obs [ start : end ] )
for groupmate_obs , team_action in zip ( obs , actions ) :
for groupmate_obs , groupmate_action in zip ( obs , actions ) :
for ( _obs , ) in groupmate_obs :
first_seq_obs = _obs [ start : end ]
seq_obs . append ( first_seq_obs )
for _obs in groupmate_obs :
sliced_seq_obs = _obs [ start : end ]
seq_obs . append ( sliced_seq_obs )
_act = team_action . slice ( start , end )
_act = groupmate_action . slice ( start , end )
groupmate_seq_act . append ( _act )
all_seq_obs = self_seq_obs + groupmate_seq_obs
all_values = {
signal_name : [ init_values [ signal_name ] ] for signal_name in values . keys ( )
}
for signal_name , _val in values . items ( ) :
all_values [ signal_name ] . append ( _val )
groupmate_obs_and_actions = ( groupmate_seq_obs , groupmate_seq_act )
baselines , _baseline_mem = self . critic . baseline (
sequence_length = first_seq_len ,
sequence_length = self . policy . sequence_length ,
all_baseline = {
signal_name : [ baselines [ signal_name ] ]
for signal_name in baselines . keys ( )
}
for signal_name , _val in baselines . items ( ) :
all_baseline [ signal_name ] . append ( _val )
# Compute values for the potentially truncated initial sequence
if leftover_seq_len > 0 :
self_seq_obs = [ ]
groupmate_seq_obs = [ ]
groupmate_seq_act = [ ]
seq_obs = [ ]
for _self_obs in self_obs :
last_seq_obs = _self_obs [ - leftover_seq_len : ]
seq_obs . append ( last_seq_obs )
self_seq_obs . append ( seq_obs )
for groupmate_obs , groupmate_action in zip ( obs , actions ) :
seq_obs = [ ]
for _obs in groupmate_obs :
last_seq_obs = _obs [ - leftover_seq_len : ]
seq_obs . append ( last_seq_obs )
groupmate_seq_obs . append ( seq_obs )
_act = groupmate_action . slice ( len ( _obs ) - leftover_seq_len , len ( _obs ) )
groupmate_seq_act . append ( _act )
# For the last sequence, the initial memory should be the one at the
# beginning of this trajectory.
seq_obs = [ ]
for _ in range ( leftover_seq_len ) :
all_next_value_mem . append ( ModelUtils . to_numpy ( _value_mem . squeeze ( ) ) )
all_next_baseline_mem . append (
ModelUtils . to_numpy ( _baseline_mem . squeeze ( ) )
)
all_seq_obs = self_seq_obs + groupmate_seq_obs
last_values , _value_mem = self . critic . critic_pass (
all_seq_obs , _value_mem , sequence_length = leftover_seq_len
)
for signal_name , _val in last_values . items ( ) :
all_values [ signal_name ] . append ( _val )
groupmate_obs_and_actions = ( groupmate_seq_obs , groupmate_seq_act )
last_baseline , _baseline_mem = self . critic . baseline (
self_seq_obs [ 0 ] ,
groupmate_obs_and_actions ,
_baseline_mem ,
sequence_length = leftover_seq_len ,
)
for signal_name , _val in last_baseline . items ( ) :
all_baseline [ signal_name ] . append ( _val )
# Create one tensor per reward signal
all_value_tensors = {
signal_name : torch . cat ( value_list , dim = 0 )