浏览代码

Fix some typing issues with curiosity

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
b61d2fa1
共有 2 个文件被更改,包括 5 次插入5 次删除
  1. 8
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
  2. 2
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py

8
ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py


# Create the encoder ops for current and next visual input.
# Note that these encoders are siamese.
encoded_visual = self.policy_model.create_visual_observation_encoder(
encoded_visual = LearningModel.create_visual_observation_encoder(
self.policy_model.visual_in[i],
self.encoding_size,
LearningModel.swish,

)
encoded_next_visual = self.policy_model.create_visual_observation_encoder(
encoded_next_visual = LearningModel.create_visual_observation_encoder(
self.next_visual_in[i],
self.encoding_size,
LearningModel.swish,

name="curiosity_next_vector_observation",
)
encoded_vector_obs = self.policy_model.create_vector_observation_encoder(
encoded_vector_obs = LearningModel.create_vector_observation_encoder(
self.policy_model.vector_in,
self.encoding_size,
LearningModel.swish,

)
encoded_next_vector_obs = self.policy_model.create_vector_observation_encoder(
encoded_next_vector_obs = LearningModel.create_vector_observation_encoder(
self.next_vector_in,
self.encoding_size,
LearningModel.swish,

2
ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py


:return: Feed_dict needed for update.
"""
feed_dict = {
policy.batch_size: num_sequences,
policy.batch_size_ph: num_sequences,
policy.sequence_length: self.policy.sequence_length,
policy.mask_input: mini_batch["masks"],
}

正在加载...
取消
保存