|
|
|
|
|
|
from typing import Callable, Dict, List, Optional |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import tensorflow as tf |
|
|
|
import tensorflow.contrib.layers as c_layers |
|
|
|
from mlagents.tf_utils import tf |
|
|
|
|
|
|
|
from mlagents.trainers.trainer import UnityTrainerException |
|
|
|
from mlagents.envs.brain import CameraResolution |
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def scaled_init(scale): |
|
|
|
return c_layers.variance_scaling_initializer(scale) |
|
|
|
return tf.initializers.variance_scaling(scale) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def swish(input_activation: tf.Tensor) -> tf.Tensor: |
|
|
|
|
|
|
activation=activation, |
|
|
|
reuse=reuse, |
|
|
|
name="hidden_{}".format(i), |
|
|
|
kernel_initializer=c_layers.variance_scaling_initializer(1.0), |
|
|
|
kernel_initializer=tf.initializers.variance_scaling(1.0), |
|
|
|
) |
|
|
|
return hidden |
|
|
|
|
|
|
|
|
|
|
reuse=reuse, |
|
|
|
name="conv_2", |
|
|
|
) |
|
|
|
hidden = c_layers.flatten(conv2) |
|
|
|
hidden = tf.layers.flatten(conv2) |
|
|
|
|
|
|
|
with tf.variable_scope(scope + "/" + "flat_encoding"): |
|
|
|
hidden_flat = LearningModel.create_vector_observation_encoder( |
|
|
|
|
|
|
reuse=reuse, |
|
|
|
name="conv_3", |
|
|
|
) |
|
|
|
hidden = c_layers.flatten(conv3) |
|
|
|
hidden = tf.layers.flatten(conv3) |
|
|
|
|
|
|
|
with tf.variable_scope(scope + "/" + "flat_encoding"): |
|
|
|
hidden_flat = LearningModel.create_vector_observation_encoder( |
|
|
|
|
|
|
) |
|
|
|
hidden = tf.add(block_input, hidden) |
|
|
|
hidden = tf.nn.relu(hidden) |
|
|
|
hidden = c_layers.flatten(hidden) |
|
|
|
hidden = tf.layers.flatten(hidden) |
|
|
|
|
|
|
|
with tf.variable_scope(scope + "/" + "flat_encoding"): |
|
|
|
hidden_flat = LearningModel.create_vector_observation_encoder( |
|
|
|
|
|
|
memory_in = tf.reshape(memory_in[:, :], [-1, m_size]) |
|
|
|
half_point = int(m_size / 2) |
|
|
|
with tf.variable_scope(name): |
|
|
|
rnn_cell = tf.contrib.rnn.BasicLSTMCell(half_point) |
|
|
|
lstm_vector_in = tf.contrib.rnn.LSTMStateTuple( |
|
|
|
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(half_point) |
|
|
|
lstm_vector_in = tf.nn.rnn_cell.LSTMStateTuple( |
|
|
|
memory_in[:, :half_point], memory_in[:, half_point:] |
|
|
|
) |
|
|
|
recurrent_output, lstm_state_out = tf.nn.dynamic_rnn( |
|
|
|