浏览代码

Change AgentAction back to 0 pad and add tests

/develop/coma2/samenet
Ervin Teng 4 年前
当前提交
1fc3640e
共有 2 个文件被更改,包括 57 次插入3 次删除
  1. 8
      ml-agents/mlagents/trainers/torch/agent_action.py
  2. 52
      ml-agents/mlagents/trainers/tests/torch/test_agent_action.py

8
ml-agents/mlagents/trainers/torch/agent_action.py


agent_buffer_field: AgentBufferField, dtype: torch.dtype = torch.float32
) -> List[torch.Tensor]:
"""
Pad actions and convert to tensor. Pad the data with NaNs where there is no
data.
Pad actions and convert to tensor. Pad the data with 0s where there is no
data. 0 is used instead of NaN because NaN is not a valid entry for integer
tensors, as used for discrete actions.
"""
action_shape = None
for _action in agent_buffer_field:

if action_shape is None:
return []
# Convert to tensor while padding with 0's
*agent_buffer_field, fillvalue=np.full(action_shape, np.nan)
*agent_buffer_field, fillvalue=np.full(action_shape, 0)
),
)
)

52
ml-agents/mlagents/trainers/tests/torch/test_agent_action.py


import numpy as np
from mlagents.torch_utils import torch
from mlagents.trainers.buffer import AgentBuffer, BufferKey
from mlagents.trainers.torch.agent_action import AgentAction
def test_agent_action_group_from_buffer():
buff = AgentBuffer()
# Create some actions
for _ in range(3):
buff[BufferKey.GROUP_CONTINUOUS_ACTION].append(
3 * [np.ones((5,), dtype=np.float32)]
)
buff[BufferKey.GROUP_DISCRETE_ACTION].append(
3 * [np.ones((4,), dtype=np.float32)]
)
# Some agents have died
for _ in range(2):
buff[BufferKey.GROUP_CONTINUOUS_ACTION].append(
1 * [np.ones((5,), dtype=np.float32)]
)
buff[BufferKey.GROUP_DISCRETE_ACTION].append(
1 * [np.ones((4,), dtype=np.float32)]
)
# Get the group actions, which will be a List of Lists of AgentAction, where each element is the same
# length as the AgentBuffer but contains only one agent's obs. Dead agents are padded by
# NaNs.
gact = AgentAction.group_from_buffer(buff)
# Agent 0 is full
agent_0_act = gact[0]
assert agent_0_act.continuous_tensor.shape == (buff.num_experiences, 5)
assert agent_0_act.discrete_tensor.shape == (buff.num_experiences, 4)
agent_1_act = gact[1]
assert agent_1_act.continuous_tensor.shape == (buff.num_experiences, 5)
assert agent_1_act.discrete_tensor.shape == (buff.num_experiences, 4)
assert (agent_1_act.continuous_tensor[0:3] > 0).all()
assert (agent_1_act.continuous_tensor[3:] == 0).all()
assert (agent_1_act.discrete_tensor[0:3] > 0).all()
assert (agent_1_act.discrete_tensor[3:] == 0).all()
def test_to_flat():
aa = AgentAction(
torch.tensor([[1.0, 1.0, 1.0]]), [torch.tensor([2]), torch.tensor([1])]
)
flattened_actions = aa.to_flat([3, 3])
assert torch.eq(
flattened_actions, torch.tensor([[1, 1, 1, 0, 0, 1, 0, 1, 0]])
).all()
正在加载...
取消
保存