浏览代码

Cleanup visual obs setup (#2647)

* DRY up the setup code

* fstrings
/develop-gpu-test
GitHub 5 年前
当前提交
4980b904
共有 1 个文件被更改,包括 28 次插入42 次删除
  1. 70
      ml-agents/mlagents/trainers/models.py

70
ml-agents/mlagents/trainers/models.py


)
return hidden
@staticmethod
self,
image_input: tf.Tensor,
h_size: int,
activation: ActivationFunction,

hidden = c_layers.flatten(conv2)
with tf.variable_scope(scope + "/" + "flat_encoding"):
hidden_flat = self.create_vector_observation_encoder(
hidden_flat = LearningModel.create_vector_observation_encoder(
@staticmethod
self,
image_input: tf.Tensor,
h_size: int,
activation: ActivationFunction,

hidden = c_layers.flatten(conv3)
with tf.variable_scope(scope + "/" + "flat_encoding"):
hidden_flat = self.create_vector_observation_encoder(
hidden_flat = LearningModel.create_vector_observation_encoder(
@staticmethod
self,
image_input: tf.Tensor,
h_size: int,
activation: ActivationFunction,

hidden = c_layers.flatten(hidden)
with tf.variable_scope(scope + "/" + "flat_encoding"):
hidden_flat = self.create_vector_observation_encoder(
hidden_flat = LearningModel.create_vector_observation_encoder(
hidden, h_size, activation, num_layers, scope, reuse
)
return hidden_flat

num_layers: int,
vis_encode_type: EncoderType = EncoderType.SIMPLE,
stream_scopes: List[str] = None,
) -> tf.Tensor:
) -> List[tf.Tensor]:
"""
Creates encoding stream for observations.
:param num_streams: Number of streams to create.

self.visual_in.append(visual_input)
vector_observation_input = self.create_vector_input()
# Pick the encoder function based on the EncoderType
create_encoder_func = LearningModel.create_visual_observation_encoder
if vis_encode_type == EncoderType.RESNET:
create_encoder_func = LearningModel.create_resnet_visual_observation_encoder
elif vis_encode_type == EncoderType.NATURE_CNN:
create_encoder_func = (
LearningModel.create_nature_cnn_visual_observation_encoder
)
final_hiddens = []
for i in range(num_streams):
visual_encoders = []

if vis_encode_type == EncoderType.RESNET:
for j in range(brain.number_visual_observations):
encoded_visual = self.create_resnet_visual_observation_encoder(
self.visual_in[j],
h_size,
activation_fn,
num_layers,
_scope_add + "main_graph_{}_encoder{}".format(i, j),
False,
)
visual_encoders.append(encoded_visual)
elif vis_encode_type == EncoderType.NATURE_CNN:
for j in range(brain.number_visual_observations):
encoded_visual = self.create_nature_cnn_visual_observation_encoder(
self.visual_in[j],
h_size,
activation_fn,
num_layers,
_scope_add + "main_graph_{}_encoder{}".format(i, j),
False,
)
visual_encoders.append(encoded_visual)
else:
for j in range(brain.number_visual_observations):
encoded_visual = self.create_visual_observation_encoder(
self.visual_in[j],
h_size,
activation_fn,
num_layers,
_scope_add + "main_graph_{}_encoder{}".format(i, j),
False,
)
visual_encoders.append(encoded_visual)
for j in range(brain.number_visual_observations):
encoded_visual = create_encoder_func(
self.visual_in[j],
h_size,
activation_fn,
num_layers,
scope=f"{_scope_add}main_graph_{i}_encoder{j}",
reuse=False,
)
visual_encoders.append(encoded_visual)
hidden_visual = tf.concat(visual_encoders, axis=1)
if brain.vector_observation_space_size > 0:
hidden_state = self.create_vector_observation_encoder(

num_layers,
_scope_add + "main_graph_{}".format(i),
False,
scope=f"{_scope_add}main_graph_{i}",
reuse=False,
)
if hidden_state is not None and hidden_visual is not None:
final_hidden = tf.concat([hidden_visual, hidden_state], axis=1)

正在加载...
取消
保存