|
|
|
|
|
|
self.vector_encoders = nn.ModuleList(self.vector_encoders) |
|
|
|
self.visual_encoders = nn.ModuleList(self.visual_encoders) |
|
|
|
if use_lstm: |
|
|
|
self.lstm = nn.GRU(h_size, h_size, 1) |
|
|
|
|
|
|
|
def clear_memory(self, batch_size): |
|
|
|
self.memory = ( |
|
|
|
torch.zeros(1, batch_size, self.m_size), |
|
|
|
torch.zeros(1, batch_size, self.m_size), |
|
|
|
) |
|
|
|
self.lstm = nn.LSTM(h_size, m_size // 2, 1) |
|
|
|
|
|
|
|
def update_normalization(self, vec_inputs): |
|
|
|
if self.normalize: |
|
|
|
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
embedding = embedding.reshape([sequence_length, -1, self.h_size]) |
|
|
|
memories = torch.split(memories, self.m_size // 2, dim=-1) |
|
|
|
embedding = embedding.reshape([-1, self.h_size]) |
|
|
|
embedding = embedding.reshape([-1, self.m_size // 2]) |
|
|
|
memories = torch.cat(memories, dim=-1) |
|
|
|
return embedding, memories |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vis_encode_type, |
|
|
|
use_lstm, |
|
|
|
) |
|
|
|
if use_lstm: |
|
|
|
embedding_size = m_size // 2 |
|
|
|
else: |
|
|
|
embedding_size = h_size |
|
|
|
self.distribution = GaussianDistribution(h_size, act_size[0]) |
|
|
|
self.distribution = GaussianDistribution(embedding_size, act_size[0]) |
|
|
|
self.distribution = MultiCategoricalDistribution(h_size, act_size) |
|
|
|
self.distribution = MultiCategoricalDistribution(embedding_size, act_size) |
|
|
|
if separate_critic: |
|
|
|
self.critic = Critic( |
|
|
|
stream_names, |
|
|
|
|
|
|
) |
|
|
|
else: |
|
|
|
self.stream_names = stream_names |
|
|
|
self.value_heads = ValueHeads(stream_names, h_size) |
|
|
|
self.value_heads = ValueHeads(stream_names, embedding_size) |
|
|
|
|
|
|
|
def update_normalization(self, vector_obs): |
|
|
|
self.network_body.update_normalization(vector_obs) |
|
|
|
|
|
|
def critic_pass(self, vec_inputs, vis_inputs): |
|
|
|
def critic_pass(self, vec_inputs, vis_inputs, memories=None): |
|
|
|
embedding, _ = self.network_body(vec_inputs, vis_inputs) |
|
|
|
embedding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories) |
|
|
|
return self.value_heads(embedding) |
|
|
|
|
|
|
|
def sample_action(self, dists): |
|
|
|