浏览代码

Fix for discrete actions (#4181)

/develop/add-fire
GitHub 5 年前
当前提交
0d80d87a
共有 4 个文件被更改,包括 5 次插入6 次删除
  1. 2
      ml-agents/mlagents/trainers/distributions_torch.py
  2. 2
      ml-agents/mlagents/trainers/policy/policy.py
  3. 5
      ml-agents/mlagents/trainers/policy/torch_policy.py
  4. 2
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py

2
ml-agents/mlagents/trainers/distributions_torch.py


return torch.multinomial(self.probs, 1)
def pdf(self, value):
return torch.diag(self.probs.T[value.flatten()])
return torch.diag(self.probs.T[value.flatten().long()])
def log_prob(self, value):
return torch.log(self.pdf(value))

2
ml-agents/mlagents/trainers/policy/policy.py


self.num_branches = len(self.brain.vector_action_space_size)
self.previous_action_dict: Dict[str, np.array] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings
self.normalize = trainer_settings.network_settings.normalize
self.use_recurrent = trainer_settings.network_settings.memory is not None
self.model_path = trainer_settings.init_path

5
ml-agents/mlagents/trainers/policy/torch_policy.py


actions = self.actor_critic.sample_action(dists)
log_probs, entropies = self.actor_critic.get_probs_and_entropy(actions, dists)
if self.act_type == "continuous":
actions.squeeze_(-1)
actions = torch.squeeze(actions)
return actions, log_probs, entropies, value_heads, memories

fake_vec_obs = [torch.zeros([1] + [self.brain.vector_observation_space_size])]
fake_vis_obs = [torch.zeros([1] + [84, 84, 3])]
fake_masks = torch.ones([1] + self.actor_critic.act_size)
fake_memories = torch.zeros([1] + [self.m_size])
# fake_memories = torch.zeros([1] + [self.m_size])
export_path = "./model-" + str(step) + ".onnx"
output_names = ["action", "action_probs"]
input_names = ["vector_observation", "action_mask"]

2
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


if self.policy.use_continuous_act:
actions = torch.as_tensor(batch["actions"]).unsqueeze(-1)
else:
actions = torch.as_tensor(batch["actions"])
actions = torch.as_tensor(batch["actions"], dtype=torch.long)
memories = [
torch.as_tensor(batch["memory"][i])

正在加载...
取消
保存