浏览代码

remove warning prints

/layernorm
Andrew Cohen 4 年前
当前提交
231328ea
共有 2 个文件被更改,包括 14 次插入14 次删除
  1. 8
      ml-agents/mlagents/trainers/agent_processor.py
  2. 20
      ml-agents/mlagents/trainers/torch/networks.py

8
ml-agents/mlagents/trainers/agent_processor.py


self._process_step(
ongoing_step, global_id, decision_steps.agent_id_to_index[local_id]
)
if decision_steps.agent_id_to_index[local_id] == 9:
print("wrong index", ongoing_step, global_id, decision_steps.agent_id, decision_steps.obs, decision_steps.reward, terminal_steps.obs, terminal_steps.agent_id)
#if decision_steps.agent_id_to_index[local_id] == 9:
# print("wrong index", ongoing_step, global_id, decision_steps.agent_id, decision_steps.obs, decision_steps.reward, terminal_steps.obs, terminal_steps.agent_id)
for _gid in action_global_agent_ids:
# If the ID doesn't have a last step result, the agent just reset,

) -> None:
terminated = isinstance(step, TerminalStep)
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None))
if idx == 9:
print("Index was 9", global_id, stored_decision_step.agent_id)
#if idx == 9:
# print("Index was 9", global_id, stored_decision_step.agent_id)
stored_take_action_outputs = self.last_take_action_outputs.get(global_id, None)
if not terminated:
# Index is needed to grab from last_take_action_outputs

20
ml-agents/mlagents/trainers/torch/networks.py


if total_enc_size == 0:
raise Exception("No valid inputs to network.")
for _, tens in list(self.transformer.named_parameters()):
tens.retain_grad()
for _, tens in list(self.entity_embedding.named_parameters()):
tens.retain_grad()
#for _, tens in list(self.transformer.named_parameters()):
# tens.retain_grad()
#for _, tens in list(self.entity_embedding.named_parameters()):
# tens.retain_grad()
for _, tens in list(self.linear_encoder.named_parameters()):
tens.retain_grad()
for processor in self.processors:
if processor is not None:
for _, tens in list(processor.named_parameters()):
tens.retain_grad()
#for _, tens in list(self.linear_encoder.named_parameters()):
# tens.retain_grad()
#for processor in self.processors:
# if processor is not None:
# for _, tens in list(processor.named_parameters()):
# tens.retain_grad()
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)

正在加载...
取消
保存