浏览代码

Fix test

/develop/trainerinterface
Ervin Teng 5 年前
当前提交
48793ec1
共有 4 个文件被更改,包括 0 次插入10 次删除
  1. 3
      ml-agents/mlagents/trainers/ppo/trainer.py
  2. 1
      ml-agents/mlagents/trainers/rl_trainer.py
  3. 3
      ml-agents/mlagents/trainers/sac/trainer.py
  4. 3
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py

3
ml-agents/mlagents/trainers/ppo/trainer.py


super()._process_trajectory(trajectory)
agent_id = trajectory.agent_id # All the agents should have the same ID
# Add to episode_steps
self.episode_steps[agent_id] += len(trajectory.steps)
agent_buffer_trajectory = trajectory.to_agentbuffer()
# Update the normalization
if self.is_training:

1
ml-agents/mlagents/trainers/rl_trainer.py


"environment": defaultdict(lambda: 0)
}
self.update_buffer: AgentBuffer = AgentBuffer()
self.episode_steps: Dict[str, int] = defaultdict(lambda: 0)
# Write hyperparameters to Tensorboard
if self.is_training:
self.write_tensorboard_text("Hyperparameters", self.trainer_parameters)

3
ml-agents/mlagents/trainers/sac/trainer.py


last_step = trajectory.steps[-1]
agent_id = trajectory.agent_id # All the agents should have the same ID
# Add to episode_steps
self.episode_steps[agent_id] += len(trajectory.steps)
agent_buffer_trajectory = trajectory.to_agentbuffer()
# Update the normalization

3
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


def test_rl_trainer():
trainer = create_rl_trainer()
agent_id = "0"
trainer.episode_steps[agent_id] = 3
for agent_id in trainer.episode_steps:
assert trainer.episode_steps[agent_id] == 0
for rewards in trainer.collected_rewards.values():
for agent_id in rewards:
assert rewards[agent_id] == 0

正在加载...
取消
保存