浏览代码

remove ActionType

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
dc89318d
共有 3 个文件被更改,包括 7 次插入19 次删除
  1. 6
      ml-agents-envs/mlagents_envs/base_env.py
  2. 4
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  3. 16
      ml-agents/mlagents/trainers/torch/networks.py

6
ml-agents-envs/mlagents_envs/base_env.py


Mapping as MappingType,
)
import numpy as np
from enum import Enum
AgentId = int
BehaviorName = str

interrupted=np.zeros(0, dtype=np.bool),
agent_id=np.zeros(0, dtype=np.int32),
)
class ActionType(Enum):
DISCRETE = 0
CONTINUOUS = 1
class ActionSpec(NamedTuple):

4
ml-agents/mlagents/trainers/sac/optimizer_torch.py


from mlagents.torch_utils import torch, nn, default_device
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import ActionType, ActionSpec
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.settings import NetworkSettings

super().__init__()
self.action_spec = action_spec
if self.action_spec.is_continuous():
self.act_type = ActionType.CONTINUOUS
self.act_type = ActionType.DISCRETE
self.act_size = self.action_spec.discrete_branches
num_value_outs = sum(self.act_size)
num_action_ins = 0

16
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.torch_utils import torch, nn
from mlagents_envs.base_env import ActionType, ActionSpec
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.torch.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,

):
super().__init__()
self.action_spec = action_spec
if self.action_spec.is_continuous():
self.act_type = ActionType.CONTINUOUS
else:
self.act_type = ActionType.DISCRETE
torch.Tensor([int(self.act_type == ActionType.CONTINUOUS)])
torch.Tensor([int(self.action_spec.is_continuous())])
)
self.act_size_vector = torch.nn.Parameter(
torch.Tensor([self.action_spec.total_size]), requires_grad=False

else:
self.encoding_size = network_settings.hidden_units
if self.act_type == ActionType.CONTINUOUS:
if self.action_spec.is_continuous():
self.distribution = GaussianDistribution(
self.encoding_size,
self.action_spec.continuous_size,

encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
if self.act_type == ActionType.CONTINUOUS:
if self.action_spec.is_continuous():
dists = self.distribution(encoding)
else:
dists = self.distribution(encoding, masks)

Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
"""
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
if self.act_type == ActionType.CONTINUOUS:
if self.action_spec.is_continuous():
action_list = self.sample_action(dists)
action_out = torch.stack(action_list, dim=-1)
else:

encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
if self.act_type == ActionType.CONTINUOUS:
if self.action_spec.is_continuous():
dists = self.distribution(encoding)
else:
dists = self.distribution(encoding, masks=masks)

正在加载...
取消
保存