浏览代码

Adding some changes

/exp-diverse-behavior
vincentpierre 3 年前
当前提交
8da21669
共有 3 个文件被更改,包括 28 次插入19 次删除
  1. 17
      config/sac/Pyramids.yaml
  2. 2
      config/sac/Walker.yaml
  3. 28
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

17
config/sac/Pyramids.yaml


hidden_units: 512
num_layers: 3
vis_encode_type: simple
goal_conditioning_type: none
gail:
gamma: 0.99
strength: 0.01
learning_rate: 0.0003
use_actions: true
use_vail: false
demo_path: Project/Assets/ML-Agents/Examples/Pyramids/Demos/ExpertPyramid.demo
# gail:
# gamma: 0.99
# strength: 0.01
# learning_rate: 0.0003
# use_actions: true
# use_vail: false
# demo_path: Project/Assets/ML-Agents/Examples/Pyramids/Demos/ExpertPyramid.demo
max_steps: 3000000
max_steps: 30000000
time_horizon: 128
summary_freq: 30000

2
config/sac/Walker.yaml


learning_rate: 0.0003
learning_rate_schedule: constant
batch_size: 1024
buffer_size: 2000000
buffer_size: 200000 #2000000
buffer_init_steps: 0
tau: 0.005
steps_per_update: 30.0

28
ml-agents/mlagents/trainers/sac/optimizer_torch.py


class DiverseNetwork(torch.nn.Module):
EPSILON = 1e-10
STRENGTH = 0.1
STRENGTH = 1.0
state_encoder_settings = settings
print("Settings : strength:", self.STRENGTH, " use_actions:", self._use_actions)
# state_encoder_settings = settings
state_encoder_settings = NetworkSettings(True)
if state_encoder_settings.memory is not None:
state_encoder_settings.memory = None
logger.warning(

for obs, spec in zip(obs_input, self._all_obs_specs)
if spec.observation_type != ObservationType.GOAL_SIGNAL
]
action = self._action_flattener.forward(action_input)
action = self._action_flattener.forward(action_input).reshape(-1, self._action_flattener.flattened_size)
if detach_action:
action = action.detach()
hidden, _ = self._encoder.forward(tensor_obs, action)

return rewards
def loss(self, obs_input, action_input, masks, detach_action=True) -> torch.Tensor:
return -ModelUtils.masked_mean(
self.rewards(obs_input, action_input, detach_action), masks
# print( ">>> ",obs_input[self._diverse_index][0],self.predict(obs_input, action_input, detach_action)[0], self.predict([x*0 for x in obs_input], action_input, detach_action * 0)[0] )
return - ModelUtils.masked_mean(
self.rewards(obs_input, action_input, detach_action) , masks
)

with torch.no_grad():
v_backup = min_policy_qs[name] - torch.mean(
branched_ent_bonus, axis=0
)
) + self._mede_network.STRENGTH * self._mede_network.rewards(obs, act)
print("The discrete case is much more complicated than that")
# Add continuous entropy bonus to minimum Q
if self._action_spec.continuous_size > 0:
v_backup += torch.sum(

cont_log_probs = log_probs.continuous_tensor
batch_policy_loss += torch.mean(
_cont_ent_coef * cont_log_probs - all_mean_q1.unsqueeze(1), dim=1
) - self._mede_network.STRENGTH * self._mede_network.rewards(obs, act)
)
batch_policy_loss += - self._mede_network.STRENGTH * self._mede_network.rewards(obs, act)
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
return policy_loss

with torch.no_grad():
cont_log_probs = log_probs.continuous_tensor
target_current_diff = torch.sum(
cont_log_probs + self.target_entropy.continuous, dim=1
)
cont_log_probs, dim=1) + self.target_entropy.continuous
# print(self.target_entropy.continuous, cont_log_probs, torch.sum(
# cont_log_probs, dim=1) + self.target_entropy.continuous)
# We update all the _cont_ent_coef as one block
entropy_loss += -1 * ModelUtils.masked_mean(
_cont_ent_coef * target_current_diff, loss_masks

total_value_loss.backward()
self.value_optimizer.step()
ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr)
self.entropy_optimizer.zero_grad()
entropy_loss.backward()

torch.exp(self._log_ent_coef.continuous)
).item(),
"Policy/Learning Rate": decay_lr,
"Policy/Entropy Loss": entropy_loss.item(),
"Policy/MEDE Loss": mede_loss.item(),
}

"Optimizer:value_optimizer": self.value_optimizer,
"Optimizer:entropy_optimizer": self.entropy_optimizer,
"Optimizer:mede_optimizer": self._mede_optimizer,
"Optimizer:mede_network": self._mede_network,
}
for reward_provider in self.reward_signals.values():
modules.update(reward_provider.get_modules())
正在加载...
取消
保存