|
|
|
|
|
|
from mlagents_envs.rpc_utils import ( |
|
|
|
behavior_spec_from_proto, |
|
|
|
process_pixels, |
|
|
|
_process_visual_observation, |
|
|
|
_process_vector_observation, |
|
|
|
_process_maybe_compressed_observation, |
|
|
|
_process_rank_one_or_two_observation, |
|
|
|
steps_from_proto, |
|
|
|
) |
|
|
|
from PIL import Image |
|
|
|
|
|
|
shapes = [(3,), (4,)] |
|
|
|
list_proto = generate_list_agent_proto(n_agents, shapes) |
|
|
|
for obs_index, shape in enumerate(shapes): |
|
|
|
arr = _process_vector_observation(obs_index, shape, list_proto) |
|
|
|
arr = _process_rank_one_or_two_observation(obs_index, shape, list_proto) |
|
|
|
assert list(arr.shape) == ([n_agents] + list(shape)) |
|
|
|
assert np.allclose(arr, 0.1, atol=0.01) |
|
|
|
|
|
|
|
|
|
|
ap2 = AgentInfoProto() |
|
|
|
ap2.observations.extend([proto_obs_2]) |
|
|
|
ap_list = [ap1, ap2] |
|
|
|
arr = _process_visual_observation(0, (128, 64, 3), ap_list) |
|
|
|
arr = _process_maybe_compressed_observation(0, (128, 64, 3), ap_list) |
|
|
|
assert list(arr.shape) == [2, 128, 64, 3] |
|
|
|
assert np.allclose(arr[0, :, :, :], in_array_1, atol=0.01) |
|
|
|
assert np.allclose(arr[1, :, :, :], in_array_2, atol=0.01) |
|
|
|
|
|
|
ap2 = AgentInfoProto() |
|
|
|
ap2.observations.extend([proto_obs_2]) |
|
|
|
ap_list = [ap1, ap2] |
|
|
|
arr = _process_visual_observation(0, (128, 64, 1), ap_list) |
|
|
|
arr = _process_maybe_compressed_observation(0, (128, 64, 1), ap_list) |
|
|
|
assert list(arr.shape) == [2, 128, 64, 1] |
|
|
|
assert np.allclose(arr[0, :, :, :], expected_out_array_1, atol=0.01) |
|
|
|
assert np.allclose(arr[1, :, :, :], expected_out_array_2, atol=0.01) |
|
|
|
|
|
|
ap1 = AgentInfoProto() |
|
|
|
ap1.observations.extend([proto_obs_1]) |
|
|
|
ap_list = [ap1] |
|
|
|
arr = _process_visual_observation(0, (128, 64, 8), ap_list) |
|
|
|
arr = _process_maybe_compressed_observation(0, (128, 64, 8), ap_list) |
|
|
|
assert list(arr.shape) == [1, 128, 64, 8] |
|
|
|
assert np.allclose(arr[0, :, :, :], expected_out_array_1, atol=0.01) |
|
|
|
|
|
|
|
|
|
|
ap1.observations.extend([proto_obs_1]) |
|
|
|
ap_list = [ap1] |
|
|
|
with pytest.raises(UnityObservationException): |
|
|
|
_process_visual_observation(0, (128, 42, 3), ap_list) |
|
|
|
_process_maybe_compressed_observation(0, (128, 42, 3), ap_list) |
|
|
|
|
|
|
|
|
|
|
|
def test_batched_step_result_from_proto(): |
|
|
|