|
|
|
|
|
|
h_size, |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
self.vector_normalizers = nn.ModuleList(self.vector_normalizers) |
|
|
|
self.vector_encoders = nn.ModuleList(self.vector_encoders) |
|
|
|
self.visual_encoders = nn.ModuleList(self.visual_encoders) |
|
|
|
if use_lstm: |
|
|
|
|
|
|
for idx, vec_input in enumerate(vec_inputs): |
|
|
|
self.vector_normalizers[idx].update(vec_input) |
|
|
|
|
|
|
|
def forward(self, vec_inputs, vis_inputs, memories=None, sequence_length=1): |
|
|
|
def forward( |
|
|
|
self, |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
memories=torch.tensor(1), |
|
|
|
sequence_length=torch.tensor(1), |
|
|
|
): |
|
|
|
if self.normalize: |
|
|
|
vec_input = self.vector_normalizers[idx](vec_input) |
|
|
|
# if self.normalize: |
|
|
|
# vec_input = self.vector_normalizers[idx](vec_input) |
|
|
|
hidden = encoder(vec_input) |
|
|
|
vec_embeds.append(hidden) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embedding = torch.cat(vec_embeds + vis_embeds) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
embedding = embedding.reshape([sequence_length, -1, self.h_size]) |
|
|
|
memories = torch.split(memories, self.m_size // 2, dim=-1) |
|
|
|
embedding, memories = self.lstm(embedding, memories) |
|
|
|
embedding = embedding.reshape([-1, self.m_size // 2]) |
|
|
|
memories = torch.cat(memories, dim=-1) |
|
|
|
return embedding, memories |
|
|
|
# if self.use_lstm: |
|
|
|
# embedding = embedding.reshape([sequence_length, -1, self.h_size]) |
|
|
|
# memories = torch.split(memories, self.m_size // 2, dim=-1) |
|
|
|
# embedding, memories = self.lstm(embedding, memories) |
|
|
|
# embedding = embedding.reshape([-1, self.m_size // 2]) |
|
|
|
# memories = torch.cat(memories, dim=-1) |
|
|
|
return embedding, embedding |
|
|
|
|
|
|
|
|
|
|
|
class ActorCritic(nn.Module): |
|
|
|
|
|
|
self.stream_names = stream_names |
|
|
|
self.value_heads = ValueHeads(stream_names, embedding_size) |
|
|
|
|
|
|
|
@torch.jit.ignore |
|
|
|
def critic_pass(self, vec_inputs, vis_inputs, memories=None): |
|
|
|
if self.separate_critic: |
|
|
|
return self.critic(vec_inputs, vis_inputs) |
|
|
|
else: |
|
|
|
embedding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories) |
|
|
|
return self.value_heads(embedding) |
|
|
|
@torch.jit.export |
|
|
|
def critic_pass(self, vec_inputs, vis_inputs, memories=torch.tensor(1)): |
|
|
|
# if self.separate_critic: |
|
|
|
value, mean_value = self.critic(vec_inputs, vis_inputs) |
|
|
|
return {"extrinsic": value}, mean_value |
|
|
|
# else: |
|
|
|
# embedding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories) |
|
|
|
# return {"extrinsic" : self.value_heads(embedding)} |
|
|
|
@torch.jit.ignore |
|
|
|
def sample_action(self, dists): |
|
|
|
actions = [] |
|
|
|
for action_dist in dists: |
|
|
|
|
|
|
return actions |
|
|
|
|
|
|
|
@torch.jit.ignore |
|
|
|
def get_probs_and_entropy(self, actions, dists): |
|
|
|
log_probs = [] |
|
|
|
entropies = [] |
|
|
|
|
|
|
entropies = entropies.squeeze(-1) |
|
|
|
return log_probs, entropies |
|
|
|
|
|
|
|
@torch.jit.ignore |
|
|
|
def evaluate( |
|
|
|
self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1 |
|
|
|
): |
|
|
|
|
|
|
|
|
|
|
return dists, memories |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def jit_forward( |
|
|
|
self, |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
masks=torch.tensor(1), |
|
|
|
memories=torch.tensor(1), |
|
|
|
sequence_length=torch.tensor(1), |
|
|
|
): |
|
|
|
fut = torch.jit._fork( |
|
|
|
self.network_body, vec_inputs, vis_inputs, memories, sequence_length |
|
|
|
) |
|
|
|
embedding, memories = torch.jit._wait(fut) |
|
|
|
value_outputs = self.critic_pass(vec_inputs, vis_inputs, memories) |
|
|
|
return embedding, value_outputs, memories |
|
|
|
|
|
|
|
@torch.jit.ignore |
|
|
|
self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1 |
|
|
|
self, |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
masks=torch.tensor(1), |
|
|
|
memories=torch.tensor(1), |
|
|
|
sequence_length=torch.tensor(1), |
|
|
|
embedding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories, sequence_length |
|
|
|
embedding, value_outputs, memories = self.jit_forward( |
|
|
|
vec_inputs, vis_inputs, masks, memories, sequence_length |
|
|
|
value_outputs = self.critic(vec_inputs, vis_inputs) |
|
|
|
dists = self.distribution(embedding, masks=masks) |
|
|
|
dists = self.get_dist(embedding, masks) |
|
|
|
@torch.jit.ignore |
|
|
|
def get_dist(self, embedding, masks): |
|
|
|
return self.distribution(embedding, masks=masks) |
|
|
|
|
|
|
|
|
|
|
|
class Critic(nn.Module): |
|
|
|
def __init__( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Normalizer(nn.Module): |
|
|
|
def __init__(self, vec_obs_size, **kwargs): |
|
|
|
super(Normalizer, self).__init__(**kwargs) |
|
|
|
def __init__(self, vec_obs_size): |
|
|
|
super(Normalizer, self).__init__() |
|
|
|
self.normalization_steps = torch.tensor(1) |
|
|
|
self.running_mean = torch.zeros(vec_obs_size) |
|
|
|
self.running_variance = torch.ones(vec_obs_size) |
|
|
|
|
|
|
for name in stream_names: |
|
|
|
value = nn.Linear(input_size, 1) |
|
|
|
self.value_heads[name] = value |
|
|
|
self.value = value |
|
|
|
self.value_outputs = nn.ModuleDict({}) |
|
|
|
value_outputs = {} |
|
|
|
for stream_name, _ in self.value_heads.items(): |
|
|
|
value_outputs[stream_name] = self.value_heads[stream_name](hidden).squeeze( |
|
|
|
-1 |
|
|
|
) |
|
|
|
return ( |
|
|
|
value_outputs, |
|
|
|
torch.mean(torch.stack(list(value_outputs.values())), dim=0), |
|
|
|
) |
|
|
|
# self.__delattr__ |
|
|
|
# for stream_name, head in self.value_heads.items(): |
|
|
|
# self.value_outputs[stream_name] = head(hidden).squeeze( |
|
|
|
# -1 |
|
|
|
# ) |
|
|
|
return (self.value(hidden).squeeze(-1), self.value(hidden).squeeze(-1)) |
|
|
|
|
|
|
|
|
|
|
|
class VectorEncoder(nn.Module): |
|
|
|