浏览代码

default actions are np.array of shape (n_agents, 0)

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
4c45efcf
共有 3 个文件被更改,包括 29 次插入13 次删除
  1. 21
      ml-agents-envs/mlagents_envs/base_env.py
  2. 4
      ml-agents-envs/mlagents_envs/environment.py
  3. 17
      ml-agents/mlagents/trainers/env_manager.py

21
ml-agents-envs/mlagents_envs/base_env.py


respectively.
"""
def __init__(
self,
continuous: Optional[np.ndarray] = None,
discrete: Optional[np.ndarray] = None,
):
if continuous is not None and continuous.dtype != np.float32:
def __init__(self, continuous: np.ndarray, discrete: np.ndarray):
if continuous.dtype != np.float32:
if discrete is not None and discrete.dtype != np.int32:
if discrete.dtype != np.int32:
discrete = discrete.astype(np.int32, copy=False)
self._discrete = discrete

@property
def discrete(self) -> np.ndarray:
return self._discrete
@staticmethod
def create_continuous(continuous: np.ndarray) -> "ActionTuple":
discrete = np.zeros((continuous.shape[0], 0), dtype=np.int32)
return ActionTuple(continuous, discrete)
@staticmethod
def create_discrete(discrete: np.ndarray) -> "ActionTuple":
continuous = np.zeros((discrete.shape[0], 0), dtype=np.float32)
return ActionTuple(continuous, discrete)
class ActionSpec(NamedTuple):

4
ml-agents-envs/mlagents_envs/environment.py


if n_agents == 0:
continue
for i in range(n_agents):
# TODO: extend to AgentBuffers
if vector_action[b].continuous is not None:
# TODO: This check will be removed when the oroto supports hybrid actions
if vector_action[b].continuous.shape[1] > 0:
_act = vector_action[b].continuous[i]
else:
_act = vector_action[b].discrete[i]

17
ml-agents/mlagents/trainers/env_manager.py


from mlagents.trainers.agent_processor import AgentManager, AgentManagerQueue
from mlagents.trainers.action_info import ActionInfo
from mlagents_envs.logging_util import get_logger
from mlagents_envs.exception import UnityActionException
AllStepResult = Dict[BehaviorName, Tuple[DecisionSteps, TerminalSteps]]
AllGroupSpec = Dict[BehaviorName, BehaviorSpec]

@staticmethod
def action_tuple_from_numpy_dict(action_dict: Dict[str, np.ndarray]) -> ActionTuple:
continuous: np.ndarray = None
discrete: np.ndarray = None
if "discrete_action" in action_dict:
if "discrete_action" in action_dict:
discrete = action_dict["discrete_action"]
action_tuple = ActionTuple(continuous, discrete)
else:
action_tuple = ActionTuple.create_continuous(continuous)
elif "discrete_action" in action_dict:
return ActionTuple(continuous, discrete)
action_tuple = ActionTuple.create_discrete(discrete)
else:
raise UnityActionException(
"The action dict must contain entries for either continuous_action or discrete_action."
)
return action_tuple
正在加载...
取消
保存