浏览代码

Try subtract marginalized value

/develop/centralizedcritic/counterfact
Ervin Teng 4 年前
当前提交
6b8b3db3
共有 4 个文件被更改,包括 60 次插入10 次删除
  1. 10
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 2
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  3. 39
      ml-agents/mlagents/trainers/ppo/trainer.py
  4. 19
      ml-agents/mlagents/trainers/torch/networks.py

10
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


memory = torch.zeros([1, 1, self.policy.m_size])
value_estimates, next_memory = self.policy.actor_critic.critic_pass(
value_estimates, marg_val_estimates, next_memory = self.policy.actor_critic.critic_pass(
current_obs,
memory,
sequence_length=batch.num_experiences,

next_value_estimate, _ = self.policy.actor_critic.critic_pass(
next_value_estimate, next_marg_val_estimate, _ = self.policy.actor_critic.critic_pass(
next_obs, next_memory, sequence_length=1, critic_obs=next_critic_obs
)

for name, estimate in marg_val_estimates.items():
marg_val_estimates[name] = ModelUtils.to_numpy(estimate)
next_marg_val_estimate[name] = ModelUtils.to_numpy(next_marg_val_estimate[name])
return value_estimates, next_value_estimate
return value_estimates, marg_val_estimates, next_value_estimate, next_marg_val_estimate

2
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
log_probs, entropy, values = self.policy.evaluate_actions(
log_probs, entropy, values, marginalized_vals = self.policy.evaluate_actions(
current_obs,
masks=act_masks,
actions=actions,

39
ml-agents/mlagents/trainers/ppo/trainer.py


self.policy.update_normalization(agent_buffer_trajectory)
# Get all value estimates
value_estimates, value_next = self.optimizer.get_trajectory_value_estimates(
(
value_estimates,
marginalized_value_estimates,
value_next,
marg_value_next,
) = self.optimizer.get_trajectory_value_estimates(
agent_buffer_trajectory,
trajectory.next_obs,
trajectory.next_collab_obs,

for name, v in value_estimates.items():
agent_buffer_trajectory[f"{name}_value_estimates"].extend(v)
agent_buffer_trajectory[f"{name}_marginalized_value_estimates"].extend(
marginalized_value_estimates[name]
)
self._stats_reporter.add_stat(
f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Value Estimate",
np.mean(v),

local_value_estimates = agent_buffer_trajectory[
f"{name}_value_estimates"
].get_batch()
m_value_estimates = agent_buffer_trajectory[
f"{name}_marginalized_value_estimates"
].get_batch()
local_advantage = get_gae(
local_advantage = get_team_gae(
marginalized_value_estimates=m_value_estimates,
value_next=bootstrap_value,
gamma=self.optimizer.reward_signals[name].gamma,
lambd=self.hyperparameters.lambd,

delta_t = rewards + gamma * value_estimates[1:] - value_estimates[:-1]
advantage = discount_rewards(r=delta_t, gamma=gamma * lambd)
return advantage
def get_team_gae(
rewards,
value_estimates,
marginalized_value_estimates,
value_next=0.0,
gamma=0.99,
lambd=0.95,
):
"""
Computes generalized advantage estimate for use in updating policy.
:param rewards: list of rewards for time-steps t to T.
:param value_next: Value estimate for time-step T+1.
:param value_estimates: list of value estimates for time-steps t to T.
:param gamma: Discount factor.
:param lambd: GAE weighing factor.
:return: list of advantage estimates for time-steps t to T.
"""
value_estimates = np.append(value_estimates, value_next)
delta_t = rewards + gamma * value_estimates[1:] - marginalized_value_estimates
advantage = discount_rewards(r=delta_t, gamma=gamma * lambd)
return advantage

19
ml-agents/mlagents/trainers/torch/networks.py


) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
all_net_inputs = [inputs]
if critic_obs is not None:
if critic_obs is not None and critic_obs:
mar_value_outputs, _ = self.critic(
critic_obs, memories=critic_mem, sequence_length=sequence_length
)
else:
mar_value_outputs = None
if mar_value_outputs is None:
mar_value_outputs = value_outputs
return value_outputs, memories_out
return value_outputs, mar_value_outputs, memories_out
def get_stats_and_value(
self,

)
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
all_net_inputs = [inputs]
if critic_obs is not None:
if critic_obs is not None and critic_obs:
mar_value_outputs, _ = self.critic(
critic_obs, memories=critic_mem, sequence_length=sequence_length
)
return log_probs, entropies, value_outputs
return log_probs, entropies, value_outputs, mar_value_outputs
def get_action_stats(
self,

正在加载...
取消
保存