|
|
|
|
|
|
) |
|
|
|
else: |
|
|
|
self._create_dc_actor( |
|
|
|
self.encoder, self.h_size, policy_layers, separate_train |
|
|
|
self.encoder, self.h_size, policy_layers, separate_train, separate_policy_net |
|
|
|
) |
|
|
|
|
|
|
|
self.policy_variables = tf.get_collection( |
|
|
|
|
|
|
h_size: int, |
|
|
|
num_layers: int, |
|
|
|
separate_train: bool = False, |
|
|
|
separate_net: bool = False |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Creates Discrete control actor-critic model. |
|
|
|
|
|
|
""" |
|
|
|
if self.use_recurrent: |
|
|
|
self.prev_action = tf.placeholder( |
|
|
|
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action" |
|
|
|
) |
|
|
|
prev_action_oh = tf.concat( |
|
|
|
[ |
|
|
|
tf.one_hot(self.prev_action[:, i], self.act_size[i]) |
|
|
|
for i in range(len(self.act_size)) |
|
|
|
], |
|
|
|
axis=1, |
|
|
|
) |
|
|
|
hidden_policy = tf.concat([encoded, prev_action_oh], axis=1) |
|
|
|
with tf.variable_scope("policy"): |
|
|
|
if separate_net: |
|
|
|
encoded = self._create_encoder_general( |
|
|
|
self.visual_in, |
|
|
|
self.processed_vector_in, |
|
|
|
h_size, |
|
|
|
self.feature_size, |
|
|
|
num_layers, |
|
|
|
self.vis_encode_type, |
|
|
|
scope="policy_enc" |
|
|
|
) |
|
|
|
if self.use_recurrent: |
|
|
|
self.prev_action = tf.placeholder( |
|
|
|
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action" |
|
|
|
) |
|
|
|
prev_action_oh = tf.concat( |
|
|
|
[ |
|
|
|
tf.one_hot(self.prev_action[:, i], self.act_size[i]) |
|
|
|
for i in range(len(self.act_size)) |
|
|
|
], |
|
|
|
axis=1, |
|
|
|
) |
|
|
|
hidden_policy = tf.concat([encoded, prev_action_oh], axis=1) |
|
|
|
self.memory_in = tf.placeholder( |
|
|
|
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" |
|
|
|
) |
|
|
|
hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( |
|
|
|
hidden_policy, |
|
|
|
self.memory_in, |
|
|
|
self.sequence_length_ph, |
|
|
|
name="lstm_policy", |
|
|
|
) |
|
|
|
self.memory_in = tf.placeholder( |
|
|
|
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" |
|
|
|
) |
|
|
|
hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( |
|
|
|
hidden_policy, |
|
|
|
self.memory_in, |
|
|
|
self.sequence_length_ph, |
|
|
|
name="lstm_policy", |
|
|
|
) |
|
|
|
self.memory_out = tf.identity(memory_policy_out, "recurrent_out") |
|
|
|
else: |
|
|
|
hidden_policy = encoded |
|
|
|
self.memory_out = tf.identity(memory_policy_out, "recurrent_out") |
|
|
|
else: |
|
|
|
hidden_policy = encoded |
|
|
|
if separate_train: |
|
|
|
hidden_policy = tf.stop_gradient(hidden_policy) |
|
|
|
self.action_masks = tf.placeholder( |
|
|
|
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks" |
|
|
|
) |
|
|
|
if not separate_net: |
|
|
|
if separate_train: |
|
|
|
hidden_policy = tf.stop_gradient(hidden_policy) |
|
|
|
with tf.variable_scope("policy"): |
|
|
|
hidden_policy = ModelUtils.create_vector_observation_encoder( |
|
|
|
hidden_policy, |
|
|
|
h_size, |
|
|
|
ModelUtils.swish, |
|
|
|
num_layers, |
|
|
|
scope=f"main_graph", |
|
|
|
reuse=False, |
|
|
|
hidden_policy = ModelUtils.create_vector_observation_encoder( |
|
|
|
hidden_policy, |
|
|
|
h_size, |
|
|
|
ModelUtils.swish, |
|
|
|
num_layers, |
|
|
|
scope=f"main_graph", |
|
|
|
reuse=False, |
|
|
|
) |
|
|
|
|
|
|
|
self.action_masks = tf.placeholder( |
|
|
|
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks" |
|
|
|
|
|
|
|
distribution = MultiCategoricalDistribution( |
|
|
|
hidden_policy, self.act_size, self.action_masks |
|
|
|
) |
|
|
|