浏览代码

_

/exp-tanh
vincentpierre 4 年前
当前提交
52b011d6
共有 8 个文件被更改,包括 122 次插入63 次删除
  1. 2
      ml-agents/mlagents/trainers/policy/policy.py
  2. 2
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 105
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  4. 1
      ml-agents/mlagents/trainers/ppo/trainer.py
  5. 4
      ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
  6. 8
      ml-agents/mlagents/trainers/torch/action_model.py
  7. 40
      ml-agents/mlagents/trainers/torch/distributions.py
  8. 23
      ml-agents/mlagents/trainers/torch/networks.py

2
ml-agents/mlagents/trainers/policy/policy.py


seed: int,
behavior_spec: BehaviorSpec,
trainer_settings: TrainerSettings,
tanh_squash: bool = False,
tanh_squash: bool = True,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
):

2
ml-agents/mlagents/trainers/policy/torch_policy.py


seed: int,
behavior_spec: BehaviorSpec,
trainer_settings: TrainerSettings,
tanh_squash: bool = False,
tanh_squash: bool = True,
reparameterize: bool = False,
separate_critic: bool = True,
condition_sigma_on_obs: bool = True,

105
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


:param num_sequences: Number of sequences to process.
:return: Results of update.
"""
# Get decayed parameters
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step())
decay_bet = self.decay_beta.get_value(self.policy.get_current_step())
returns = {}
old_values = {}
for name in self.reward_signals:
old_values[name] = ModelUtils.list_to_tensor(
batch[f"{name}_value_estimates"]
)
returns[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns"])
with torch.autograd.detect_anomaly():
# Get decayed parameters
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step())
decay_bet = self.decay_beta.get_value(self.policy.get_current_step())
returns = {}
old_values = {}
for name in self.reward_signals:
old_values[name] = ModelUtils.list_to_tensor(
batch[f"{name}_value_estimates"]
)
returns[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns"])
n_obs = len(self.policy.behavior_spec.observation_specs)
current_obs = ObsUtil.from_buffer(batch, n_obs)
# Convert to tensors
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
n_obs = len(self.policy.behavior_spec.observation_specs)
current_obs = ObsUtil.from_buffer(batch, n_obs)
# Convert to tensors
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
actions = AgentAction.from_dict(batch)
act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
actions = AgentAction.from_dict(batch)
memories = [
ModelUtils.list_to_tensor(batch["memory"][i])
for i in range(0, len(batch["memory"]), self.policy.sequence_length)
]
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
memories = [
ModelUtils.list_to_tensor(batch["memory"][i])
for i in range(0, len(batch["memory"]), self.policy.sequence_length)
]
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
log_probs, entropy, values = self.policy.evaluate_actions(
current_obs,
masks=act_masks,
actions=actions,
memories=memories,
seq_len=self.policy.sequence_length,
)
old_log_probs = ActionLogProbs.from_dict(batch).flatten()
log_probs = log_probs.flatten()
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
value_loss = self.ppo_value_loss(
values, old_values, returns, decay_eps, loss_masks
)
policy_loss = self.ppo_policy_loss(
ModelUtils.list_to_tensor(batch["advantages"]),
log_probs,
old_log_probs,
loss_masks,
)
loss = (
policy_loss
+ 0.5 * value_loss
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
)
log_probs, entropy, values = self.policy.evaluate_actions(
current_obs,
masks=act_masks,
actions=actions,
memories=memories,
seq_len=self.policy.sequence_length,
)
old_log_probs = ActionLogProbs.from_dict(batch).flatten()
log_probs = log_probs.flatten()
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
value_loss = self.ppo_value_loss(
values, old_values, returns, decay_eps, loss_masks
)
print(log_probs)
policy_loss = self.ppo_policy_loss(
ModelUtils.list_to_tensor(batch["advantages"]),
log_probs,
old_log_probs,
loss_masks,
)
loss = (
policy_loss
+ 0.5 * value_loss
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
)
# Set optimizer learning rate
ModelUtils.update_learning_rate(self.optimizer, decay_lr)
self.optimizer.zero_grad()
loss.backward()
# Set optimizer learning rate
ModelUtils.update_learning_rate(self.optimizer, decay_lr)
self.optimizer.zero_grad()
with torch.autograd.detect_anomaly():
loss.backward()
self.optimizer.step()
update_stats = {

1
ml-agents/mlagents/trainers/ppo/trainer.py


behavior_spec,
self.trainer_settings,
condition_sigma_on_obs=False, # Faster training for PPO
tanh_squash=True,
separate_critic=True, # Match network architecture with TF
)
return policy

4
ml-agents/mlagents/trainers/tests/torch/test_hybrid.py


env = SimpleEnvironment([BRAIN_NAME], action_sizes=action_size, step_size=0.8)
new_network_settings = attr.evolve(PPO_TORCH_CONFIG.network_settings)
new_hyperparams = attr.evolve(
PPO_TORCH_CONFIG.hyperparameters, batch_size=64, buffer_size=1024
PPO_TORCH_CONFIG.hyperparameters, batch_size=64, buffer_size=1024, learning_rate=1e-3
)
config = attr.evolve(
PPO_TORCH_CONFIG,

)
check_environment_trains(
env, {BRAIN_NAME: config}, success_threshold=0.9, training_seed=1212
env, {BRAIN_NAME: config}, success_threshold=0.9
)

8
ml-agents/mlagents/trainers/torch/action_model.py


discrete_dist: Optional[List[DiscreteDistInstance]] = None
# This checks None because mypy complains otherwise
if self._continuous_distribution is not None:
if (torch.isnan(torch.mean(inputs))):
print("_get_dist inputs in action_model")
continuous_dist = self._continuous_distribution(inputs)
if self._discrete_distribution is not None:
discrete_dist = self._discrete_distribution(inputs, masks)

:params actions: The AgentAction
:return: An ActionLogProbs tuple and a torch tensor of the distribution entropies.
"""
if (torch.isnan(torch.mean(inputs))):
print("evaluate inputs in action_model")
dists = self._get_dists(inputs, masks)
log_probs, entropies = self._get_probs_and_entropy(actions, dists)
# Use the sum of entropy across actions, not the mean

