浏览代码

Add test for GroupObs

/develop/coma2/samenet
Ervin Teng 4 年前
当前提交
12cef7af
共有 2 个文件被更改,包括 36 次插入4 次删除
  1. 34
      ml-agents/mlagents/trainers/tests/test_trajectory.py
  2. 6
      ml-agents/mlagents/trainers/torch/agent_action.py

34
ml-agents/mlagents/trainers/tests/test_trajectory.py


import numpy as np
from mlagents.trainers.trajectory import GroupObsUtil
from mlagents.trainers.buffer import BufferKey, ObservationKeyPrefix
from mlagents.trainers.buffer import AgentBuffer, BufferKey, ObservationKeyPrefix
VEC_OBS_SIZE = 6
ACTION_SIZE = 4

for _key in wanted_group_keys:
for step in agentbuffer[_key]:
assert len(step) == 4
def test_obsutil_group_from_buffer():
buff = AgentBuffer()
# Create some obs
for _ in range(3):
buff[GroupObsUtil.get_name_at(0)].append(3 * [np.ones((5,), dtype=np.float32)])
# Some agents have died
for _ in range(2):
buff[GroupObsUtil.get_name_at(0)].append(1 * [np.ones((5,), dtype=np.float32)])
# Get the group obs, which will be a List of Lists of np.ndarray, where each element is the same
# length as the AgentBuffer but contains only one agent's obs. Dead agents are padded by
# NaNs.
gobs = GroupObsUtil.from_buffer(buff, 1)
# Agent 0 is full
agent_0_obs = gobs[0]
for obs in agent_0_obs:
assert obs.shape == (buff.num_experiences, 5)
assert not np.isnan(obs).any()
agent_1_obs = gobs[1]
for obs in agent_1_obs:
assert obs.shape == (buff.num_experiences, 5)
for i, _exp_obs in enumerate(obs):
if i >= 3:
assert np.isnan(_exp_obs).all()
else:
assert not np.isnan(_exp_obs).any()

6
ml-agents/mlagents/trainers/torch/agent_action.py


agent_buffer_field: AgentBufferField, dtype: torch.dtype = torch.float32
) -> List[torch.Tensor]:
"""
Pad actions and convert to tensor. Note that data is padded by 0's, not NaNs
as the observations are.
Pad actions and convert to tensor. Pad the data with NaNs where there is no
data.
"""
action_shape = None
for _action in agent_buffer_field:

map(
lambda x: ModelUtils.list_to_tensor(x, dtype=dtype),
itertools.zip_longest(
*agent_buffer_field, fillvalue=np.full(action_shape, 0)
*agent_buffer_field, fillvalue=np.full(action_shape, np.nan)
),
)
)

正在加载...
取消
保存