浏览代码

fixed tests/ -> single validate_action func

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
1bbe492c
共有 3 个文件被更改,包括 13 次插入22 次删除
  1. 12
      ml-agents-envs/mlagents_envs/base_env.py
  2. 17
      ml-agents-envs/mlagents_envs/environment.py
  3. 6
      ml-agents/mlagents/trainers/policy/policy.py

12
ml-agents-envs/mlagents_envs/base_env.py


)
return action
def validate_action_shape(
def validate_action(
) -> None:
) -> np.ndarray:
for the correct number of agents.
for the correct number of agents and ensures the type.
"""
if self.continuous_size > 0:
_size = self.continuous_size

f"{_expected_shape} for (<number of agents>, <action size>) but "
f"received input of dimension {actions.shape}"
)
def validate_action_type(self, actions: np.ndarray) -> np.ndarray:
"""
Checks action has the correct expected type and if not
casts it to the correct type.
"""
_expected_type = np.float32 if self.is_continuous() else np.int32
if actions.dtype != _expected_type:
actions = actions.astype(_expected_type)

17
ml-agents-envs/mlagents_envs/environment.py


if behavior_name not in self._env_state:
return
action_spec = self._env_specs[behavior_name].action_spec
action_spec.validate_action_shape(
action, len(self._env_state[behavior_name][0]), behavior_name
)
action = action_spec.validate_action_type(action)
num_agents = len(self._env_state[behavior_name][0])
action = action_spec.validate_action(action, num_agents, behavior_name)
self._env_actions[behavior_name] = action
def set_action_for_agent(

if behavior_name not in self._env_state:
return
action_spec = self._env_specs[behavior_name].action_spec
action_spec.validate_action_shape(
action, len(self._env_state[behavior_name][0]), behavior_name
)
action = action_spec.validate_action_type(action)
num_agents = len(self._env_state[behavior_name][0])
action = action_spec.validate_action(action, num_agents, behavior_name)
self._env_actions[behavior_name] = action_spec.create_empty(
len(self._env_state[behavior_name][0])
)
self._env_actions[behavior_name] = action_spec.create_empty(num_agents)
try:
index = np.where(self._env_state[behavior_name][0].agent_id == agent_id)[0][
0

6
ml-agents/mlagents/trainers/policy/policy.py


1 for shape in behavior_spec.observation_shapes if len(shape) == 3
)
self.use_continuous_act = self.behavior_spec.action_spec.is_continuous()
self.num_branches = self.behavior_spec.action_spec.discrete_size
# This line will be removed in the ActionBuffer change
self.num_branches = (
self.behavior_spec.action_spec.continuous_size
+ self.behavior_spec.action_spec.discrete_size
)
self.previous_action_dict: Dict[str, np.array] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings.network_settings.normalize

正在加载...
取消
保存