|
|
|
|
|
|
is_training: bool, |
|
|
|
load: bool, |
|
|
|
tanh_squash: bool = False, |
|
|
|
resample: bool = False, |
|
|
|
reparameterize: bool = False, |
|
|
|
condition_sigma_on_obs: bool = True, |
|
|
|
create_tf_graph: bool = True, |
|
|
|
): |
|
|
|
|
|
|
:param is_training: Whether the model should be trained. |
|
|
|
:param load: Whether a pre-trained model will be loaded or a new one created. |
|
|
|
:param tanh_squash: Whether to use a tanh function on the continuous output, or a clipped output. |
|
|
|
:param resample: Whether we are using the resampling trick to update the policy in continuous output. |
|
|
|
:param reparameterize: Whether we are using the resampling trick to update the policy in continuous output. |
|
|
|
""" |
|
|
|
super().__init__(seed, brain, trainer_params, load) |
|
|
|
self.grads = None |
|
|
|
|
|
|
trainer_params.get("vis_encode_type", "simple") |
|
|
|
) |
|
|
|
self.tanh_squash = tanh_squash |
|
|
|
self.resample = resample |
|
|
|
self.reparameterize = reparameterize |
|
|
|
self.condition_sigma_on_obs = condition_sigma_on_obs |
|
|
|
self.trainable_variables: List[tf.Variable] = [] |
|
|
|
|
|
|
|
|
|
|
self.num_layers, |
|
|
|
self.vis_encode_type, |
|
|
|
self.tanh_squash, |
|
|
|
self.resample, |
|
|
|
self.reparameterize, |
|
|
|
self.condition_sigma_on_obs, |
|
|
|
) |
|
|
|
else: |
|
|
|
|
|
|
num_layers: int, |
|
|
|
vis_encode_type: EncoderType, |
|
|
|
tanh_squash: bool = False, |
|
|
|
resample: bool = False, |
|
|
|
reparameterize: bool = False, |
|
|
|
condition_sigma_on_obs: bool = True, |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
|
|
|
:param vis_encode_type: Type of visual encoder to use if visual input. |
|
|
|
:param tanh_squash: Whether to use a tanh function, or a clipped output. |
|
|
|
:param resample: Whether we are using the resampling trick to update the policy. |
|
|
|
:param reparameterize: Whether we are using the resampling trick to update the policy. |
|
|
|
""" |
|
|
|
with tf.variable_scope("policy"): |
|
|
|
hidden_stream = ModelUtils.create_observation_streams( |
|
|
|
|
|
|
sampled_policy = mu + sigma * epsilon |
|
|
|
|
|
|
|
# Stop gradient if we're not doing the resampling trick |
|
|
|
if not resample: |
|
|
|
if not reparameterize: |
|
|
|
sampled_policy_probs = tf.stop_gradient(sampled_policy) |
|
|
|
else: |
|
|
|
sampled_policy_probs = sampled_policy |
|
|
|