浏览代码

Add team reward to buffer

/develop/centralizedcritic/counterfact
Ervin Teng 4 年前
当前提交
fdf97d99
共有 2 个文件被更改,包括 23 次插入7 次删除
  1. 27
      ml-agents/mlagents/trainers/agent_processor.py
  2. 3
      ml-agents/mlagents/trainers/trajectory.py

27
ml-agents/mlagents/trainers/agent_processor.py


self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list)
self.last_step_result: Dict[str, Tuple[DecisionStep, int]] = {}
# current_group_obs is used to collect the last seen obs of all the agents in the same group,
# and assemble the next_collab_obs.
# and assemble the collab_obs.
self.current_group_obs: Dict[str, Dict[str, List[np.ndarray]]] = defaultdict(
lambda: defaultdict(list)
)

lambda: defaultdict(list)
)
# current_group_rewards is used to collect the last seen rewards of all the agents in the same group.
self.current_group_rewards: Dict[str, Dict[str, float]] = defaultdict(
lambda: defaultdict(float)
)
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while
# grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1).

global_id
] = stored_decision_step.obs
self.current_group_obs[step.team_manager_id][global_id] = step.obs
self.current_group_rewards[step.team_manager_id][
global_id
] = step.reward
for _manager_id, _team_group in self.current_group_obs.items():
self._safe_delete(_team_group, global_id)
if not _team_group: # if dict is empty
self._safe_delete(_team_group, _manager_id)
for _manager_id, _team_group in self.last_group_obs.items():
self._safe_delete(_team_group, global_id)
self._delete_in_nested_dict(self.current_group_obs, global_id)
self._delete_in_nested_dict(self.last_group_obs, global_id)
self._delete_in_nested_dict(self.current_group_rewards, global_id)
def _delete_in_nested_dict(self, nested_dict, key):
for _manager_id, _team_group in nested_dict.items():
self._safe_delete(_team_group, key)
if not _team_group: # if dict is empty
self._safe_delete(_team_group, _manager_id)

for _id, _obs in self.last_group_obs[step.team_manager_id].items():
if _id != global_id:
collab_obs.append(_obs)
teammate_rewards = []
for _id, _rew in self.current_group_rewards[step.team_manager_id].items():
if _id != global_id:
teammate_rewards.append(_rew)
team_rewards=teammate_rewards,
done=done,
action=action_tuple,
action_probs=log_probs_tuple,

3
ml-agents/mlagents/trainers/trajectory.py


obs: List[np.ndarray]
collab_obs: List[List[np.ndarray]]
reward: float
team_rewards: List[float]
done: bool
action: ActionTuple
action_probs: LogProbsTuple

# Assume teammates have same obs space
ith_team_obs.append(_team_obs[i])
agent_buffer_trajectory[TeamObsUtil.get_name_at(i)].append(ith_team_obs)
agent_buffer_trajectory["team_rewards"].append(exp.team_rewards)
if exp.memory is not None:
agent_buffer_trajectory["memory"].append(exp.memory)

正在加载...
取消
保存