浏览代码

fixing more ci tests

/develop/magic-string
Andrew Cohen 5 年前
当前提交
3e76adbd
共有 2 个文件被更改,包括 7 次插入11 次删除
  1. 10
      ml-agents/mlagents/trainers/ppo/trainer.py
  2. 8
      ml-agents/mlagents/trainers/sac/trainer.py

10
ml-agents/mlagents/trainers/ppo/trainer.py


def create_policy(self, brain_parameters: BrainParameters) -> TFPolicy:
if self.multi_gpu and len(get_devices()) > 1:
policy = MultiGpuPPOPolicy(
self.ppo_policy = MultiGpuPPOPolicy(
self.seed,
brain_parameters,
self.trainer_parameters,

else:
policy = PPOPolicy(
self.ppo_policy = PPOPolicy(
self.seed,
brain_parameters,
self.trainer_parameters,

for _reward_signal in policy.reward_signals.keys():
for _reward_signal in self.ppo_policy.reward_signals.keys():
self.ppo_policy = policy
return policy
return self.ppo_policy
def discount_rewards(r, gamma=0.99, value_next=0.0):

8
ml-agents/mlagents/trainers/sac/trainer.py


self.trainer_metrics.end_policy_update()
def create_policy(self, brain_parameters: BrainParameters) -> TFPolicy:
policy = SACPolicy(
self.sac_policy = SACPolicy(
self.seed,
brain_parameters,
self.trainer_parameters,

for _reward_signal in policy.reward_signals.keys():
for _reward_signal in self.sac_policy.reward_signals.keys():
self.collected_rewards[_reward_signal] = {}
# Load the replay buffer if load

)
)
self.sac_policy = policy
return policy
return self.sac_policy
def update_sac_policy(self) -> None:
"""

正在加载...
取消
保存