浏览代码

no baseline

/develop/coma-noact
Andrew Cohen 4 年前
当前提交
511a9a7e
共有 2 个文件被更改,包括 16 次插入5 次删除
  1. 12
      ml-agents/mlagents/trainers/ppo/trainer.py
  2. 9
      ml-agents/mlagents/trainers/torch/networks.py

12
ml-agents/mlagents/trainers/ppo/trainer.py


#local_advantage = np.array(q_estimates) - np.array(
local_advantage = np.array(returns_v) - np.array(returns_b)
#local_advantage = np.array(returns_v) - baseline_estimates#np.array(returns_b)
local_advantage = get_gae(
rewards=local_rewards,
value_estimates=v_estimates,
value_next=value_next[name],
gamma=self.optimizer.reward_signals[name].gamma,
lambd=self.hyperparameters.lambd,
)
#self._stats_reporter.add_stat(
# f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} GAE Advantage Estimate",
# np.mean(gae_advantage),

local_return = local_advantage + baseline_estimates
# This is later use as target for the different value estimates
# agent_buffer_trajectory[f"{name}_returns"].set(local_return)
agent_buffer_trajectory[f"{name}_returns_b"].set(returns_b)
agent_buffer_trajectory[f"{name}_returns_b"].set(returns_v)
agent_buffer_trajectory[f"{name}_returns_v"].set(returns_v)
agent_buffer_trajectory[f"{name}_advantage"].set(local_advantage)
tmp_advantages.append(local_advantage)

9
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, Initialization
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil

encoder_input_size = self.h_size
if baseline:
self.self_encoder = LinearEncoder(
obs_only_ent_size, 1, self.h_size
obs_only_ent_size,
1,
self.h_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.h_size) ** 0.5,
self.obs_encoder = EntityEmbedding(
self.h_size, obs_only_ent_size, None, self.h_size, concat_self=True

正在加载...
取消
保存