|
|
|
|
|
|
self._assert_behavior_exists(behavior_name) |
|
|
|
if behavior_name not in self._env_state: |
|
|
|
return |
|
|
|
spec = self._env_specs[behavior_name] |
|
|
|
expected_type = np.float32 if spec.action_spec.is_continuous() else np.int32 |
|
|
|
expected_shape = (len(self._env_state[behavior_name][0]), spec.action_spec.size) |
|
|
|
if action.shape != expected_shape: |
|
|
|
raise UnityActionException( |
|
|
|
f"The behavior {behavior_name} needs an input of dimension " |
|
|
|
f"{expected_shape} for (<number of agents>, <action size>) but " |
|
|
|
f"received input of dimension {action.shape}" |
|
|
|
) |
|
|
|
if action.dtype != expected_type: |
|
|
|
action = action.astype(expected_type) |
|
|
|
action_spec = self._env_specs[behavior_name].action_spec |
|
|
|
action_spec.validate_action_shape( |
|
|
|
action, len(self._env_state[behavior_name][0]), behavior_name |
|
|
|
) |
|
|
|
action = action_spec.validate_action_type(action) |
|
|
|
self._env_actions[behavior_name] = action |
|
|
|
|
|
|
|
def set_action_for_agent( |
|
|
|
|
|
|
if behavior_name not in self._env_state: |
|
|
|
return |
|
|
|
spec = self._env_specs[behavior_name] |
|
|
|
expected_shape = (spec.action_spec.size,) |
|
|
|
if action.shape != expected_shape: |
|
|
|
raise UnityActionException( |
|
|
|
f"The Agent {agent_id} with BehaviorName {behavior_name} needs " |
|
|
|
f"an input of dimension {expected_shape} but received input of " |
|
|
|
f"dimension {action.shape}" |
|
|
|
) |
|
|
|
expected_type = np.float32 if spec.action_spec.is_continuous() else np.int32 |
|
|
|
if action.dtype != expected_type: |
|
|
|
action = action.astype(expected_type) |
|
|
|
action_spec = self._env_specs[behavior_name].action_spec |
|
|
|
action_spec.validate_action_shape( |
|
|
|
action, len(self._env_state[behavior_name][0]), behavior_name |
|
|
|
) |
|
|
|
action = action_spec.validate_action_type(action) |
|
|
|
self._env_actions[behavior_name] = spec.action_spec.create_empty( |
|
|
|
self._env_actions[behavior_name] = action_spec.create_empty( |
|
|
|
len(self._env_state[behavior_name][0]) |
|
|
|
) |
|
|
|
try: |
|
|
|