浏览代码

[Bug fix] Gym last reward before Done (#3471)

* Fixing #3460

* Addressing comments

* Added 2 tests

* encapsulate the agent mapping operations (#3481)

* encapsulate the agent mapping operations

* rename, linear time impl

* cleanup

* dict.popitem

* udpate comments

* Update gym-unity/gym_unity/tests/test_gym.py

Co-authored-by: Chris Elion <celion@gmail.com>
/asymm-envs
GitHub 4 年前
当前提交
a8c0564b
共有 2 个文件被更改,包括 198 次插入24 次删除
  1. 145
      gym-unity/gym_unity/envs/__init__.py
  2. 77
      gym-unity/gym_unity/tests/test_gym.py

145
gym-unity/gym_unity/envs/__init__.py


import logging
import itertools
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, Union, Set
from typing import Any, Dict, List, Optional, Tuple, Union
import gym
from gym import error, spaces

self.visual_obs = None
self._n_agents = -1
self._done_agents: Set[int] = set()
self.agent_mapper = AgentIdIndexMapper()
# Save the step result from the last time all Agents requested decisions.
self._previous_step_result: BatchedStepResult = None
self._multiagent = multiagent

step_result = self._env.get_step_result(self.brain_name)
self._check_agents(step_result.n_agents())
self._previous_step_result = step_result
self.agent_mapper.set_initial_agents(list(self._previous_step_result.agent_id))
# Set observation and action spaces
if self.group_spec.is_action_discrete():

"The number of agents in the scene does not match the expected number."
)
# remove the done Agents
indices_to_keep: List[int] = []
for index, is_done in enumerate(step_result.done):
if not is_done:
indices_to_keep.append(index)
if step_result.n_agents() - sum(step_result.done) != self._n_agents:
raise UnityGymException(
"The number of agents in the scene does not match the expected number."
)
for index, agent_id in enumerate(step_result.agent_id):
if step_result.done[index]:
self.agent_mapper.mark_agent_done(agent_id, step_result.reward[index])
# Set the new AgentDone flags to True
# Note that the corresponding agent_id that gets marked done will be different

if not self._previous_step_result.contains_agent(agent_id):
step_result.done[index] = True
if agent_id in self._done_agents:
# Register this agent, and get the reward of the previous agent that
# was in its index, so that we can return it to the gym.
last_reward = self.agent_mapper.register_new_agent_id(agent_id)
self._done_agents = set()
step_result.reward[index] = last_reward
# Get a permutation of the agent IDs so that a given ID stays in the same
# index as where it was first seen.
new_id_order = self.agent_mapper.get_id_permutation(list(step_result.agent_id))
_mask.append(step_result.action_mask[mask_index][indices_to_keep])
_mask.append(step_result.action_mask[mask_index][new_id_order])
new_obs.append(step_result.obs[obs_index][indices_to_keep])
new_obs.append(step_result.obs[obs_index][new_id_order])
reward=step_result.reward[indices_to_keep],
done=step_result.done[indices_to_keep],
max_step=step_result.max_step[indices_to_keep],
agent_id=step_result.agent_id[indices_to_keep],
reward=step_result.reward[new_id_order],
done=step_result.done[new_id_order],
max_step=step_result.max_step[new_id_order],
agent_id=step_result.agent_id[new_id_order],
if self._previous_step_result.n_agents() == self._n_agents:
return action
input_index = 0
for index in range(self._previous_step_result.n_agents()):
for index, agent_id in enumerate(self._previous_step_result.agent_id):
sanitized_action[index, :] = action[input_index, :]
input_index = input_index + 1
array_index = self.agent_mapper.get_gym_index(agent_id)
sanitized_action[index, :] = action[array_index, :]
return sanitized_action
def _step(self, needs_reset: bool = False) -> BatchedStepResult:

"The environment does not have the expected amount of agents."
+ "Some agents did not request decisions at the same time."
)
self._done_agents.update(list(info.agent_id))
for agent_id, reward in zip(info.agent_id, info.reward):
self.agent_mapper.mark_agent_done(agent_id, reward)
self._env.step()
info = self._env.get_step_result(self.brain_name)
return self._sanitize_info(info)

