浏览代码

_

/exp-robot
vincentpierre 4 年前
当前提交
9fbc2e0e
共有 2 个文件被更改,包括 43 次插入6 次删除
  1. 22
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  2. 27
      ml-agents/mlagents/trainers/torch/networks.py

22
ml-agents/mlagents/trainers/sac/optimizer_torch.py


policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks)
entropy_loss = self.sac_entropy_loss(log_probs, masks)
# Compute surrogate loss for predicting cube position :
l_1 = self.value_network.q1_network.network_body.get_surrogate_loss(current_obs)
l_2 = self.value_network.q2_network.network_body.get_surrogate_loss(current_obs)
l_v = self.target_network.network_body.get_surrogate_loss(current_obs)
surrogate_loss_v = (l_1 + l_2 + l_v) * 0.05
surrogate_loss_p = self.policy.actor_critic.network_body.get_surrogate_loss(current_obs) * 0.05
surrogate_loss = surrogate_loss_v + surrogate_loss_p
policy_loss.backward()
(policy_loss + surrogate_loss_p).backward()
total_value_loss.backward()
(total_value_loss + surrogate_loss_v).backward()
self.value_optimizer.step()
ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr)

"Losses/Value Loss": value_loss.item(),
"Losses/Q1 Loss": q1_loss.item(),
"Losses/Q2 Loss": q2_loss.item(),
"Losses/Surrogate Loss": surrogate_loss.item(),
"Policy/Discrete Entropy Coeff": torch.mean(
torch.exp(self._log_ent_coef.discrete)
).item(),

27
ml-agents/mlagents/trainers/torch/networks.py


normalize=self.normalize,
)
total_enc_size = sum(self.embedding_sizes) + encoded_act_size
total_enc_size = sum(self.embedding_sizes) + encoded_act_size - 9
self.surrogate_predictor = torch.nn.Linear(self.h_size, 9)
self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
)

actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
retrun_target = False
processed_obs = processor(obs_input)
encodes.append(processed_obs)
if obs_input.shape[1] == 9:
target = obs_input
if retrun_target:
return target
else:
processed_obs = processor(obs_input)
encodes.append(processed_obs)
if len(encodes) == 0:
raise Exception("No valid inputs to network.")

encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
return encoding, memories
def get_surrogate_loss(self, inputs: List[torch.Tensor]) -> torch.Tensor:
prediction, _ = self.forward(inputs)
prediction = self.surrogate_predictor(prediction)
target = self.forward(inputs, retrun_target=True)
loss = torch.sum((prediction - target) ** 2, dim=1)
loss = torch.mean(loss)
return loss
class ValueNetwork(nn.Module):

正在加载...
取消
保存