浏览代码

Bug fixes

/develop/coma2/samenet
Ervin Teng 4 年前
当前提交
a9116382
共有 3 个文件被更改,包括 7 次插入7 次删除
  1. 10
      ml-agents/mlagents/trainers/agent_processor.py
  2. 2
      ml-agents/mlagents/trainers/ppo/trainer.py
  3. 2
      ml-agents/mlagents/trainers/torch/agent_action.py

10
ml-agents/mlagents/trainers/agent_processor.py


stored_decision_step, idx = self.last_step_result.get(global_id, (None, None))
stored_take_action_outputs = self.last_take_action_outputs.get(global_id, None)
if stored_decision_step is not None and stored_take_action_outputs is not None:
if step.team_manager_id > 0:
if step.group_id > 0:
stored_actions = stored_take_action_outputs["action"]
action_tuple = ActionTuple(
continuous=stored_actions.continuous[idx],

action=action_tuple,
done=isinstance(step, TerminalStep),
)
self.group_status[step.team_manager_id][global_id] = group_status
self.current_group_obs[step.team_manager_id][global_id] = step.obs
self.group_status[step.group_id][global_id] = group_status
self.current_group_obs[step.group_id][global_id] = step.obs
def _clear_teammate_obs(self, global_id: str) -> None:
self._delete_in_nested_dict(self.current_group_obs, global_id)

# Assemble teammate_obs. If none saved, then it will be an empty list.
group_statuses = []
for _id, _obs in self.group_status[step.team_manager_id].items():
for _id, _obs in self.group_status[step.group_id].items():
if _id != global_id:
group_statuses.append(_obs)

):
next_obs = step.obs
next_group_obs = []
for _id, _exp in self.current_group_obs[step.team_manager_id].items():
for _id, _exp in self.current_group_obs[step.group_id].items():
if _id != global_id:
next_group_obs.append(_exp)

2
ml-agents/mlagents/trainers/ppo/trainer.py


int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)
advantages = self.update_buffer["advantages"].get_batch()
advantages = np.array(self.update_buffer["advantages"].get_batch())
self.update_buffer["advantages"].set(
(advantages - advantages.mean()) / (advantages.std() + 1e-10)
)

2
ml-agents/mlagents/trainers/torch/agent_action.py


@staticmethod
def _from_team_dict(
buff: Dict[str, np.ndarray], cont_action_key: str, disc_action_key: str
):
) -> List["AgentAction"]:
continuous_tensors: List[torch.Tensor] = []
discrete_tensors: List[torch.Tensor] = [] # type: ignore
if cont_action_key in buff:

正在加载...
取消
保存