:return: The List containing the branched actions.
"""
return self.action_lookup[action]
class AgentIdIndexMapper:
def __init__(self) -> None:
self._agent_id_to_gym_index: Dict[int, int] = {}
self._done_agents_index_to_last_reward: Dict[int, float] = {}
def set_initial_agents(self, agent_ids: List[int]) -> None:
"""
Provide the initial list of agent ids for the mapper
"""
for idx, agent_id in enumerate(agent_ids):
self._agent_id_to_gym_index[agent_id] = idx
def mark_agent_done(self, agent_id: int, reward: float) -> None:
"""
Declare the agent done with the corresponding final reward.
"""
gym_index = self._agent_id_to_gym_index.pop(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward
def register_new_agent_id(self, agent_id: int) -> float:
"""
Adds the new agent ID and returns the reward to use for the previous agent in this index
"""
# Any free index is OK here.
free_index, last_reward = self._done_agents_index_to_last_reward.popitem()
self._agent_id_to_gym_index[agent_id] = free_index
return last_reward
def get_id_permutation(self, agent_ids: List[int]) -> List[int]:
"""
Get the permutation from new agent ids to the order that preserves the positions of previous agents.
The result is a list with each integer from 0 to len(agent_ids)-1 appearing exactly once.
"""
# Map the new agent ids to the their index
new_agent_ids_to_index = {
agent_id: idx for idx, agent_id in enumerate(agent_ids)
}
# Make the output list. We don't write to it sequentially, so start with dummy values.
new_permutation = [-1] * len(agent_ids)
# For each agent ID, find the new index of the agent, and write it in the original index.
for agent_id, original_index in self._agent_id_to_gym_index.items():
new_permutation[original_index] = new_agent_ids_to_index[agent_id]
return new_permutation
def get_gym_index(self, agent_id: int) -> int:
"""
Get the gym index for the current agent.
"""
return self._agent_id_to_gym_index[agent_id]
class AgentIdIndexMapperSlow:
"""
Reference implementation of AgentIdIndexMapper.
The operations are O(N^2) so it shouldn't be used for large numbers of agents.
See AgentIdIndexMapper for method descriptions
"""
def __init__(self) -> None:
self._gym_id_order: List[int] = []
self._done_agents_index_to_last_reward: Dict[int, float] = {}
def set_initial_agents(self, agent_ids: List[int]) -> None:
self._gym_id_order = list(agent_ids)
def mark_agent_done(self, agent_id: int, reward: float) -> None:
gym_index = self._gym_id_order.index(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward
self._gym_id_order[gym_index] = -1
def register_new_agent_id(self, agent_id: int) -> float:
original_index = self._gym_id_order.index(-1)
self._gym_id_order[original_index] = agent_id
reward = self._done_agents_index_to_last_reward.pop(original_index)
return reward
def get_id_permutation(self, agent_ids):
new_id_order = []
for agent_id in self._gym_id_order:
new_id_order.append(agent_ids.index(agent_id))
return new_id_order
def get_gym_index(self, agent_id: int) -> int:
return self._gym_id_order.index(agent_id)

77
gym-unity/gym_unity/tests/test_gym.py


import numpy as np
from gym import spaces
from gym_unity.envs import UnityEnv, UnityGymException
from gym_unity.envs import (
UnityEnv,
UnityGymException,
AgentIdIndexMapper,
AgentIdIndexMapperSlow,
)
from mlagents_envs.base_env import AgentGroupSpec, ActionType, BatchedStepResult

assert isinstance(info, dict)
@mock.patch("gym_unity.envs.UnityEnvironment")
def test_sanitize_action_shuffled_id(mock_env):
mock_spec = create_mock_group_spec(
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
)
mock_step = create_mock_vector_step_result(num_agents=5)
mock_step.agent_id = np.array(range(5))
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
env = UnityEnv(" ", use_visual=False, multiagent=True)
shuffled_step_result = create_mock_vector_step_result(num_agents=5)
shuffled_order = [4, 2, 3, 1, 0]
shuffled_step_result.reward = np.array(shuffled_order)
shuffled_step_result.agent_id = np.array(shuffled_order)
sanitized_result = env._sanitize_info(shuffled_step_result)
for expected_reward, reward in zip(range(5), sanitized_result.reward):
assert expected_reward == reward
for expected_agent_id, agent_id in zip(range(5), sanitized_result.agent_id):
assert expected_agent_id == agent_id
@mock.patch("gym_unity.envs.UnityEnvironment")
def test_sanitize_action_one_agent_done(mock_env):
mock_spec = create_mock_group_spec(
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
)
mock_step = create_mock_vector_step_result(num_agents=5)
mock_step.agent_id = np.array(range(5))
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
env = UnityEnv(" ", use_visual=False, multiagent=True)
received_step_result = create_mock_vector_step_result(num_agents=6)
received_step_result.agent_id = np.array(range(6))
# agent #3 (id = 2) is Done
received_step_result.done = np.array([False] * 2 + [True] + [False] * 3)
sanitized_result = env._sanitize_info(received_step_result)
for expected_agent_id, agent_id in zip([0, 1, 5, 3, 4], sanitized_result.agent_id):
assert expected_agent_id == agent_id
# Helper methods

mock_env.return_value.get_agent_groups.return_value = ["MockBrain"]
mock_env.return_value.get_agent_group_spec.return_value = mock_spec
mock_env.return_value.get_step_result.return_value = mock_result
@pytest.mark.parametrize("mapper_cls", [AgentIdIndexMapper, AgentIdIndexMapperSlow])
def test_agent_id_index_mapper(mapper_cls):
mapper = mapper_cls()
initial_agent_ids = [1001, 1002, 1003, 1004]
mapper.set_initial_agents(initial_agent_ids)
# Mark some agents as done with their last rewards.
mapper.mark_agent_done(1001, 42.0)
mapper.mark_agent_done(1004, 1337.0)
# Now add new agents, and get the rewards of the agent they replaced.
old_reward1 = mapper.register_new_agent_id(2001)
old_reward2 = mapper.register_new_agent_id(2002)
# Order of the rewards don't matter
assert {old_reward1, old_reward2} == {42.0, 1337.0}
new_agent_ids = [1002, 1003, 2001, 2002]
permutation = mapper.get_id_permutation(new_agent_ids)
# Make sure it's actually a permutation - needs to contain 0..N-1 with no repeats.
assert set(permutation) == set(range(0, 4))
# For initial agents that were in the initial group, they need to be in the same slot.
# Agents that were added later can appear in any free slot.
permuted_ids = [new_agent_ids[i] for i in permutation]
for idx, agent_id in enumerate(initial_agent_ids):
if agent_id in permuted_ids:
assert permuted_ids[idx] == agent_id
正在加载...
取消
保存