浏览代码

Fix on GAIL Torch when using actions (#4407)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
12e15e29
共有 2 个文件被更改,包括 8 次插入4 次删除
  1. 2
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
  2. 10
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py

2
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py


buffer["vector_obs"].append(curr_split_obs.vector_observations)
buffer["next_vector_in"].append(next_split_obs.vector_observations)
buffer["actions"].append(action)
buffer["done"].append(np.zeros(1, dtype=np.float32))
buffer["done"] = np.zeros(number, dtype=np.float32)
return buffer

10
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


encoder_input = self.get_state_encoding(mini_batch)
if self._settings.use_actions:
actions = self.get_action_input(mini_batch)
dones = torch.as_tensor(mini_batch["done"], dtype=torch.float)
dones = torch.as_tensor(mini_batch["done"], dtype=torch.float).unsqueeze(1)
encoder_input = torch.cat([encoder_input, actions, dones], dim=1)
hidden = self.encoder(encoder_input)
z_mu: Optional[torch.Tensor] = None

policy_action = self.get_action_input(policy_batch)
expert_action = self.get_action_input(policy_batch)
action_epsilon = torch.rand(policy_action.shape)
policy_dones = torch.as_tensor(policy_batch["done"], dtype=torch.float)
expert_dones = torch.as_tensor(expert_batch["done"], dtype=torch.float)
policy_dones = torch.as_tensor(
policy_batch["done"], dtype=torch.float
).unsqueeze(1)
expert_dones = torch.as_tensor(
expert_batch["done"], dtype=torch.float
).unsqueeze(1)
dones_epsilon = torch.rand(policy_dones.shape)
encoder_input = torch.cat(
[

正在加载...
取消
保存