浏览代码

Use same network

/develop/coma2/samenet
Ervin Teng 4 年前
当前提交
2be83146
共有 2 个文件被更改,包括 10 次插入10 次删除
  1. 8
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  2. 12
      ml-agents/mlagents/trainers/torch/networks.py

8
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


def coma_regularizer_loss(self, values: Dict[str, torch.Tensor], baseline_values: Dict[str, torch.Tensor]):
reg_losses = []
for name, head in values.items():
reg_loss = (baseline_values[name] - head) ** 2
reg_loss = torch.nn.functional.mse_loss(head, baseline_values[name])
reg_losses.append(reg_loss)
value_loss = torch.mean(torch.stack(reg_losses))
return value_loss

torch.min(p_opt_a, p_opt_b), loss_masks
)
return policy_loss
@timed
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:

values, old_values, returns_v, decay_eps, loss_masks
)
# Regularizer loss reduces bias between the baseline and values. Other
# regularizers are possible here.
regularizer_loss = self.coma_regularizer_loss(values, baseline_vals)
policy_loss = self.ppo_policy_loss(

loss = (
policy_loss
+ 0.25 * (value_loss + baseline_loss)
+ 1.0 * regularizer_loss
+ 0.25 * regularizer_loss
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
)

12
ml-agents/mlagents/trainers/torch/networks.py


self.critic = CentralizedValueNetwork(
stream_names, sensor_specs, network_settings, action_spec=action_spec
)
self.target = CentralizedValueNetwork(
stream_names, sensor_specs, network_settings, action_spec=action_spec
)
# self.target = CentralizedValueNetwork(
# stream_names, sensor_specs, network_settings, action_spec=action_spec
# )
@property
def memory_size(self) -> int:

if team_obs is not None and team_obs:
all_obs.extend(team_obs)
value_outputs, critic_mem_out = self.target.value(
value_outputs, critic_mem_out = self.critic.value(
all_obs, memories=critic_mem, sequence_length=sequence_length
)

if team_act is not None and team_act:
all_acts.extend(team_act)
baseline_outputs, _ = self.target.baseline(
baseline_outputs, _ = self.critic.baseline(
inputs,
team_obs,
team_act,

value_outputs, critic_mem_out = self.target.q_net(
value_outputs, critic_mem_out = self.critic.q_net(
all_obs, all_acts, memories=critic_mem, sequence_length=sequence_length
)

正在加载...
取消
保存