浏览代码

Migrate SAC

/develop-newnormalization
Ervin Teng 5 年前
当前提交
28eba789
共有 1 个文件被更改,包括 13 次插入1 次删除
  1. 14
      ml-agents/mlagents/trainers/sac/trainer.py

14
ml-agents/mlagents/trainers/sac/trainer.py


if self.is_training:
self.policy.update_normalization(agent_buffer_trajectory["vector_obs"])
# Evaluate all reward functions for reporting purposes
self.collected_rewards["environment"][agent_id] += np.sum(
agent_buffer_trajectory["environment_rewards"]
)
for name, reward_signal in self.policy.reward_signals.items():
evaluate_result = reward_signal.evaluate_batch(
agent_buffer_trajectory
).scaled_reward
agent_buffer_trajectory["{}_rewards".format(name)].extend(evaluate_result)
# Report the reward signals
self.collected_rewards[name][agent_id] += np.sum(evaluate_result)
vec_vis_obs = split_obs(last_step)
vec_vis_obs = split_obs(last_step.obs)
for i, obs in enumerate(vec_vis_obs.visual_observations):
agent_buffer_trajectory["next_visual_obs%d" % i][-1] = obs
if vec_vis_obs.vector_observations.size > 1:

正在加载...
取消
保存