浏览代码

Make sure all tests pass on BC

/hotfix-v0.9.2a
Ervin Teng 5 年前
当前提交
c912d140
共有 2 个文件被更改,包括 4 次插入14 次删除
  1. 16
      ml-agents/mlagents/trainers/bc/policy.py
  2. 2
      ml-agents/mlagents/trainers/bc/trainer.py

16
ml-agents/mlagents/trainers/bc/policy.py


self.model.sequence_length: self.sequence_length,
}
if self.use_continuous_act:
feed_dict[self.model.true_action] = mini_batch["actions"].reshape(
[-1, self.brain.vector_action_space_size[0]]
)
feed_dict[self.model.true_action] = mini_batch["actions"]
feed_dict[self.model.true_action] = mini_batch["actions"].reshape(
[-1, len(self.brain.vector_action_space_size)]
)
feed_dict[self.model.true_action] = mini_batch["actions"]
apparent_obs_size = (
self.brain.vector_observation_space_size
* self.brain.num_stacked_vector_observations
)
feed_dict[self.model.vector_in] = mini_batch["vector_obs"].reshape(
[-1, apparent_obs_size]
)
feed_dict[self.model.vector_in] = mini_batch["vector_obs"]
for i, _ in enumerate(self.model.visual_in):
visual_obs = mini_batch["visual_obs%d" % i]
feed_dict[self.model.visual_in[i]] = visual_obs

2
ml-agents/mlagents/trainers/bc/trainer.py


"""
Updates the policy.
"""
self.demonstration_buffer.update_buffer.shuffle()
self.demonstration_buffer.update_buffer.shuffle(self.policy.sequence_length)
batch_losses = []
num_batches = min(
len(self.demonstration_buffer.update_buffer["actions"]) // self.n_sequences,

正在加载...
取消
保存