浏览代码

Rename resample to reparameterize

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
d6eb262c
共有 2 个文件被更改,包括 8 次插入8 次删除
  1. 14
      ml-agents/mlagents/trainers/common/nn_policy.py
  2. 2
      ml-agents/mlagents/trainers/sac/trainer.py

14
ml-agents/mlagents/trainers/common/nn_policy.py


is_training: bool,
load: bool,
tanh_squash: bool = False,
resample: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
create_tf_graph: bool = True,
):

:param is_training: Whether the model should be trained.
:param load: Whether a pre-trained model will be loaded or a new one created.
:param tanh_squash: Whether to use a tanh function on the continuous output, or a clipped output.
:param resample: Whether we are using the resampling trick to update the policy in continuous output.
:param reparameterize: Whether we are using the resampling trick to update the policy in continuous output.
"""
super().__init__(seed, brain, trainer_params, load)
self.grads = None

trainer_params.get("vis_encode_type", "simple")
)
self.tanh_squash = tanh_squash
self.resample = resample
self.reparameterize = reparameterize
self.condition_sigma_on_obs = condition_sigma_on_obs
self.trainable_variables: List[tf.Variable] = []

self.num_layers,
self.vis_encode_type,
self.tanh_squash,
self.resample,
self.reparameterize,
self.condition_sigma_on_obs,
)
else:

num_layers: int,
vis_encode_type: EncoderType,
tanh_squash: bool = False,
resample: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
) -> None:
"""

:param vis_encode_type: Type of visual encoder to use if visual input.
:param tanh_squash: Whether to use a tanh function, or a clipped output.
:param resample: Whether we are using the resampling trick to update the policy.
:param reparameterize: Whether we are using the resampling trick to update the policy.
"""
with tf.variable_scope("policy"):
hidden_stream = ModelUtils.create_observation_streams(

sampled_policy = mu + sigma * epsilon
# Stop gradient if we're not doing the resampling trick
if not resample:
if not reparameterize:
sampled_policy_probs = tf.stop_gradient(sampled_policy)
else:
sampled_policy_probs = sampled_policy

2
ml-agents/mlagents/trainers/sac/trainer.py


self.is_training,
self.load,
tanh_squash=True,
resample=True,
reparameterize=True,
create_tf_graph=False,
)
for _reward_signal in policy.reward_signals.keys():

正在加载...
取消
保存