浏览代码

Remove Q-net for perf

/develop/coma2/clip
Ervin Teng 4 年前
当前提交
3283b6a1
共有 4 个文件被更改,包括 40 次插入47 次删除
  1. 40
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 4
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 2
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  4. 41
      ml-agents/mlagents/trainers/torch/networks.py

40
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


current_obs = ObsUtil.from_buffer(batch, n_obs)
team_obs = TeamObsUtil.from_buffer(batch, n_obs)
#next_obs = ObsUtil.from_buffer_next(batch, n_obs)
#next_team_obs = TeamObsUtil.from_buffer_next(batch, n_obs)
# next_obs = ObsUtil.from_buffer_next(batch, n_obs)
# next_team_obs = TeamObsUtil.from_buffer_next(batch, n_obs)
# Convert to tensors
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]

]
#next_team_obs = [
# next_team_obs = [
#]
# ]
#next_actions = AgentAction.from_dict_next(batch)
#next_team_actions = AgentAction.from_team_dict_next(batch)
# next_actions = AgentAction.from_dict_next(batch)
# next_team_actions = AgentAction.from_team_dict_next(batch)
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]
next_obs = [obs.unsqueeze(0) for obs in next_obs]

# ]
next_critic_obs = [
ModelUtils.list_to_tensor_list(_list_obs) for _list_obs in next_critic_obs
]
]
]
]
q_estimates, baseline_estimates, mem = self.policy.actor_critic.critic_pass(
baseline_estimates, _ = self.policy.actor_critic.critic_pass(
current_obs,
actions,
memory,

team_obs=next_critic_obs,
)
#next_value_estimates, next_marg_val_estimates, next_mem = self.policy.actor_critic.target_critic_pass(
# next_value_estimates, next_marg_val_estimates, next_mem = self.policy.actor_critic.target_critic_pass(
# next_obs,
# next_actions,
# memory,

#)
# )
# # Actions is a hack here, we need the next actions
# next_value_estimate, next_marg_val_estimate, _ = self.policy.actor_critic.critic_pass(

# print(baseline_estimates)
# print(value_estimates)
# print(boot_value_baseline[k][-1])
#if done and not all_dones:
# if done and not all_dones:
#elif all_dones:
# elif all_dones:
#else:
# else:
#print("final", boot_value_estimates)
#print("value", value_estimates)
#print("base", baseline_estimates)
# print("final", boot_value_estimates)
# print("value", value_estimates)
# print("base", baseline_estimates)
return (
value_estimates,
baseline_estimates,
boot_value_estimates,
)
return (value_estimates, baseline_estimates, boot_value_estimates)

4
ml-agents/mlagents/trainers/policy/torch_policy.py


team_obs: Optional[List[List[torch.Tensor]]] = None,
team_act: Optional[List[AgentAction]] = None,
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]:
log_probs, entropies, q_heads, baseline, values = self.actor_critic.get_stats_and_value(
log_probs, entropies, baseline, values = self.actor_critic.get_stats_and_value(
return log_probs, entropies, q_heads, baseline, values
return log_probs, entropies, baseline, values
@timed
def evaluate(

2
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
log_probs, entropy, qs, baseline_vals, values = self.policy.evaluate_actions(
log_probs, entropy, baseline_vals, values = self.policy.evaluate_actions(
current_obs,
masks=act_masks,
actions=actions,

41
ml-agents/mlagents/trainers/torch/networks.py


)
return encoding, memories
def forward(
self,
f_enc: torch.Tensor,

memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encoding, memories = self.network_body.value(
obs, memories, sequence_length
)
encoding, memories = self.network_body.value(obs, memories, sequence_length)
output = self.value_heads(encoding)
return output, memories

stream_names, sensor_specs, network_settings, action_spec=action_spec
)
@property
def memory_size(self) -> int:
return self.network_body.memory_size + self.critic.memory_size

if team_obs is not None and team_obs:
all_obs.extend(team_obs)
value_outputs, _ = self.target.value(
all_obs,
memories=critic_mem,
sequence_length=sequence_length,
value_outputs, critic_mem_out = self.target.value(
all_obs, memories=critic_mem, sequence_length=sequence_length
)
# if mar_value_outputs is None:

if team_obs is not None and team_obs:
all_obs.extend(team_obs)
value_outputs, _ = self.critic.value(
all_obs,
memories=critic_mem,
sequence_length=sequence_length,
value_outputs, critic_mem_out = self.critic.value(
all_obs, memories=critic_mem, sequence_length=sequence_length
)
# if mar_value_outputs is None:

sequence_length: int = 1,
team_obs: List[List[torch.Tensor]] = None,
team_act: List[AgentAction] = None,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]:
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
all_obs = [inputs]

if team_act is not None and team_act:
all_acts.extend(team_act)
baseline_outputs, _ = self.critic.baseline(
baseline_outputs, critic_mem_out = self.critic.baseline(
inputs,
team_obs,
team_act,

q_out, critic_mem_out = self.critic.q_net(
all_obs, all_acts, memories=critic_mem, sequence_length=sequence_length
)
# q_out, critic_mem_out = self.critic.q_net(
# all_obs, all_acts, memories=critic_mem, sequence_length=sequence_length
# )
# if mar_value_outputs is None:
# mar_value_outputs = value_outputs

memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1)
else:
memories_out = None
return q_out, baseline_outputs, memories_out
return baseline_outputs, memories_out
def get_stats_and_value(
self,

)
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
q_outputs, baseline_outputs, _ = self.critic_pass(
baseline_outputs, _ = self.critic_pass(
inputs,
actions,
memories=critic_mem,

)
value_outputs, _ = self.target_critic_value(inputs, memories=critic_mem, sequence_length=sequence_length, team_obs=team_obs)
value_outputs, _ = self.target_critic_value(
inputs,
memories=critic_mem,
sequence_length=sequence_length,
team_obs=team_obs,
)
return log_probs, entropies, q_outputs, baseline_outputs, value_outputs
return log_probs, entropies, baseline_outputs, value_outputs
def get_action_stats(
self,

正在加载...
取消
保存