浏览代码

[🐛 🔨 ] set_action_for_agent expects a ActionTuple with batch size 1. (#5208)

* [Bug Fix] set_action_for_agent expects a ActionTuple with batch size 1.

* moving a line around
/check-for-ModelOverriders
GitHub 3 年前
当前提交
66eb9432
共有 3 个文件被更改,包括 60 次插入13 次删除
  1. 14
      ml-agents-envs/mlagents_envs/base_env.py
  2. 4
      ml-agents-envs/mlagents_envs/environment.py
  3. 55
      ml-agents-envs/mlagents_envs/tests/test_set_action.py

14
ml-agents-envs/mlagents_envs/base_env.py


return ActionTuple(continuous=_continuous, discrete=_discrete)
def _validate_action(
self, actions: ActionTuple, n_agents: Optional[int], name: str
self, actions: ActionTuple, n_agents: int, name: str
_expected_shape = (
(n_agents, self.continuous_size)
if n_agents is not None
else (self.continuous_size,)
)
_expected_shape = (n_agents, self.continuous_size)
if actions.continuous.shape != _expected_shape:
raise UnityActionException(
f"The behavior {name} needs a continuous input of dimension "

_expected_shape = (
(n_agents, self.discrete_size)
if n_agents is not None
else (self.discrete_size,)
)
_expected_shape = (n_agents, self.discrete_size)
if actions.discrete.shape != _expected_shape:
raise UnityActionException(
f"The behavior {name} needs a discrete input of dimension "

4
ml-agents-envs/mlagents_envs/environment.py


if behavior_name not in self._env_state:
return
action_spec = self._env_specs[behavior_name].action_spec
num_agents = len(self._env_state[behavior_name][0])
action = action_spec._validate_action(action, None, behavior_name)
action = action_spec._validate_action(action, 1, behavior_name)
num_agents = len(self._env_state[behavior_name][0])
self._env_actions[behavior_name] = action_spec.empty_action(num_agents)
try:
index = np.where(self._env_state[behavior_name][0].agent_id == agent_id)[0][

55
ml-agents-envs/mlagents_envs/tests/test_set_action.py


from mlagents_envs.registry import default_registry
from mlagents_envs.side_channel.engine_configuration_channel import (
EngineConfigurationChannel,
)
from mlagents_envs.base_env import ActionTuple
import numpy as np
BALL_ID = "3DBall"
def test_set_action_single_agent():
engine_config_channel = EngineConfigurationChannel()
env = default_registry[BALL_ID].make(
base_port=6000,
worker_id=0,
no_graphics=True,
side_channels=[engine_config_channel],
)
engine_config_channel.set_configuration_parameters(time_scale=100)
for _ in range(3):
env.reset()
behavior_name = list(env.behavior_specs.keys())[0]
d, t = env.get_steps(behavior_name)
for _ in range(50):
for agent_id in d.agent_id:
action = np.ones((1, 2))
action_tuple = ActionTuple()
action_tuple.add_continuous(action)
env.set_action_for_agent(behavior_name, agent_id, action_tuple)
env.step()
d, t = env.get_steps(behavior_name)
env.close()
def test_set_action_multi_agent():
engine_config_channel = EngineConfigurationChannel()
env = default_registry[BALL_ID].make(
base_port=6001,
worker_id=0,
no_graphics=True,
side_channels=[engine_config_channel],
)
engine_config_channel.set_configuration_parameters(time_scale=100)
for _ in range(3):
env.reset()
behavior_name = list(env.behavior_specs.keys())[0]
d, t = env.get_steps(behavior_name)
for _ in range(50):
action = np.ones((len(d), 2))
action_tuple = ActionTuple()
action_tuple.add_continuous(action)
env.set_actions(behavior_name, action_tuple)
env.step()
d, t = env.get_steps(behavior_name)
env.close()
正在加载...
取消
保存