浏览代码

fixed bug in discrete

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
dee6b805
共有 2 个文件被更改,包括 10 次插入7 次删除
  1. 2
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
  2. 15
      ml-agents/mlagents/trainers/torch/utils.py

2
ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py


check_environment_trains(env, {BRAIN_NAME: config})
@pytest.mark.parametrize("use_discrete", [True])
@pytest.mark.parametrize("use_discrete", [True, False])
def test_2d_sac(use_discrete):
env = SimpleEnvironment(
[BRAIN_NAME], use_discrete=use_discrete, action_size=2, step_size=0.8

15
ml-agents/mlagents/trainers/torch/utils.py


@property
def discrete_tensor(self):
return torch.cat([_disc.unsqueeze(-1) for _disc in self.discrete_list], dim=1)
return torch.stack(self.discrete_list, dim=-1)
def to_numpy_dict(self) -> Dict[str, np.ndarray]:
array_dict: Dict[str, np.ndarray] = {}

if self.continuous_tensor is not None:
tensor_list.append(self.continuous_tensor)
if self.discrete_list is not None:
tensor_list.append(self.discrete_tensor)
tensor_list += (
self.discrete_list
) # Note this is different for ActionLogProbs
return tensor_list
@staticmethod

@property
def discrete_tensor(self):
return torch.stack(self.discrete_list, dim=-1)
return torch.cat([_disc.unsqueeze(-1) for _disc in self.discrete_list], dim=1)
@property

self.continuous_tensor
)
if self.discrete_list is not None:
array_dict["discrete_log_probs"] = ModelUtils.to_numpy(self.discrete_tensor)
return array_dict

tensor_list.append(self.continuous_tensor)
if self.discrete_list is not None:
tensor_list.append(self.discrete_tensor)
tensor_list.append(
self.discrete_tensor
) # Note this is different for AgentActions
return tensor_list
def flatten(self) -> torch.Tensor:

entropies = torch.stack(entropies_list, dim=-1)
if not all_probs_list:
entropies = entropies.squeeze(-1)
# all_probs = None
# else:
# all_probs = torch.cat(all_probs_list, dim=-1)
return log_probs_list, entropies, all_probs_list
@staticmethod

正在加载...
取消
保存