浏览代码
Python Dataflow for Group Manager (#4926)
Python Dataflow for Group Manager (#4926)
* Make buffer type-agnostic * Edit types of Apped method * Change comment * Collaborative walljump * Make collab env harder * Add group ID * Add collab obs to trajectory * Fix bug; add critic_obs to buffer * Set group ids for some envs * Pretty broken * Less broken PPO * Update SAC, fix PPO batching * Fix SAC interrupted condition and typing * Fix SAC interrupted again * Remove erroneous file * Fix multiple obs * Update curiosity reward provider * Update GAIL and BC * Multi-input network * Some minor tweaks but still broken * Get next critic observations into value estimate * Temporarily disable exporting * Use Vince's ONNX export code * Cleanup * Add walljump collab YAML * Lower max height * Update prefab * Update prefab * Collaborative Hallway * Set num teammates to 2 * Add config and group ids to HallwayCollab * Fix bug with hallway collab * E.../develop/gail-srl-hack
GitHub
4 年前
当前提交
d36a5242
共有 14 个文件被更改,包括 832 次插入 和 141 次删除
-
3ml-agents-envs/mlagents_envs/base_env.py
-
211ml-agents/mlagents/trainers/agent_processor.py
-
15ml-agents/mlagents/trainers/behavior_id_utils.py
-
106ml-agents/mlagents/trainers/buffer.py
-
15ml-agents/mlagents/trainers/policy/policy.py
-
2ml-agents/mlagents/trainers/ppo/trainer.py
-
20ml-agents/mlagents/trainers/tests/mock_brain.py
-
136ml-agents/mlagents/trainers/tests/test_agent_processor.py
-
80ml-agents/mlagents/trainers/tests/test_buffer.py
-
54ml-agents/mlagents/trainers/tests/test_trajectory.py
-
86ml-agents/mlagents/trainers/torch/agent_action.py
-
12ml-agents/mlagents/trainers/torch/utils.py
-
170ml-agents/mlagents/trainers/trajectory.py
-
63ml-agents/mlagents/trainers/tests/torch/test_agent_action.py
|
|||
import numpy as np |
|||
from mlagents.torch_utils import torch |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer, BufferKey |
|||
from mlagents.trainers.torch.agent_action import AgentAction |
|||
|
|||
|
|||
def test_agent_action_group_from_buffer(): |
|||
buff = AgentBuffer() |
|||
# Create some actions |
|||
for _ in range(3): |
|||
buff[BufferKey.GROUP_CONTINUOUS_ACTION].append( |
|||
3 * [np.ones((5,), dtype=np.float32)] |
|||
) |
|||
buff[BufferKey.GROUP_DISCRETE_ACTION].append( |
|||
3 * [np.ones((4,), dtype=np.float32)] |
|||
) |
|||
# Some agents have died |
|||
for _ in range(2): |
|||
buff[BufferKey.GROUP_CONTINUOUS_ACTION].append( |
|||
1 * [np.ones((5,), dtype=np.float32)] |
|||
) |
|||
buff[BufferKey.GROUP_DISCRETE_ACTION].append( |
|||
1 * [np.ones((4,), dtype=np.float32)] |
|||
) |
|||
|
|||
# Get the group actions, which will be a List of Lists of AgentAction, where each element is the same |
|||
# length as the AgentBuffer but contains only one agent's obs. Dead agents are padded by |
|||
# NaNs. |
|||
gact = AgentAction.group_from_buffer(buff) |
|||
# Agent 0 is full |
|||
agent_0_act = gact[0] |
|||
assert agent_0_act.continuous_tensor.shape == (buff.num_experiences, 5) |
|||
assert agent_0_act.discrete_tensor.shape == (buff.num_experiences, 4) |
|||
|
|||
agent_1_act = gact[1] |
|||
assert agent_1_act.continuous_tensor.shape == (buff.num_experiences, 5) |
|||
assert agent_1_act.discrete_tensor.shape == (buff.num_experiences, 4) |
|||
assert (agent_1_act.continuous_tensor[0:3] > 0).all() |
|||
assert (agent_1_act.continuous_tensor[3:] == 0).all() |
|||
assert (agent_1_act.discrete_tensor[0:3] > 0).all() |
|||
assert (agent_1_act.discrete_tensor[3:] == 0).all() |
|||
|
|||
|
|||
def test_to_flat(): |
|||
# Both continuous and discrete |
|||
aa = AgentAction( |
|||
torch.tensor([[1.0, 1.0, 1.0]]), [torch.tensor([2]), torch.tensor([1])] |
|||
) |
|||
flattened_actions = aa.to_flat([3, 3]) |
|||
assert torch.eq( |
|||
flattened_actions, torch.tensor([[1, 1, 1, 0, 0, 1, 0, 1, 0]]) |
|||
).all() |
|||
|
|||
# Just continuous |
|||
aa = AgentAction(torch.tensor([[1.0, 1.0, 1.0]]), None) |
|||
flattened_actions = aa.to_flat([]) |
|||
assert torch.eq(flattened_actions, torch.tensor([1, 1, 1])).all() |
|||
|
|||
# Just discrete |
|||
aa = AgentAction(torch.tensor([]), [torch.tensor([2]), torch.tensor([1])]) |
|||
flattened_actions = aa.to_flat([3, 3]) |
|||
assert torch.eq(flattened_actions, torch.tensor([0, 0, 1, 0, 1, 0])).all() |
撰写
预览
正在加载...
取消
保存
Reference in new issue