浏览代码

renaming some methods and verifying imiation learning works as expected

/bullet-hell-barracuda-test-1.3.1
vincentpierre 4 年前
当前提交
b9991637
共有 2 个文件被更改,包括 15 次插入15 次删除
  1. 16
      ml-agents-envs/mlagents_envs/rpc_utils.py
  2. 14
      ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py

16
ml-agents-envs/mlagents_envs/rpc_utils.py


@timed
def observation_to_np_array(
def _observation_to_np_array(
obs: ObservationProto, expected_shape: Optional[Iterable[int]] = None
) -> np.ndarray:
"""

@timed
def _process_visual_observation(
def _process_maybe_compressed_observation(
obs_index: int,
shape: Tuple[int, int, int],
agent_info_list: Collection[AgentInfoProto],

batched_visual = [
observation_to_np_array(agent_obs.observations[obs_index], shape)
_observation_to_np_array(agent_obs.observations[obs_index], shape)
for agent_obs in agent_info_list
]
return np.array(batched_visual, dtype=np.float32)

@timed
def _process_vector_observation(
def _process_rank_one_or_two_observation(
obs_index: int, shape: Tuple[int, ...], agent_info_list: Collection[AgentInfoProto]
) -> np.ndarray:
if len(agent_info_list) == 0:

if is_visual:
obs_shape = cast(Tuple[int, int, int], observation_specs.shape)
decision_obs_list.append(
_process_visual_observation(
_process_maybe_compressed_observation(
_process_visual_observation(
_process_maybe_compressed_observation(
_process_vector_observation(
_process_rank_one_or_two_observation(
_process_vector_observation(
_process_rank_one_or_two_observation(
obs_index, observation_specs.shape, terminal_agent_info_list
)
)

14
ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py


from mlagents_envs.rpc_utils import (
behavior_spec_from_proto,
process_pixels,
_process_visual_observation,
_process_vector_observation,
_process_maybe_compressed_observation,
_process_rank_one_or_two_observation,
steps_from_proto,
)
from PIL import Image

shapes = [(3,), (4,)]
list_proto = generate_list_agent_proto(n_agents, shapes)
for obs_index, shape in enumerate(shapes):
arr = _process_vector_observation(obs_index, shape, list_proto)
arr = _process_rank_one_or_two_observation(obs_index, shape, list_proto)
assert list(arr.shape) == ([n_agents] + list(shape))
assert np.allclose(arr, 0.1, atol=0.01)

ap2 = AgentInfoProto()
ap2.observations.extend([proto_obs_2])
ap_list = [ap1, ap2]
arr = _process_visual_observation(0, (128, 64, 3), ap_list)
arr = _process_maybe_compressed_observation(0, (128, 64, 3), ap_list)
assert list(arr.shape) == [2, 128, 64, 3]
assert np.allclose(arr[0, :, :, :], in_array_1, atol=0.01)
assert np.allclose(arr[1, :, :, :], in_array_2, atol=0.01)

ap2 = AgentInfoProto()
ap2.observations.extend([proto_obs_2])
ap_list = [ap1, ap2]
arr = _process_visual_observation(0, (128, 64, 1), ap_list)
arr = _process_maybe_compressed_observation(0, (128, 64, 1), ap_list)
assert list(arr.shape) == [2, 128, 64, 1]
assert np.allclose(arr[0, :, :, :], expected_out_array_1, atol=0.01)
assert np.allclose(arr[1, :, :, :], expected_out_array_2, atol=0.01)

ap1 = AgentInfoProto()
ap1.observations.extend([proto_obs_1])
ap_list = [ap1]
arr = _process_visual_observation(0, (128, 64, 8), ap_list)
arr = _process_maybe_compressed_observation(0, (128, 64, 8), ap_list)
assert list(arr.shape) == [1, 128, 64, 8]
assert np.allclose(arr[0, :, :, :], expected_out_array_1, atol=0.01)

ap1.observations.extend([proto_obs_1])
ap_list = [ap1]
with pytest.raises(UnityObservationException):
_process_visual_observation(0, (128, 42, 3), ap_list)
_process_maybe_compressed_observation(0, (128, 42, 3), ap_list)
def test_batched_step_result_from_proto():

正在加载...
取消
保存