浏览代码

Some more bugfixes

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
76ad64d7
共有 3 个文件被更改,包括 15 次插入2 次删除
  1. 14
      ml-agents/mlagents/trainers/ppo/optimizer.py
  2. 2
      ml-agents/mlagents/trainers/ppo/trainer.py
  3. 1
      ml-agents/mlagents/trainers/tf_policy.py

14
ml-agents/mlagents/trainers/ppo/optimizer.py


lr,
max_step,
)
self.create_ppo_optimizer()
self.update_dict.update(
{
"value_loss": self.value_loss,
"policy_loss": self.abs_policy_loss,
"update_batch": self.update_batch,
}
)
def create_cc_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType

def construct_feed_dict(
self, model: PPOModel, mini_batch: AgentBuffer, num_sequences: int
) -> Dict[tf.Tensor, Any]:
model.sequence_length: model.sequence_length,
model.sequence_length: len(mini_batch["advantages"])
/ num_sequences, # TODO: Fix LSTM
model.mask_input: mini_batch["masks"],
self.advantage: mini_batch["advantages"],
self.all_old_log_probs: mini_batch["action_probs"],

:param out_dict: Output dictionary mapping names to nodes.
:return: Dictionary mapping names to input data.
"""
print(feed_dict)
network_out = self.sess.run(list(out_dict.values()), feed_dict=feed_dict)
run_out = dict(zip(list(out_dict.keys()), network_out))
return run_out

2
ml-agents/mlagents/trainers/ppo/trainer.py


buffer = self.update_buffer
max_num_batch = buffer_length // batch_size
for l in range(0, max_num_batch * batch_size, batch_size):
update_stats = self.policy.update(
update_stats = self.policy.optimizer.update(
buffer.make_mini_batch(l, l + batch_size), n_sequences
)
for stat_name, value in update_stats.items():

1
ml-agents/mlagents/trainers/tf_policy.py


config.allow_soft_placement = True
self.sess = tf.Session(config=config, graph=self.graph)
self.saver = None
self.optimizer = None
if self.use_recurrent:
self.m_size = trainer_parameters["memory_size"]
self.sequence_length = trainer_parameters["sequence_length"]

正在加载...
取消
保存