|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mlagents.trainers.trajectory import GroupObsUtil |
|
|
|
from mlagents.trainers.buffer import BufferKey, ObservationKeyPrefix |
|
|
|
from mlagents.trainers.buffer import AgentBuffer, BufferKey, ObservationKeyPrefix |
|
|
|
|
|
|
|
VEC_OBS_SIZE = 6 |
|
|
|
ACTION_SIZE = 4 |
|
|
|
|
|
|
for _key in wanted_group_keys: |
|
|
|
for step in agentbuffer[_key]: |
|
|
|
assert len(step) == 4 |
|
|
|
|
|
|
|
|
|
|
|
def test_obsutil_group_from_buffer(): |
|
|
|
buff = AgentBuffer() |
|
|
|
# Create some obs |
|
|
|
for _ in range(3): |
|
|
|
buff[GroupObsUtil.get_name_at(0)].append(3 * [np.ones((5,), dtype=np.float32)]) |
|
|
|
# Some agents have died |
|
|
|
for _ in range(2): |
|
|
|
buff[GroupObsUtil.get_name_at(0)].append(1 * [np.ones((5,), dtype=np.float32)]) |
|
|
|
|
|
|
|
# Get the group obs, which will be a List of Lists of np.ndarray, where each element is the same |
|
|
|
# length as the AgentBuffer but contains only one agent's obs. Dead agents are padded by |
|
|
|
# NaNs. |
|
|
|
gobs = GroupObsUtil.from_buffer(buff, 1) |
|
|
|
# Agent 0 is full |
|
|
|
agent_0_obs = gobs[0] |
|
|
|
for obs in agent_0_obs: |
|
|
|
assert obs.shape == (buff.num_experiences, 5) |
|
|
|
assert not np.isnan(obs).any() |
|
|
|
|
|
|
|
agent_1_obs = gobs[1] |
|
|
|
for obs in agent_1_obs: |
|
|
|
assert obs.shape == (buff.num_experiences, 5) |
|
|
|
for i, _exp_obs in enumerate(obs): |
|
|
|
if i >= 3: |
|
|
|
assert np.isnan(_exp_obs).all() |
|
|
|
else: |
|
|
|
assert not np.isnan(_exp_obs).any() |