浏览代码

SAC LSTM isn't broken

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
5ec49542
共有 4 个文件被更改,包括 34 次插入72 次删除
  1. 5
      ml-agents/mlagents/trainers/optimizer.py
  2. 7
      ml-agents/mlagents/trainers/ppo/optimizer.py
  3. 51
      ml-agents/mlagents/trainers/sac/network.py
  4. 43
      ml-agents/mlagents/trainers/sac/optimizer.py

5
ml-agents/mlagents/trainers/optimizer.py


network_out = self.sess.run(list(out_dict.values()), feed_dict=feed_dict)
run_out = dict(zip(list(out_dict.keys()), network_out))
return run_out
def _make_zero_mem(self, m_size: int, length: int) -> List[np.ndarray]:
return [
np.zeros((m_size)) for i in range(0, length, self.policy.sequence_length)
]

7
ml-agents/mlagents/trainers/ppo/optimizer.py


import logging
from typing import Optional, Any, Dict, List
from typing import Optional, Any, Dict
import numpy as np
from mlagents.tf_utils import tf

self.m_size, mini_batch.num_experiences
)
return feed_dict
def _make_zero_mem(self, m_size: int, length: int) -> List[np.ndarray]:
return [
np.zeros((m_size)) for i in range(0, length, self.policy.sequence_length)
]

51
ml-agents/mlagents/trainers/sac/network.py


self.activ_fn = LearningModel.swish
self.sequence_length_ph = tf.placeholder(
shape=None, dtype=tf.int32, name="sequence_length"
shape=None, dtype=tf.int32, name="sac_sequence_length"
)
self.policy_memory_in: Optional[tf.Tensor] = None

if self.policy.use_recurrent:
self.memory_in = tf.placeholder(
shape=[None, self.policy.m_size],
dtype=tf.float32,
name="recurrent_in",
shape=[None, m_size], dtype=tf.float32, name="target_recurrent_in"
)
self.value_memory_in = self.memory_in
hidden_streams = LearningModel.create_observation_streams(

class SACPolicyNetwork(SACNetwork):
"""
Instantiation for SAC policy network. Contains a dual Q estimator,
a value estimator, and the actual policy network.
a value estimator, and a reference to the actual policy network.
"""
def __init__(

vis_encode_type,
)
if self.policy.use_recurrent:
self.create_memory_ins(self.policy.m_size)
self.create_memory_ins(m_size)
# Use the sequence length of the policy
self.sequence_length_ph = self.policy.sequence_length_ph
if self.policy.use_continuous_act:
self.create_cc_critic(hidden_critic, POLICY_SCOPE)

if self.use_recurrent:
mem_outs = [
self.value_memory_out,
self.q1_memory_out,
self.q2_memory_out,
self.policy_memory_out,
]
mem_outs = [self.value_memory_out, self.q1_memory_out, self.q2_memory_out]
self.memory_out = tf.concat(mem_outs, axis=1)
def create_memory_ins(self, m_size):

"""
# Create the Policy input separate from the rest
# This is so in inference we only have to run the Policy network.
# Barracuda will grab the recurrent_in and recurrent_out named tensors.
self.inference_memory_in = tf.placeholder(
shape=[None, m_size // 4], dtype=tf.float32, name="recurrent_in"
)
# We assume m_size is divisible by 4
# Create the non-Policy inputs
# Use a default placeholder here so nothing has to be provided during
# Barracuda inference. Note that the default value is just the tiled input
# for the policy, which is thrown away.
three_fourths_m_size = m_size * 3 // 4
self.other_memory_in = tf.placeholder_with_default(
input=tf.tile(self.inference_memory_in, [1, 3]),
shape=[None, three_fourths_m_size],
name="other_recurrent_in",
)
# Concat and use this as the "placeholder"
# for training
self.memory_in = tf.concat(
[self.other_memory_in, self.inference_memory_in], axis=1
self.memory_in = tf.placeholder(
shape=[None, m_size * 3], dtype=tf.float32, name="value_recurrent_in"
num_mems = 4
num_mems = 3
input_size = self.memory_in.get_shape().as_list()[1]
_start = m_size // num_mems * i
_end = m_size // num_mems * (i + 1)
_start = input_size // num_mems * i
_end = input_size // num_mems * (i + 1)
self.policy_memory_in = mem_ins[3]
def create_observation_in(self, vis_encode_type):
"""

43
ml-agents/mlagents/trainers/sac/optimizer.py


trainer_params.get("vis_encode_type", "simple")
)
self.tau = trainer_params.get("tau", 0.005)
m_size = self.policy.m_size
self.init_entcoef = trainer_params.get("init_entcoef", 1.0)
stream_names = self.reward_signals.keys()
# Use to reduce "survivor bonus" when using Curiosity or GAIL.

