浏览代码

Fix a couple additional bugs

/develop/add-fire
Arthur Juliani 5 年前
当前提交
8c6f4696
共有 8 个文件被更改,包括 32 次插入27 次删除
  1. 13
      ml-agents/mlagents/trainers/distributions_torch.py
  2. 4
      ml-agents/mlagents/trainers/models.py
  3. 6
      ml-agents/mlagents/trainers/models_torch.py
  4. 11
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  5. 2
      ml-agents/mlagents/trainers/policy/torch_policy.py
  6. 2
      ml-agents/mlagents/trainers/ppo/optimizer_tf.py
  7. 19
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  8. 2
      ml-agents/mlagents/trainers/ppo/trainer.py

13
ml-agents/mlagents/trainers/distributions_torch.py


def __init__(self, hidden_size, num_outputs, **kwargs):
super(GaussianDistribution, self).__init__(**kwargs)
self.mu = nn.Linear(hidden_size, num_outputs)
self.log_sigma_sq = nn.Linear(hidden_size, num_outputs)
nn.init.xavier_uniform(self.mu.weight, gain=0.01)
nn.init.xavier_uniform(self.log_sigma_sq.weight, gain=0.01)
# self.log_sigma_sq = nn.Linear(hidden_size, num_outputs)
self.log_sigma = nn.Parameter(torch.zeros(1, num_outputs, requires_grad=True))
nn.init.xavier_uniform_(self.mu.weight, gain=0.01)
# nn.init.xavier_uniform(self.log_sigma_sq.weight, gain=0.01)
log_sig = self.log_sigma_sq(inputs)
return [
distributions.normal.Normal(loc=mu, scale=torch.sqrt(torch.exp(log_sig)))
]
# log_sig = torch.tanh(self.log_sigma_sq(inputs)) * 3.0
return [distributions.normal.Normal(loc=mu, scale=torch.exp(self.log_sigma))]
class MultiCategoricalDistribution(nn.Module):

4
ml-agents/mlagents/trainers/models.py


:param action_masks: The mask for the logits. Must be of dimension [None x total_number_of_action]
:param action_size: A list containing the number of possible actions for each branch
:return: The action output dimension [batch_size, num_branches], the concatenated
normalized probs (after softmax)
and the concatenated normalized log probs
normalized log_probs (after softmax)
and the concatenated normalized log log_probs
"""
branch_masks = ModelUtils.break_into_branches(action_masks, action_size)
raw_probs = [

6
ml-agents/mlagents/trainers/models_torch.py


self.running_variance = torch.ones(vec_obs_size)
def forward(self, inputs):
inputs = torch.from_numpy(inputs)
normalized_state = torch.clamp(
(inputs - self.running_mean)
/ torch.sqrt(

return normalized_state
def update(self, vector_input):
vector_input = torch.from_numpy(vector_input)
mean_current_observation = vector_input.mean(0).type(torch.float32)
new_mean = self.running_mean + (
mean_current_observation - self.running_mean

def forward(self, hidden):
value_outputs = {}
for stream_name, _ in self.value_heads.items():
value_outputs[stream_name] = self.value_heads[stream_name](hidden)
value_outputs[stream_name] = self.value_heads[stream_name](hidden).squeeze(
-1
)
return (
value_outputs,
torch.mean(torch.stack(list(value_outputs.values())), dim=0),

11
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


next_value_estimate, next_value = self.policy.critic(next_obs, next_obs)
for name, estimate in value_estimates.items():
value_estimates[name] = estimate.squeeze(-1).detach().numpy()
next_value_estimate[name] = (
next_value_estimate[name].squeeze(-1).detach().numpy()
)
value_estimates[name] = estimate.detach().numpy()
next_value_estimate[name] = next_value_estimate[name].detach().numpy()
if done:
for k in next_value_estimate:
if self.reward_signals[k].use_terminal_states:
next_value_estimate[k] = 0.0
return value_estimates, next_value_estimate

2
ml-agents/mlagents/trainers/policy/torch_policy.py


If this policy normalizes vector observations, this will update the norm values in the graph.
:param vector_obs: The vector observations to add to the running estimate of the distribution.
"""
vector_obs = np.array(vector_obs)
vector_obs = torch.Tensor(vector_obs)
vector_obs = [vector_obs]
if self.use_vec_obs and self.normalize:
self.critic.network_body.update_normalization(vector_obs)

2
ml-agents/mlagents/trainers/ppo/optimizer_tf.py


name="old_probabilities",
)
# Break old log probs into separate branches
# Break old log log_probs into separate branches
old_log_prob_branches = ModelUtils.break_into_branches(
self.all_old_log_probs, self.policy.act_size
)

19
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


value_losses = []
for name, head in values.items():
old_val_tensor = torch.DoubleTensor(old_values[name])
returns_tensor = torch.DoubleTensor(returns[name])
old_val_tensor = torch.Tensor(old_values[name])
returns_tensor = torch.Tensor(returns[name])
torch.sum(head, dim=1) - old_val_tensor, -decay_epsilon, decay_epsilon
head - old_val_tensor, -decay_epsilon, decay_epsilon
v_opt_a = (returns_tensor - torch.sum(head, dim=1)) ** 2
v_opt_a = (returns_tensor - head) ** 2
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
value_loss = torch.mean(torch.max(v_opt_a, v_opt_b))
value_losses.append(value_loss)

def ppo_policy_loss(self, advantages, probs, old_probs, masks):
def ppo_policy_loss(self, advantages, log_probs, old_log_probs, masks):
:param probs: Current policy probabilities
:param old_probs: Past policy probabilities
:param log_probs: Current policy probabilities
:param old_log_probs: Past policy probabilities
advantage = torch.from_numpy(np.expand_dims(advantages, -1))
advantage = torch.Tensor(advantages).unsqueeze(-1)
old_log_probs = torch.Tensor(old_log_probs)
r_theta = torch.exp(probs - torch.DoubleTensor(old_probs))
r_theta = torch.exp(log_probs - old_log_probs)
p_opt_a = r_theta * advantage
p_opt_b = (
torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage

2
ml-agents/mlagents/trainers/ppo/trainer.py


local_value_estimates = agent_buffer_trajectory[
"{}_value_estimates".format(name)
].get_batch()
local_advantage = get_gae(
rewards=local_rewards,
value_estimates=local_value_estimates,

)
local_return = local_advantage + local_value_estimates
# This is later use as target for the different value estimates
agent_buffer_trajectory["{}_returns".format(name)].set(local_return)
agent_buffer_trajectory["{}_advantage".format(name)].set(local_advantage)

正在加载...
取消
保存