浏览代码

First version of MEDE, crawler does not seem to work properly, I suspect the actions make it distinguishable to the discriminator but not to the human eye

/exp-continuous-div
vincentpierre 3 年前
当前提交
bab3ecb7
共有 2 个文件被更改,包括 124 次插入6 次删除
  1. 3
      config/sac/Walker.yaml
  2. 127
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

3
config/sac/Walker.yaml


hidden_units: 256
num_layers: 3
vis_encode_type: simple
goal_conditioning_type: none
strength: 1.0
strength: 0.1
keep_checkpoints: 5
max_steps: 15000000
time_horizon: 1000

127
ml-agents/mlagents/trainers/sac/optimizer_torch.py


logger = get_logger(__name__)
from mlagents.trainers.torch.action_flattener import ActionFlattener
from mlagents_envs.base_env import ObservationType
from mlagents.trainers.torch.networks import NetworkBody
from mlagents_envs.base_env import BehaviorSpec
class DiverseNetwork(torch.nn.Module):
EPSILON = 1e-10
STRENGTH = 0.1
def __init__(self, specs: BehaviorSpec, settings) -> None:
super().__init__()
self._use_actions = True
state_encoder_settings = settings
if state_encoder_settings.memory is not None:
state_encoder_settings.memory = None
logger.warning(
"memory was specified in network_settings but is not supported. It is being ignored."
)
self._action_flattener = ActionFlattener(specs.action_spec)
new_spec = [
spec
for spec in specs.observation_specs
if spec.observation_type != ObservationType.GOAL_SIGNAL
]
diverse_spec = [
spec
for spec in specs.observation_specs
if spec.observation_type == ObservationType.GOAL_SIGNAL
][0]
print(" > ", new_spec, "\n\n\n", " >> ", diverse_spec)
self._all_obs_specs = specs.observation_specs
self.diverse_size = diverse_spec.shape[0]
if self._use_actions:
self._encoder = NetworkBody(
new_spec, state_encoder_settings, self._action_flattener.flattened_size
)
else:
self._encoder = NetworkBody(new_spec, state_encoder_settings)
self._last_layer = torch.nn.Linear(
state_encoder_settings.hidden_units, self.diverse_size
)
self._diverse_index = -1
self._max_index = len(specs.observation_specs)
for i, spec in enumerate(specs.observation_specs):
if spec.observation_type == ObservationType.GOAL_SIGNAL:
self._diverse_index = i
def predict(self, obs_input, action_input, detach_action=False) -> torch.Tensor:
# Convert to tensors
tensor_obs = [
obs
for obs, spec in zip(obs_input, self._all_obs_specs)
if spec.observation_type != ObservationType.GOAL_SIGNAL
]
if self._use_actions:
action = self._action_flattener.forward(action_input)
if detach_action:
action = action.detach()
hidden, _ = self._encoder.forward(tensor_obs, action)
else:
hidden, _ = self._encoder.forward(tensor_obs)
# add a VAE (like in VAIL ?)
prediction = torch.softmax(self._last_layer(hidden), dim=1)
return prediction
def copy_normalization(self, thing):
self._encoder.processors[0].copy_normalization(thing.processors[1])
def rewards(self, obs_input, action_input, detach_action=False) -> torch.Tensor:
truth = obs_input[self._diverse_index]
prediction = self.predict(obs_input, action_input, detach_action)
rewards = torch.log(torch.sum((prediction * truth), dim=1) + self.EPSILON)
return rewards
def loss(self, obs_input, action_input, masks, detach_action=True) -> torch.Tensor:
return -ModelUtils.masked_mean(
self.rewards(obs_input, action_input, detach_action), masks
)
class TorchSACOptimizer(TorchOptimizer):
class PolicyValueNetwork(nn.Module):

self._critic.parameters()
)
self._mede_network = DiverseNetwork(
self.policy.behavior_spec, self.policy.network_settings
)
self._mede_optimizer = torch.optim.Adam(
list(self._mede_network.parameters()), lr=hyperparameters.learning_rate
)
logger.debug("value_vars")
for param in value_params:
logger.debug(param.shape)

q1p_out: Dict[str, torch.Tensor],
q2p_out: Dict[str, torch.Tensor],
loss_masks: torch.Tensor,
obs,
act,
) -> torch.Tensor:
min_policy_qs = {}
with torch.no_grad():

if self._action_spec.discrete_size <= 0:
for name in values.keys():
with torch.no_grad():
v_backup = min_policy_qs[name] - torch.sum(
_cont_ent_coef * log_probs.continuous_tensor, dim=1
v_backup = (
min_policy_qs[name]
- torch.sum(_cont_ent_coef * log_probs.continuous_tensor, dim=1)
+ self._mede_network.STRENGTH
* self._mede_network.rewards(obs, act)
)
value_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks

log_probs: ActionLogProbs,
q1p_outs: Dict[str, torch.Tensor],
loss_masks: torch.Tensor,
obs,
act,
) -> torch.Tensor:
_cont_ent_coef, _disc_ent_coef = (
self._log_ent_coef.continuous,

cont_log_probs = log_probs.continuous_tensor
batch_policy_loss += torch.mean(
_cont_ent_coef * cont_log_probs - all_mean_q1.unsqueeze(1), dim=1
)
) - self._mede_network.STRENGTH * self._mede_network.rewards(obs, act)
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
return policy_loss

self.target_network.network_body.copy_normalization(
self.policy.actor.network_body
)
self._mede_network.copy_normalization(self.policy.actor.network_body)
self._critic.network_body.copy_normalization(self.policy.actor.network_body)
sampled_actions, log_probs, _, _, = self.policy.actor.get_action_and_stats(
current_obs,

q1_stream, q2_stream, target_values, dones, rewards, masks
)
value_loss = self.sac_value_loss(
log_probs, value_estimates, q1p_out, q2p_out, masks
log_probs,
value_estimates,
q1p_out,
q2p_out,
masks,
current_obs,
sampled_actions,
)
policy_loss = self.sac_policy_loss(
log_probs, q1p_out, masks, current_obs, sampled_actions
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks)
entropy_loss = self.sac_entropy_loss(log_probs, masks)
total_value_loss = q1_loss + q2_loss

entropy_loss.backward()
self.entropy_optimizer.step()
mede_loss = self._mede_network.loss(current_obs, sampled_actions, masks)
ModelUtils.update_learning_rate(self._mede_optimizer, decay_lr)
self._mede_optimizer.zero_grad()
mede_loss.backward()
self._mede_optimizer.step()
# Update target network
ModelUtils.soft_update(self._critic, self.target_network, self.tau)
update_stats = {

torch.exp(self._log_ent_coef.continuous)
).item(),
"Policy/Learning Rate": decay_lr,
"Policy/MEDE Loss": mede_loss.item(),
}
return update_stats

"Optimizer:policy_optimizer": self.policy_optimizer,
"Optimizer:value_optimizer": self.value_optimizer,
"Optimizer:entropy_optimizer": self.entropy_optimizer,
"Optimizer:mede_optimizer": self._mede_optimizer,
}
for reward_provider in self.reward_signals.values():
modules.update(reward_provider.get_modules())
正在加载...
取消
保存