self.policy_network = SACPolicyNetwork(
policy=self.policy,
m_size=m_size,
m_size=self.policy.m_size, # 3x policy.m_size
h_size=h_size,
normalize=self.policy.normalize,
use_recurrent=self.policy.use_recurrent,

)
self.target_network = SACTargetNetwork(
policy=self.policy,
m_size=m_size // 4 if m_size else None,
m_size=self.policy.m_size, # 1x policy.m_size
h_size=h_size,
normalize=self.policy.normalize,
use_recurrent=self.policy.use_recurrent,

)
# The optimizer's m_size is 3 times the policy (Q1, Q2, and Value)
self.m_size = 3 * self.policy.m_size
self.create_inputs_and_outputs()
self.learning_rate = LearningModel.create_learning_rate(
lr_schedule, lr, self.policy.global_step, int(max_step)

# Add some stuff to inference dict from optimizer
self.policy.inference_dict["learning_rate"] = self.learning_rate
if self.policy.use_recurrent:
self.policy.inference_dict["optimizer_memory_out"] = self.memory_out
def create_inputs_and_outputs(self) -> None:
"""

if self.policy.use_recurrent:
self.memory_in = self.policy_network.memory_in
self.memory_out = self.policy_network.memory_out
# For Barracuda
self.inference_memory_out = tf.identity(
self.policy_network.policy_memory_out, name="recurrent_out"
)
if not self.policy.use_continuous_act:
self.prev_action = self.policy_network.prev_action
self.next_memory_in = self.target_network.memory_in

stats_needed.update(self.reward_signals[name].stats_name_to_update_name)
def construct_feed_dict(
self, policy: TFPolicy, batch: Dict[str, Any], num_sequences: int
self, policy: TFPolicy, batch: AgentBuffer, num_sequences: int
) -> Dict[tf.Tensor, Any]:
"""
Builds the feed dict for updating the SAC model.

_obs = batch["next_visual_obs%d" % i]
feed_dict[self.next_visual_in[i]] = _obs
if self.policy.use_recurrent:
mem_in = [
batch["memory"][i]
for i in range(0, len(batch["memory"]), self.policy.sequence_length)
]
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
offset = 1 if self.policy.sequence_length > 1 else 0
next_mem_in = [
batch["memory"][i][
: self.policy.m_size // 4
] # only pass value part of memory to target network
for i in range(
offset, len(batch["memory"]), self.policy.sequence_length
)
]
feed_dict[policy.memory_in] = mem_in
feed_dict[self.next_memory_in] = next_mem_in
feed_dict[policy.memory_in] = self._make_zero_mem(
self.policy.m_size, batch.num_experiences
)
feed_dict[self.policy_network.memory_in] = self._make_zero_mem(
self.m_size, batch.num_experiences
)
feed_dict[self.target_network.memory_in] = self._make_zero_mem(
self.policy.m_size, batch.num_experiences
)
feed_dict[self.dones_holder] = batch["done"]
return feed_dict
正在加载...
取消
保存