|
|
|
|
|
|
return img |
|
|
|
|
|
|
|
|
|
|
|
def _check_observations_match_spec( |
|
|
|
obs_index: int, |
|
|
|
observation_spec: ObservationSpec, |
|
|
|
agent_info_list: Collection[AgentInfoProto], |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Check that all the observations match the expected size. |
|
|
|
This gives a nicer error than a cryptic numpy error later. |
|
|
|
""" |
|
|
|
expected_obs_shape = tuple(observation_spec.shape) |
|
|
|
for agent_info in agent_info_list: |
|
|
|
agent_obs_shape = tuple(agent_info.observations[obs_index].shape) |
|
|
|
if expected_obs_shape != agent_obs_shape: |
|
|
|
raise UnityObservationException( |
|
|
|
f"Observation at index={obs_index} for agent with " |
|
|
|
f"id={agent_info.id} didn't match the ObservationSpec. " |
|
|
|
f"Expected shape {expected_obs_shape} but got {agent_obs_shape}." |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@timed |
|
|
|
def _observation_to_np_array( |
|
|
|
obs: ObservationProto, expected_shape: Optional[Iterable[int]] = None |
|
|
|
|
|
|
@timed |
|
|
|
def _process_maybe_compressed_observation( |
|
|
|
obs_index: int, |
|
|
|
shape: Tuple[int, int, int], |
|
|
|
observation_spec: ObservationSpec, |
|
|
|
shape = cast(Tuple[int, int, int], observation_spec.shape) |
|
|
|
batched_visual = [ |
|
|
|
_observation_to_np_array(agent_obs.observations[obs_index], shape) |
|
|
|
for agent_obs in agent_info_list |
|
|
|
] |
|
|
|
try: |
|
|
|
batched_visual = [ |
|
|
|
_observation_to_np_array(agent_obs.observations[obs_index], shape) |
|
|
|
for agent_obs in agent_info_list |
|
|
|
] |
|
|
|
except ValueError: |
|
|
|
# Try to get a more useful error message |
|
|
|
_check_observations_match_spec(obs_index, observation_spec, agent_info_list) |
|
|
|
# If that didn't raise anything, raise the original error |
|
|
|
raise |
|
|
|
return np.array(batched_visual, dtype=np.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@timed |
|
|
|
def _process_rank_one_or_two_observation( |
|
|
|
obs_index: int, shape: Tuple[int, ...], agent_info_list: Collection[AgentInfoProto] |
|
|
|
obs_index: int, |
|
|
|
observation_spec: ObservationSpec, |
|
|
|
agent_info_list: Collection[AgentInfoProto], |
|
|
|
return np.zeros((0,) + shape, dtype=np.float32) |
|
|
|
np_obs = np.array( |
|
|
|
[ |
|
|
|
agent_obs.observations[obs_index].float_data.data |
|
|
|
for agent_obs in agent_info_list |
|
|
|
], |
|
|
|
dtype=np.float32, |
|
|
|
).reshape((len(agent_info_list),) + shape) |
|
|
|
return np.zeros((0,) + observation_spec.shape, dtype=np.float32) |
|
|
|
try: |
|
|
|
np_obs = np.array( |
|
|
|
[ |
|
|
|
agent_obs.observations[obs_index].float_data.data |
|
|
|
for agent_obs in agent_info_list |
|
|
|
], |
|
|
|
dtype=np.float32, |
|
|
|
).reshape((len(agent_info_list),) + observation_spec.shape) |
|
|
|
except ValueError: |
|
|
|
# Try to get a more useful error message |
|
|
|
_check_observations_match_spec(obs_index, observation_spec, agent_info_list) |
|
|
|
# If that didn't raise anything, raise the original error |
|
|
|
raise |
|
|
|
_raise_on_nan_and_inf(np_obs, "observations") |
|
|
|
return np_obs |
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
decision_obs_list: List[np.ndarray] = [] |
|
|
|
terminal_obs_list: List[np.ndarray] = [] |
|
|
|
for obs_index, observation_specs in enumerate(behavior_spec.observation_specs): |
|
|
|
is_visual = len(observation_specs.shape) == 3 |
|
|
|
for obs_index, observation_spec in enumerate(behavior_spec.observation_specs): |
|
|
|
is_visual = len(observation_spec.shape) == 3 |
|
|
|
obs_shape = cast(Tuple[int, int, int], observation_specs.shape) |
|
|
|
obs_index, obs_shape, decision_agent_info_list |
|
|
|
obs_index, observation_spec, decision_agent_info_list |
|
|
|
obs_index, obs_shape, terminal_agent_info_list |
|
|
|
obs_index, observation_spec, terminal_agent_info_list |
|
|
|
obs_index, observation_specs.shape, decision_agent_info_list |
|
|
|
obs_index, observation_spec, decision_agent_info_list |
|
|
|
obs_index, observation_specs.shape, terminal_agent_info_list |
|
|
|
obs_index, observation_spec, terminal_agent_info_list |
|
|
|
) |
|
|
|
) |
|
|
|
decision_rewards = np.array( |
|
|
|