:params masks: Action masks for discrete actions
:return: A tuple of torch tensors corresponding to the inference output
"""
if (torch.isnan(torch.mean(inputs))):
print("get_action_out inputs in action_model")
dists = self._get_dists(inputs, masks)
continuous_out, discrete_out, action_out_deprecated = None, None, None
if self.action_spec.continuous_size > 0 and dists.continuous is not None:

:return: Given the input, an AgentAction of the actions generated by the policy and the corresponding
ActionLogProbs and entropies.
"""
if (torch.isnan(torch.mean(inputs))):
print("forward inputs in action_model")
dists = self._get_dists(inputs, masks)
actions = self._sample_action(dists)
log_probs, entropies = self._get_probs_and_entropy(actions, dists)

40
ml-agents/mlagents/trainers/torch/distributions.py


def entropy(self):
return torch.mean(
0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON),
0.5 * torch.log(2 * math.pi * math.e * self.std ** 2 + EPSILON),
dim=1,
keepdim=True,
) # Use equivalent behavior to TF

def __init__(self, mean, std):
super().__init__(mean, std)
self.transform = torch.distributions.transforms.TanhTransform(cache_size=1)
if torch.isnan(torch.mean(std)):
print("Nan in TanhGaussianDistInstance init, std")
def sample(self):
unsquashed_sample = super().sample()

def _inverse_tanh(self, value):
capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
# capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
capped_value = (1-EPSILON) * value
unsquashed = self.transform.inv(value)
return super().log_prob(unsquashed) - self.transform.log_abs_det_jacobian(
unsquashed, value
# capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
capped_value = 0.1 * value
unsquashed = self.transform.inv(capped_value)
# capped_unsqached = self.transform.inv(capped_value)
tmp= super().log_prob(unsquashed) - self.transform.log_abs_det_jacobian(
unsquashed, None
print("tmp decomposition", value, capped_value, unsquashed, super().log_prob(unsquashed) , self.transform.log_abs_det_jacobian(
unsquashed, None
))
if torch.isnan(torch.mean(value)):
print("Nan in log_prob(self, value), value")
if torch.isnan(torch.mean(super().log_prob(unsquashed))):
print("Nan in log_prob(self, value), super().log_prob(unsquashed)")
if torch.isnan(torch.mean(self.transform.log_abs_det_jacobian(unsquashed, None ))):
print("Nan in log_prob(self, value), log_abs_det_jacobian")
return tmp
class CategoricalDistInstance(DiscreteDistInstance):

def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
mu = self.mu(inputs)
if self.conditional_sigma:
if torch.isnan(torch.mean(inputs)):
print("GaussianDistribution conditional log sigma inputs")
# use this to replace torch.expand() becuase it is not supported in
# use this to replace torch.expand() because it is not supported in
if torch.isnan(torch.mean(self.log_sigma)):
print("GaussianDistribution self.log_sigma")
if torch.isnan(torch.mean(mu)):
print("GaussianDistribution mu")
if torch.isnan(torch.mean(inputs)):
print("GaussianDistribution inputs")
if torch.isnan(torch.mean(log_sigma)):
print("GaussianDistribution log sigma NaN")
if self.tanh_squash:
return TanhGaussianDistInstance(mu, torch.exp(log_sigma))
else:

23
ml-agents/mlagents/trainers/torch/networks.py


def forward(
self,
inputs: List[torch.Tensor],
inputs_: List[torch.Tensor],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,

obs_input = inputs[idx]
obs_input = inputs_[idx]
processed_obs = processor(obs_input)
encodes.append(processed_obs)

encoding = encoding.reshape([-1, sequence_length, self.h_size])
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
if torch.isnan(torch.mean(encoding)):
print("NaN in Netowrk Body :")
print(torch.mean(inputs_[0]), torch.mean(self.processors[0](inputs_[0])))
print(self.processors[0].conv_layers[0].weight.data)
print(self.processors[0].conv_layers[2].weight.data)
print("\n\n\n\n\n")
raise _
return encoding, memories

encoding, memories = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
)
if torch.isnan(torch.mean(encoding)):
print("SimpleActor encoding in get_action_stats")
action, log_probs, entropies = self.action_model(encoding, masks)
return action, log_probs, entropies, memories

encoding, memories = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
)
if torch.isnan(torch.mean(encoding)):
print("SharedActorCritic, get_stats_and_value, encoding")
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
value_outputs = self.value_heads(encoding)
return log_probs, entropies, value_outputs

sequence_length: int = 1,
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
for i in inputs:
if torch.isnan(torch.mean(i)):
print("Nan input to network body in SeparateActorCritic")
if torch.isnan(torch.mean(encoding)):
print("SeparateActorCritic, get_stats_and_value, encoding")
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
value_outputs, critic_mem_outs = self.critic(
inputs, memories=critic_mem, sequence_length=sequence_length

正在加载...
取消
保存