浏览代码

[bug-fix] Fix stats reporting for reward signals in SAC (#3606)

/bug-failed-api-check
GitHub 5 年前
当前提交
873ba7fd
共有 3 个文件被更改,包括 5 次插入5 次删除
  1. 1
      ml-agents/mlagents/trainers/policy/tf_policy.py
  2. 5
      ml-agents/mlagents/trainers/sac/trainer.py
  3. 4
      ml-agents/mlagents/trainers/tests/test_sac.py

1
ml-agents/mlagents/trainers/policy/tf_policy.py


self.use_recurrent = trainer_parameters["use_recurrent"]
self.memory_dict: Dict[str, np.ndarray] = {}
self.reward_signals: Dict[str, "RewardSignal"] = {}
self.num_branches = len(self.brain.vector_action_space_size)
self.previous_action_dict: Dict[str, np.array] = {}
self.normalize = trainer_parameters.get("normalize", False)

5
ml-agents/mlagents/trainers/sac/trainer.py


self.collected_rewards["environment"][agent_id] += np.sum(
agent_buffer_trajectory["environment_rewards"]
)
for name, reward_signal in self.policy.reward_signals.items():
for name, reward_signal in self.optimizer.reward_signals.items():
evaluate_result = reward_signal.evaluate_batch(
agent_buffer_trajectory
).scaled_reward

reparameterize=True,
create_tf_graph=False,
)
for _reward_signal in policy.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Load the replay buffer if load
if self.load and self.checkpoint_replay_buffer:
try:

4
ml-agents/mlagents/trainers/tests/test_sac.py


for agent in reward.values():
assert agent == 0
assert trainer.stats_reporter.get_stats_summaries("Policy/Extrinsic Reward").num > 0
# Assert we're not just using the default values
assert (
trainer.stats_reporter.get_stats_summaries("Policy/Extrinsic Reward").mean > 0
)
if __name__ == "__main__":
正在加载...
取消
保存