浏览代码

I think it's running

/develop/centralizedcritic/counterfact
Ervin Teng 4 年前
当前提交
457b2630
共有 6 个文件被更改,包括 67 次插入27 次删除
  1. 21
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 4
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 31
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  4. 2
      ml-agents/mlagents/trainers/ppo/trainer.py
  5. 1
      ml-agents/mlagents/trainers/torch/agent_action.py
  6. 35
      ml-agents/mlagents/trainers/torch/networks.py

21
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


critic_obs=critic_obs,
)
# Actions is a hack here, we need the next actions
next_value_estimate, next_marg_val_estimate, _ = self.policy.actor_critic.critic_pass(
next_obs, actions, next_memory, sequence_length=1, critic_obs=next_critic_obs
)
# # Actions is a hack here, we need the next actions
# next_value_estimate, next_marg_val_estimate, _ = self.policy.actor_critic.critic_pass(
# next_obs, actions, next_memory, sequence_length=1, critic_obs=next_critic_obs
# )
# These aren't used in COMAttention
next_value_estimate, next_marg_val_estimate = {}, {}
next_value_estimate[name] = ModelUtils.to_numpy(next_value_estimate[name])
next_value_estimate[name] = 0.0
next_marg_val_estimate[name] = ModelUtils.to_numpy(next_marg_val_estimate[name])
next_marg_val_estimate[name] = 0.0
if done:
for k in next_value_estimate:

return value_estimates, marg_val_estimates, next_value_estimate, next_marg_val_estimate
return (
value_estimates,
marg_val_estimates,
next_value_estimate,
next_marg_val_estimate,
)

4
ml-agents/mlagents/trainers/policy/torch_policy.py


seq_len: int = 1,
critic_obs: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]:
log_probs, entropies, value_heads = self.actor_critic.get_stats_and_value(
log_probs, entropies, value_heads, marg_vals = self.actor_critic.get_stats_and_value(
return log_probs, entropies, value_heads
return log_probs, entropies, value_heads, marg_vals
@timed
def evaluate(

31
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


from typing import Dict, cast
import itertools
import numpy as np
from mlagents.torch_utils import torch
from mlagents.trainers.buffer import AgentBuffer

decay_bet = self.decay_beta.get_value(self.policy.get_current_step())
returns = {}
old_values = {}
old_marg_values = {}
)
old_marg_values[name] = ModelUtils.list_to_tensor(
batch[f"{name}_marginalized_value_estimates"]
padded_team_rewards = list(
map(
lambda x: np.asanyarray(x),
itertools.zip_longest(*batch["team_rewards"], fillvalue=np.nan),
)
)
padded_team_rewards = torch.tensor(
np.array(
list(itertools.zip_longest(*batch["team_rewards"], fillvalue=np.nan))
)
)
# Average team rewards
if "extrinsic" in returns:
all_rewards = torch.cat(
[torch.unsqueeze(returns["extrinsic"], 0), padded_team_rewards], dim=0
)
returns["extrinsic"] = torch.mean(
all_rewards[~torch.isnan(all_rewards)], dim=0
)
n_obs = len(self.policy.behavior_spec.sensor_specs)
current_obs = ObsUtil.from_buffer(batch, n_obs)
# Convert to tensors

value_loss = self.ppo_value_loss(
values, old_values, returns, decay_eps, loss_masks
)
marg_value_loss = self.ppo_value_loss(
marginalized_vals, old_marg_values, returns, decay_eps, loss_masks
)
policy_loss = self.ppo_policy_loss(
ModelUtils.list_to_tensor(batch["advantages"]),
log_probs,

loss = (
policy_loss
+ 0.5 * value_loss
+ 0.5 * (value_loss + marg_value_loss)
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
)

2
ml-agents/mlagents/trainers/ppo/trainer.py


:param lambd: GAE weighing factor.
:return: list of advantage estimates for time-steps t to T.
"""
advantage = value_estimates - marginalized_value_estimates
advantage = np.array(value_estimates) - np.array(marginalized_value_estimates)
return advantage

1
ml-agents/mlagents/trainers/torch/agent_action.py


discrete_oh = ModelUtils.actions_to_onehot(
self.discrete_tensor, discrete_branches
)
discrete_oh = torch.cat(discrete_oh, dim=1)
return torch.cat([self.continuous_tensor, discrete_oh], dim=-1)

35
ml-agents/mlagents/trainers/torch/networks.py


memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Tensors that go into ResidualSelfAttention
self_attn_inputs = []
self_attn_masks = []
# Get the self encoding separately, but keep it in the entities
concat_enc_q_obs = []
for inputs, actions in zip(q_inputs, q_actions):

torch.cat(encodes, dim=-1),
actions.to_flat(self.action_spec.discrete_branches),
]
concat_enc_q_obs.append(torch.cat(cat_encodes, dim=-1))
q_input_concat = torch.stack(concat_enc_q_obs, dim=1)
concat_enc_q_obs.append(torch.cat(cat_encodes, dim=1))
if concat_enc_q_obs:
q_input_concat = torch.stack(concat_enc_q_obs, dim=1)
self_attn_masks.append(self._get_masks_from_nans(q_inputs))
encoded_obs_action = self.obs_action_encoder(None, q_input_concat)
self_attn_inputs.append(encoded_obs_action)
# Get the self encoding separately, but keep it in the entities
concat_encoded_obs = []

processed_obs = processor(obs_input)
encodes.append(processed_obs)
concat_encoded_obs.append(torch.cat(encodes, dim=-1))
value_input_concat = torch.stack(concat_encoded_obs, dim=1)
# Get the mask from nans
value_masks = self._get_masks_from_nans(value_inputs)
q_masks = self._get_masks_from_nans(q_inputs)
encoded_obs = self.obs_encoder(None, value_input_concat)
encoded_obs_action = self.obs_action_encoder(None, q_input_concat)
encoded_entity = torch.cat([encoded_obs, encoded_obs_action], dim=1)
encoded_state = self.self_attn(encoded_entity, [value_masks, q_masks])
if concat_encoded_obs:
value_input_concat = torch.stack(concat_encoded_obs, dim=1)
# Get the mask from nans
self_attn_masks.append(self._get_masks_from_nans(value_inputs))
encoded_obs = self.obs_encoder(None, value_input_concat)
self_attn_inputs.append(encoded_obs)
if len(concat_encoded_obs) == 0:
raise Exception("No valid inputs to network.")
encoded_entity = torch.cat(self_attn_inputs, dim=1)
encoded_state = self.self_attn(encoded_entity, self_attn_masks)
inputs = encoded_state
encoding = self.linear_encoder(inputs)

memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
critic_obs: List[List[torch.Tensor]] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
all_net_inputs = [inputs]
if critic_obs is not None and critic_obs:

正在加载...
取消
保存