浏览代码

Move initialization call around

/develop/nopreviousactions
Ervin Teng 4 年前
当前提交
78671383
共有 6 个文件被更改,包括 14 次插入7 次删除
  1. 4
      ml-agents/mlagents/trainers/common/nn_policy.py
  2. 12
      ml-agents/mlagents/trainers/ppo/optimizer.py
  3. 1
      ml-agents/mlagents/trainers/ppo/trainer.py
  4. 2
      ml-agents/mlagents/trainers/sac/optimizer.py
  5. 1
      ml-agents/mlagents/trainers/sac/trainer.py
  6. 1
      ml-agents/mlagents/trainers/tests/test_bcmodule.py

4
ml-agents/mlagents/trainers/common/nn_policy.py


if self.use_recurrent:
self.inference_dict["memory_out"] = self.memory_out
# We do an initialize to make the Policy usable out of the box. If an optimizer is needed,
# it will re-load the full graph
self._initialize_graph()
@timed
def evaluate(
self, batched_step_result: BatchedStepResult, global_agent_ids: List[str]

12
ml-agents/mlagents/trainers/ppo/optimizer.py


# Add some stuff to inference dict from optimizer
self.policy.inference_dict["learning_rate"] = self.learning_rate
self.policy.initialize_or_load()
def create_cc_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType

"{}_value_estimates".format(name)
]
if "actions_pre" in mini_batch:
if self.policy.output_pre is not None and "actions_pre" in mini_batch:
feed_dict[self.policy.output_pre] = mini_batch["actions_pre"]
else:
feed_dict[self.policy.action_holder] = mini_batch["actions"]

for i, _ in enumerate(self.policy.visual_in):
feed_dict[self.policy.visual_in[i]] = mini_batch["visual_obs%d" % i]
if self.policy.use_recurrent:
feed_dict[self.policy.memory_in] = self._make_zero_mem(
self.policy.m_size, mini_batch.num_experiences
)
feed_dict[self.policy.memory_in] = [
mini_batch["memory"][i]
for i in range(
0, len(mini_batch["memory"]), self.policy.sequence_length
)
]
feed_dict[self.memory_in] = self._make_zero_mem(
self.m_size, mini_batch.num_experiences
)

1
ml-agents/mlagents/trainers/ppo/trainer.py


raise RuntimeError("Non-NNPolicy passed to PPOTrainer.add_policy()")
self.policy = policy
self.optimizer = PPOOptimizer(self.policy, self.trainer_parameters)
self.policy.initialize_or_load()
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly

2
ml-agents/mlagents/trainers/sac/optimizer.py


[self.policy.update_normalization_op, target_update_norm]
)
self.policy.initialize_or_load()
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",

1
ml-agents/mlagents/trainers/sac/trainer.py


raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()")
self.policy = policy
self.optimizer = SACOptimizer(self.policy, self.trainer_parameters)
self.policy.initialize_or_load()
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly

1
ml-agents/mlagents/trainers/tests/test_bcmodule.py


default_num_epoch=3,
**trainer_config["behavioral_cloning"],
)
policy.initialize_or_load()
return bc_module

正在加载...
取消
保存