浏览代码

Proper initialization and SAC masking

/develop/add-fire/memoryclass
Ervin Teng 4 年前
当前提交
eeae6d97
共有 3 个文件被更改,包括 43 次插入21 次删除
  1. 27
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  2. 24
      ml-agents/mlagents/trainers/torch/layers.py
  3. 13
      ml-agents/mlagents/trainers/torch/networks.py

27
ml-agents/mlagents/trainers/sac/optimizer_torch.py


* self.gammas[i]
* target_values[name]
)
_q1_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(q_backup, q1_stream)
_q1_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks
_q2_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(q_backup, q2_stream)
_q2_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks
)
q1_losses.append(_q1_loss)

v_backup = min_policy_qs[name] - torch.sum(
_ent_coef * log_probs, dim=1
)
value_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(values[name], v_backup)
value_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks
)
value_losses.append(value_loss)
else:

v_backup = min_policy_qs[name] - torch.mean(
branched_ent_bonus, axis=0
)
value_loss = 0.5 * torch.mean(
loss_masks
* torch.nn.functional.mse_loss(values[name], v_backup.squeeze())
value_loss = 0.5 * ModelUtils.masked_mean(
torch.pan><span class="n">nn.functional.mse_loss(values[namen><span class="p">], v_backup.squeeze()),
loss_masks,
)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))

if not discrete:
mean_q1 = mean_q1.unsqueeze(1)
batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1)
policy_loss = torch.mean(loss_masks * batch_policy_loss)
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
else:
action_probs = log_probs.exp()
branched_per_action_ent = ModelUtils.break_into_branches(

target_current_diff = torch.squeeze(
target_current_diff_branched, axis=2
)
entropy_loss = -torch.mean(
loss_masks
* torch.mean(self._log_ent_coef * target_current_diff, axis=1)
entropy_loss = -1 * ModelUtils.masked_mean(
torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks
)
return entropy_loss

memories=next_memories,
sequence_length=self.policy.sequence_length,
)
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32)
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
use_discrete = not self.policy.use_continuous_act
dones = ModelUtils.list_to_tensor(batch["done"])

24
ml-agents/mlagents/trainers/torch/layers.py


layer.weight.data *= kernel_gain
_init_methods[bias_init](layer.bias.data)
return layer
def lstm_layer(
input_size: int,
hidden_size: int,
num_layers: int = 1,
batch_first: bool = True,
forget_bias: float = 1.0,
kernel_init: Initialization = Initialization.XavierGlorotUniform,
bias_init: Initialization = Initialization.Zero,
) -> torch.nn.Module:
"""
Creates a torch.nn.LSTM and initializes its weights and biases. Provides a
forget_bias offset like is done in TensorFlow.
"""
lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=batch_first)
# Add forget_bias to forget gate bias
for name, param in lstm.named_parameters():
if "weight" in name:
_init_methods[kernel_init](param.data)
elif "bias" in name:
_init_methods[bias_init](param.data)
param.data[hidden_size : 2 * hidden_size].add_(forget_bias)
return lstm

13
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import lstm_layer
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[

)
if self.use_lstm:
self.lstm = nn.LSTM(self.h_size, self.m_size // 2, 1)
self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True)
else:
self.lstm = None

raise Exception("No valid inputs to network.")
if self.use_lstm:
encoding = encoding.view([sequence_length, -1, self.h_size])
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
encoding, memories = self.lstm(
encoding.contiguous(),
(memories[0].contiguous(), memories[1].contiguous()),
)
encoding = encoding.view([-1, self.m_size // 2])
encoding, memories = self.lstm(encoding, (memories[0], memories[1]))
encoding = encoding.reshape([-1, self.m_size // 2])
memories = torch.cat(memories, dim=-1)
return encoding, memories

正在加载...
取消
保存