浏览代码

Curiosity Driven Exploration & Pyramids Environments (#739)

* Adds implementation of Curiosity-driven Exploration by Self-supervised Prediction (https://arxiv.org/abs/1705.05363) to PPO trainer.
* To enable, set use_curiosity flag to true in hyperparameter file.
* Includes refactor of unitytrainers model code to accommodate new feature.
* Adds new Pyramids environment (w/ documentation). Environment contains sparse reward, and can only be solved using PPO+Curiosity.
/develop-generalizationTraining-TrainerController
GitHub 7 年前
当前提交
c17937ef
共有 59 个文件被更改,包括 7466 次插入119 次删除
  1. 15
      docs/Learning-Environment-Examples.md
  2. 5
      docs/Training-ML-Agents.md
  3. 16
      docs/Training-PPO.md
  4. 5
      docs/Using-Tensorboard.md
  5. 34
      python/trainer_config.yaml
  6. 2
      python/unitytrainers/bc/models.py
  7. 179
      python/unitytrainers/models.py
  8. 112
      python/unitytrainers/ppo/models.py
  9. 160
      python/unitytrainers/ppo/trainer.py
  10. 3
      python/unitytrainers/trainer_controller.py
  11. 10
      unity-environment/Assets/ML-Agents/Scripts/Agent.cs
  12. 4
      unity-environment/ProjectSettings/TagManager.asset
  13. 996
      docs/images/pyramids.png
  14. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids.meta
  15. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials.meta
  16. 76
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/Brick.mat
  17. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/Brick.mat.meta
  18. 76
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/Gold.mat
  19. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/Gold.mat.meta
  20. 76
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/agent.mat
  21. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/agent.mat.meta
  22. 76
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/black.mat
  23. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/black.mat.meta
  24. 76
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/ground.mat
  25. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/ground.mat.meta
  26. 76
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/red.mat
  27. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/red.mat.meta
  28. 76
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/wall.mat
  29. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/wall.mat.meta
  30. 76
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/white.mat
  31. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/white.mat.meta
  32. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs.meta
  33. 1001
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs/AreaPB.prefab
  34. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs/AreaPB.prefab.meta
  35. 1001
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs/BrickPyramid.prefab
  36. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs/BrickPyramid.prefab.meta
  37. 1001
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs/StonePyramid.prefab
  38. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs/StonePyramid.prefab.meta
  39. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Scenes.meta
  40. 1001
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Scenes/Pyramids.unity
  41. 7
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Scenes/Pyramids.unity.meta
  42. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts.meta
  43. 18
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAcademy.cs
  44. 11
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAcademy.cs.meta
  45. 117
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs
  46. 11
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs.meta
  47. 55
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidArea.cs
  48. 11
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidArea.cs.meta
  49. 47
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidSwitch.cs
  50. 11
      unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidSwitch.cs.meta
  51. 8
      unity-environment/Assets/ML-Agents/Examples/Pyramids/TFModels.meta
  52. 1001
      unity-environment/Assets/ML-Agents/Examples/Pyramids/TFModels/Pyramids.bytes
  53. 7
      unity-environment/Assets/ML-Agents/Examples/Pyramids/TFModels/Pyramids.bytes.meta

15
docs/Learning-Environment-Examples.md


* Vector Action space: (Continuous) Size of 39, corresponding to target rotations applicable to the joints.
* Visual Observations: None
* Reset Parameters: None
## Pyramids
![Pyramids](images/pyramids.png)
* Set-up: Environment where the agent needs to press a button to spawn a pyramid, then navigate to the pyramid, knock it over, and move to the gold brick at the top.
* Goal: Move to the golden brick on top of the spawned pyramid.
* Agents: The environment contains one agent linked to a single brain.
* Agent Reward Function (independent):
* +2 For moving to golden brick (minus 0.001 per step).
* Brains: One brain with the following observation/action space:
* Vector Observation space: (Continuous) 148 corresponding to local ray-casts detecting switch, bricks, golden brick, and walls, plus variable indicating switch state.
* Vector Action space: (Discrete) 4 corresponding to agent rotation and forward/backward movement.
* Visual Observations (Optional): First-person view for the agent.
* Reset Parameters: None

5
docs/Training-ML-Agents.md


| batch_size | The number of experiences in each iteration of gradient descent.| PPO, BC |
| batches_per_epoch | In imitation learning, the number of batches of training examples to collect before training the model.| BC |
| beta | The strength of entropy regularization.| PPO, BC |
| brain_to_imitate | For imitation learning, the name of the GameObject containing the Brain component to imitate. | BC |
| brain\_to\_imitate | For imitation learning, the name of the GameObject containing the Brain component to imitate. | BC |
| curiosity\_enc\_size | The size of the encoding to use in the forward and inverse models in the Curioity module. | PPO |
| curiosity_strength | Magnitude of intrinsic reward generated by Intrinsic Curiosity Module. | PPO |
| epsilon | Influences how rapidly the policy can evolve during training.| PPO, BC |
| gamma | The reward discount rate for the Generalized Advantage Estimator (GAE). | PPO |
| hidden_units | The number of units in the hidden layers of the neural network. | PPO, BC |

| summary_freq | How often, in steps, to save training statistics. This determines the number of data points shown by TensorBoard. | PPO, BC |
| time_horizon | How many steps of experience to collect per-agent before adding it to the experience buffer. | PPO, BC |
| trainer | The type of training to perform: "ppo" or "imitation".| PPO, BC |
| use_curiosity | Train using an additional intrinsic reward signal generated from Intrinsic Curiosity Module. | PPO |
| use_recurrent | Train using a recurrent neural network. See [Using Recurrent Neural Networks](Feature-Memory.md).| PPO, BC |
|| PPO = Proximal Policy Optimization, BC = Behavioral Cloning (Imitation)) ||

16
docs/Training-PPO.md


Typical Range: `64` - `512`
### (Optional) Intrinsic Curiosity Module Hyperparameters
The below hyperparameters are only used when `use_curiosity` is set to true.
#### Curioisty Encoding Size
`curiosity_enc_size` corresponds to the size of the hidden layer used to encode the observations within the intrinsic curiosity module. This value should be small enough to encourage the curiosity module to compress the original observation, but also not too small to prevent it from learning the dynamics of the environment.
Typical Range: `64` - `256`
#### Curiosity Strength
`curiosity_strength` corresponds to the magnitude of the intrinsic reward generated by the intrinsic curiosity module. This should be scaled in order to ensure it is large enough to not be overwhelmed by extrnisic reward signals in the environment. Likewise it should not be too large to overwhelm the extrinsic reward signal.
Typical Range: `0.1` - `0.001`
## Training Statistics
To view training statistics, use TensorBoard. For information on launching and using TensorBoard, see [here](./Getting-Started-with-Balance-Ball.md#observing-training-progress).

5
docs/Using-Tensorboard.md


well the model is able to predict the value of each state. This should increase
while the agent is learning, and then decrease once the reward stabilizes.
* _(Curiosity-Specific)_ Intrinsic Reward - This corresponds to the mean cumulative intrinsic reward generated per-episode.
* _(Curiosity-Specific)_ Forward Loss - The mean magnitude of the inverse model loss function. Corresponds to how well the model is able to predict the new observation encoding.
* _(Curiosity-Specific)_ Inverse Loss - The mean magnitude of the forward model loss function. Corresponds to how well the model is able to predict the action taken between two observations.

34
python/trainer_config.yaml


sequence_length: 64
summary_freq: 1000
use_recurrent: false
use_curiosity: false
curiosity_strength: 0.01
encoding_size: 128
BananaBrain:
normalize: false

BouncerBrain:
normalize: true
max_steps: 5.0e5
num_layers: 2
hidden_units: 64
PushBlockBrain:
max_steps: 5.0e4

time_horizon: 128
num_layers: 2
normalize: false
PyramidBrain:
use_curiosity: true
summary_freq: 2000
curiosity_strength: 0.01
curiosity_enc_size: 256
time_horizon: 128
batch_size: 128
buffer_size: 2048
hidden_units: 512
num_layers: 2
beta: 1.0e-2
max_steps: 2.0e5
num_epoch: 3
VisualPyramidBrain:
use_curiosity: true
time_horizon: 128
batch_size: 32
buffer_size: 1024
hidden_units: 256
num_layers: 2
beta: 1.0e-2
max_steps: 5.0e5
num_epoch: 3
Ball3DBrain:
normalize: true

2
python/unitytrainers/bc/models.py


LearningModel.__init__(self, m_size, normalize, use_recurrent, brain)
num_streams = 1
hidden_streams = self.create_new_obs(num_streams, h_size, n_layers)
hidden_streams = self.create_observation_streams(num_streams, h_size, n_layers)
hidden = hidden_streams[0]
self.dropout_rate = tf.placeholder(dtype=tf.float32, shape=[], name="dropout_rate")
hidden_reg = tf.layers.dropout(hidden, self.dropout_rate)

179
python/unitytrainers/models.py


self.normalize = normalize
self.use_recurrent = use_recurrent
self.a_size = brain.vector_action_space_size
self.o_size = brain.vector_observation_space_size * brain.num_stacked_vector_observations
self.v_size = brain.number_visual_observations
@staticmethod
def create_global_steps():

return tf.multiply(input_activation, tf.nn.sigmoid(input_activation))
@staticmethod
def create_visual_input(o_size_h, o_size_w, bw, name):
def create_visual_input(camera_parameters, name):
"""
Creates image input op.
:param camera_parameters: Parameters for visual observation from BrainInfo.
:param name: Desired name of input op.
:return: input op.
"""
o_size_h = camera_parameters['height']
o_size_w = camera_parameters['width']
bw = camera_parameters['blackAndWhite']
if bw:
c_channels = 1
else:

return visual_in
def create_vector_input(self, s_size):
def create_vector_input(self, name='vector_observation'):
"""
Creates ops for vector observation input.
:param name: Name of the placeholder op.
:param o_size: Size of stacked vector observation.
:return:
"""
self.vector_in = tf.placeholder(shape=[None, s_size], dtype=tf.float32, name='vector_observation')
self.vector_in = tf.placeholder(shape=[None, self.o_size], dtype=tf.float32, name=name)
self.running_mean = tf.get_variable("running_mean", [s_size], trainable=False, dtype=tf.float32,
self.running_mean = tf.get_variable("running_mean", [self.o_size], trainable=False, dtype=tf.float32,
self.running_variance = tf.get_variable("running_variance", [s_size], trainable=False, dtype=tf.float32,
initializer=tf.ones_initializer())
self.new_mean = tf.placeholder(shape=[s_size], dtype=tf.float32, name='new_mean')
self.new_variance = tf.placeholder(shape=[s_size], dtype=tf.float32, name='new_variance')
self.running_variance = tf.get_variable("running_variance", [self.o_size], trainable=False,
dtype=tf.float32, initializer=tf.ones_initializer())
self.new_mean = tf.placeholder(shape=[self.o_size], dtype=tf.float32, name='new_mean')
self.new_variance = tf.placeholder(shape=[self.o_size], dtype=tf.float32, name='new_variance')
self.update_mean = tf.assign(self.running_mean, self.new_mean)
self.update_variance = tf.assign(self.running_variance, self.new_variance)

return self.normalized_state
self.normalized_state = self.vector_in
return self.vector_in
return self.vector_in
def create_continuous_state_encoder(self, h_size, activation, num_layers):
@staticmethod
def create_continuous_observation_encoder(observation_input, h_size, activation, num_layers, scope, reuse):
:param reuse: Whether to re-use the weights within the same scope.
:param scope: Graph scope for the encoder ops.
:param observation_input: Input vector.
hidden = self.normalized_state
for j in range(num_layers):
hidden = tf.layers.dense(hidden, h_size, activation=activation,
kernel_initializer=c_layers.variance_scaling_initializer(1.0))
with tf.variable_scope(scope):
hidden = observation_input
for i in range(num_layers):
hidden = tf.layers.dense(hidden, h_size, activation=activation, reuse=reuse, name="hidden_{}".format(i),
kernel_initializer=c_layers.variance_scaling_initializer(1.0))
def create_visual_encoder(self, image_input, h_size, activation, num_layers):
def create_visual_observation_encoder(self, image_input, h_size, activation, num_layers, scope, reuse):
:param reuse: Whether to re-use the weights within the same scope.
:param scope: The scope of the graph within which to create the ops.
:param image_input: The placeholder for the image input to use.
:param h_size: Hidden layer size.
:param activation: What type of activation function to use for layers.

conv1 = tf.layers.conv2d(image_input, 16, kernel_size=[8, 8], strides=[4, 4],
activation=tf.nn.elu)
conv2 = tf.layers.conv2d(conv1, 32, kernel_size=[4, 4], strides=[2, 2],
activation=tf.nn.elu)
hidden = c_layers.flatten(conv2)
with tf.variable_scope(scope):
conv1 = tf.layers.conv2d(image_input, 16, kernel_size=[8, 8], strides=[4, 4],
activation=tf.nn.elu, reuse=reuse, name="conv_1")
conv2 = tf.layers.conv2d(conv1, 32, kernel_size=[4, 4], strides=[2, 2],
activation=tf.nn.elu, reuse=reuse, name="conv_2")
hidden = c_layers.flatten(conv2)
for j in range(num_layers):
hidden = tf.layers.dense(hidden, h_size, use_bias=False, activation=activation)
return hidden
hidden_flat = self.create_continuous_observation_encoder(hidden, h_size, activation, num_layers, scope, reuse)
return hidden_flat
def create_discrete_state_encoder(self, s_size, h_size, activation, num_layers):
@staticmethod
def create_discrete_observation_encoder(observation_input, s_size, h_size, activation,
num_layers, scope, reuse):
:param reuse: Whether to re-use the weights within the same scope.
:param scope: The scope of the graph within which to create the ops.
:param observation_input: Discrete observation.
:param s_size: state input size (discrete).
:param h_size: Hidden layer size.
:param activation: What type of activation function to use for layers.

vector_in = tf.reshape(self.vector_in, [-1])
state_onehot = tf.one_hot(vector_in, s_size)
hidden = state_onehot
for j in range(num_layers):
hidden = tf.layers.dense(hidden, h_size, use_bias=False, activation=activation)
with tf.name_scope(scope):
vector_in = tf.reshape(observation_input, [-1])
state_onehot = tf.one_hot(vector_in, s_size)
hidden = state_onehot
for i in range(num_layers):
hidden = tf.layers.dense(hidden, h_size, use_bias=False, activation=activation,
reuse=reuse, name="hidden_{}".format(i))
def create_new_obs(self, num_streams, h_size, num_layers):
def create_observation_streams(self, num_streams, h_size, num_layers):
"""
Creates encoding stream for observations.
:param num_streams: Number of streams to create.
:param h_size: Size of hidden linear layers in stream.
:param num_layers: Number of hidden linear layers in stream.
:return: List of encoded streams.
"""
s_size = brain.vector_observation_space_size * brain.num_stacked_vector_observations
if brain.vector_action_space_type == "continuous":
activation_fn = tf.nn.tanh
else:

for i in range(brain.number_visual_observations):
height_size, width_size = brain.camera_resolutions[i]['height'], brain.camera_resolutions[i]['width']
bw = brain.camera_resolutions[i]['blackAndWhite']
visual_input = self.create_visual_input(height_size, width_size, bw, name="visual_observation_" + str(i))
visual_input = self.create_visual_input(brain.camera_resolutions[i], name="visual_observation_" + str(i))
self.create_vector_input(s_size)
vector_observation_input = self.create_vector_input()
if brain.number_visual_observations > 0:
if self.v_size > 0:
encoded_visual = self.create_visual_encoder(self.visual_in[j], h_size, activation_fn, num_layers)
encoded_visual = self.create_visual_observation_encoder(self.visual_in[j], h_size,
activation_fn, num_layers,
"main_graph_{}".format(i), False)
s_size = brain.vector_observation_space_size * brain.num_stacked_vector_observations
hidden_state = self.create_continuous_state_encoder(h_size, activation_fn, num_layers)
hidden_state = self.create_continuous_observation_encoder(vector_observation_input,
h_size, activation_fn, num_layers,
"main_graph_{}".format(i), False)
hidden_state = self.create_discrete_state_encoder(s_size, h_size,
activation_fn, num_layers)
hidden_state = self.create_discrete_observation_encoder(vector_observation_input, self.o_size,
h_size, activation_fn, num_layers,
"main_graph_{}".format(i), False)
if hidden_state is not None and hidden_visual is not None:
final_hidden = tf.concat([hidden_visual, hidden_state], axis=1)
elif hidden_state is None and hidden_visual is not None:

final_hiddens.append(final_hidden)
return final_hiddens
def create_recurrent_encoder(self, input_state, memory_in, name='lstm'):
@staticmethod
def create_recurrent_encoder(input_state, memory_in, sequence_length, name='lstm'):
:param sequence_length: Length of sequence to unroll.
:param input_state: The input tensor to the LSTM cell.
:param memory_in: The input memory to the LSTM cell.
:param name: The scope of the LSTM cell.

lstm_input_state = tf.reshape(input_state, shape=[-1, self.sequence_length, s_size])
lstm_input_state = tf.reshape(input_state, shape=[-1, sequence_length, s_size])
memory_in = tf.reshape(memory_in[:, :], [-1, m_size])
recurrent_state, lstm_state_out = tf.nn.dynamic_rnn(rnn_cell, lstm_input_state,
initial_state=lstm_vector_in,
time_major=False,
dtype=tf.float32)
recurrent_output, lstm_state_out = tf.nn.dynamic_rnn(rnn_cell, lstm_input_state,
initial_state=lstm_vector_in)
recurrent_state = tf.reshape(recurrent_state, shape=[-1, _half_point])
return recurrent_state, tf.concat([lstm_state_out.c, lstm_state_out.h], axis=1)
recurrent_output = tf.reshape(recurrent_output, shape=[-1, _half_point])
return recurrent_output, tf.concat([lstm_state_out.c, lstm_state_out.h], axis=1)
num_streams = 1
hidden_streams = self.create_new_obs(num_streams, h_size, num_layers)
"""
Creates Discrete control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
"""
hidden_streams = self.create_observation_streams(1, h_size, num_layers)
self.prev_action_oh = tf.one_hot(self.prev_action, self.a_size)
hidden = tf.concat([hidden, self.prev_action_oh], axis=1)
prev_action_oh = tf.one_hot(self.prev_action, self.a_size)
hidden = tf.concat([hidden, prev_action_oh], axis=1)
hidden, self.memory_out = self.create_recurrent_encoder(hidden, self.memory_in)
self.memory_out = tf.identity(self.memory_out, name='recurrent_out')
hidden, memory_out = self.create_recurrent_encoder(hidden, self.memory_in, self.sequence_length)
self.memory_out = tf.identity(memory_out, name='recurrent_out')
self.output = tf.multinomial(self.policy, 1)
self.output = tf.identity(self.output, name="action")
output = tf.multinomial(self.policy, 1)
self.output = tf.identity(output, name="action")
self.value = tf.layers.dense(hidden, 1, activation=None)
self.value = tf.identity(self.value, name="value_estimate")
value = tf.layers.dense(hidden, 1, activation=None)
self.value = tf.identity(value, name="value_estimate")
self.entropy = -tf.reduce_sum(self.all_probs * tf.log(self.all_probs + 1e-10), axis=1)
self.action_holder = tf.placeholder(shape=[None], dtype=tf.int32)
self.selected_actions = tf.one_hot(self.action_holder, self.a_size)

self.old_probs = tf.expand_dims(tf.reduce_sum(self.all_old_probs * self.selected_actions, axis=1), 1)
def create_cc_actor_critic(self, h_size, num_layers):
num_streams = 2
hidden_streams = self.create_new_obs(num_streams, h_size, num_layers)
"""
Creates Continuous control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
"""
hidden_streams = self.create_observation_streams(2, h_size, num_layers)
if self.use_recurrent:
tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32)

hidden_streams[0], self.memory_in[:, :_half_point], name='lstm_policy')
hidden_streams[0], self.memory_in[:, :_half_point], self.sequence_length, name='lstm_policy')
hidden_streams[1], self.memory_in[:, _half_point:], name='lstm_value')
hidden_streams[1], self.memory_in[:, _half_point:], self.sequence_length, name='lstm_value')
self.memory_out = tf.concat([memory_policy_out, memory_value_out], axis=1, name='recurrent_out')
else:
hidden_policy = hidden_streams[0]

self.output_pre = mu + tf.sqrt(sigma_sq) * epsilon
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3
self.output = tf.identity(output_post, name='action')
self.selected_actions = tf.stop_gradient(output_post)
# Compute probability of model output.
a = tf.exp(-1 * tf.pow(tf.stop_gradient(self.output_pre) - mu, 2) / (2 * sigma_sq))

112
python/unitytrainers/ppo/models.py


class PPOModel(LearningModel):
def __init__(self, brain, lr=1e-4, h_size=128, epsilon=0.2, beta=1e-3, max_step=5e6,
normalize=False, use_recurrent=False, num_layers=2, m_size=None):
normalize=False, use_recurrent=False, num_layers=2, m_size=None, use_curiosity=False,
curiosity_strength=0.01, curiosity_enc_size=128):
"""
Takes a Unity environment and model-specific hyper-parameters and returns the
appropriate PPO agent model for the environment.

:param m_size: Size of brain memory.
"""
LearningModel.__init__(self, m_size, normalize, use_recurrent, brain)
self.use_curiosity = use_curiosity
if num_layers < 1:
num_layers = 1
self.last_reward, self.new_reward, self.update_reward = self.create_reward_encoder()

else:
self.create_dc_actor_critic(h_size, num_layers)
if self.use_curiosity:
self.curiosity_enc_size = curiosity_enc_size
self.curiosity_strength = curiosity_strength
encoded_state, encoded_next_state = self.create_curiosity_encoders()
self.create_inverse_model(encoded_state, encoded_next_state)
self.create_forward_model(encoded_state, encoded_next_state)
self.create_ppo_optimizer(self.probs, self.old_probs, self.value,
self.entropy, beta, epsilon, lr, max_step)

update_reward = tf.assign(last_reward, new_reward)
return last_reward, new_reward, update_reward
def create_curiosity_encoders(self):
"""
Creates state encoders for current and future observations.
Used for implementation of Curiosity-driven Exploration by Self-supervised Prediction
See https://arxiv.org/abs/1705.05363 for more details.
:return: current and future state encoder tensors.
"""
encoded_state_list = []
encoded_next_state_list = []
if self.v_size > 0:
self.next_visual_in = []
visual_encoders = []
next_visual_encoders = []
for i in range(self.v_size):
# Create input ops for next (t+1) visual observations.
next_visual_input = self.create_visual_input(self.brain.camera_resolutions[i],
name="next_visual_observation_" + str(i))
self.next_visual_in.append(next_visual_input)
# Create the encoder ops for current and next visual input. Not that these encoders are siamese.
encoded_visual = self.create_visual_observation_encoder(self.visual_in[i], self.curiosity_enc_size,
self.swish, 1, "visual_obs_encoder", False)
encoded_next_visual = self.create_visual_observation_encoder(self.next_visual_in[i], self.curiosity_enc_size,
self.swish, 1, "visual_obs_encoder", True)
visual_encoders.append(encoded_visual)
next_visual_encoders.append(encoded_next_visual)
hidden_visual = tf.concat(visual_encoders, axis=1)
hidden_next_visual = tf.concat(next_visual_encoders, axis=1)
encoded_state_list.append(hidden_visual)
encoded_next_state_list.append(hidden_next_visual)
if self.o_size > 0:
# Create input op for next (t+1) vector observation.
self.next_vector_obs = tf.placeholder(shape=[None, self.o_size], dtype=tf.float32,
name='next_vector_observation')
# Create the encoder ops for current and next vector input. Not that these encoders are siamese.
encoded_vector_obs = self.create_continuous_observation_encoder(self.vector_in, self.curiosity_enc_size,
self.swish, 2, "vector_obs_encoder", False)
encoded_next_vector_obs = self.create_continuous_observation_encoder(self.next_vector_obs,
self.curiosity_enc_size, self.swish,
2, "vector_obs_encoder", True)
encoded_state_list.append(encoded_vector_obs)
encoded_next_state_list.append(encoded_next_vector_obs)
encoded_state = tf.concat(encoded_state_list, axis=1)
encoded_next_state = tf.concat(encoded_next_state_list, axis=1)
return encoded_state, encoded_next_state
def create_inverse_model(self, encoded_state, encoded_next_state):
"""
Creates inverse model TensorFlow ops for Curiosity module.
Predicts action taken given current and future encoded states.
:param encoded_state: Tensor corresponding to encoded current state.
:param encoded_next_state: Tensor corresponding to encoded next state.
"""
combined_input = tf.concat([encoded_state, encoded_next_state], axis=1)
hidden = tf.layers.dense(combined_input, 256, activation=self.swish)
if self.brain.vector_action_space_type == "continuous":
pred_action = tf.layers.dense(hidden, self.a_size, activation=None)
squared_difference = tf.reduce_sum(tf.squared_difference(pred_action, self.selected_actions), axis=1)
self.inverse_loss = tf.reduce_mean(squared_difference)
else:
pred_action = tf.layers.dense(hidden, self.a_size, activation=tf.nn.softmax)
cross_entropy = tf.reduce_sum(-tf.log(pred_action + 1e-10) * self.selected_actions, axis=1)
self.inverse_loss = tf.reduce_mean(cross_entropy)
def create_forward_model(self, encoded_state, encoded_next_state):
"""
Creates forward model TensorFlow ops for Curiosity module.
Predicts encoded future state based on encoded current state and given action.
:param encoded_state: Tensor corresponding to encoded current state.
:param encoded_next_state: Tensor corresponding to encoded next state.
"""
combined_input = tf.concat([encoded_state, self.selected_actions], axis=1)
hidden = tf.layers.dense(combined_input, 256, activation=self.swish)
pred_next_state = tf.layers.dense(hidden, self.curiosity_enc_size, activation=None)
squared_difference = 0.5 * tf.reduce_sum(tf.squared_difference(pred_next_state, encoded_next_state), axis=1)
self.intrinsic_reward = tf.clip_by_value(self.curiosity_strength * squared_difference, 0, 1)
self.forward_loss = tf.reduce_mean(squared_difference)
def create_ppo_optimizer(self, probs, old_probs, value, entropy, beta, epsilon, lr, max_step):
"""
Creates training-specific Tensorflow ops for PPO models.

:param lr: Learning rate
:param max_step: Total number of training steps.
"""
self.returns_holder = tf.placeholder(shape=[None], dtype=tf.float32, name='discounted_rewards')
self.advantage = tf.placeholder(shape=[None, 1], dtype=tf.float32, name='advantages')
self.learning_rate = tf.train.polynomial_decay(lr, self.global_step, max_step, 1e-10, power=1.0)

decay_beta = tf.train.polynomial_decay(beta, self.global_step, max_step, 1e-5, power=1.0)
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
self.mask = tf.equal(self.mask_input, 1.0)
mask = tf.equal(self.mask_input, 1.0)
clipped_value_estimate = self.old_value + tf.clip_by_value(tf.reduce_sum(value, axis=1) - self.old_value,
- decay_epsilon, decay_epsilon)

self.value_loss = tf.reduce_mean(tf.boolean_mask(tf.maximum(v_opt_a, v_opt_b), self.mask))
self.value_loss = tf.reduce_mean(tf.boolean_mask(tf.maximum(v_opt_a, v_opt_b), mask))
self.r_theta = probs / (old_probs + 1e-10)
self.p_opt_a = self.r_theta * self.advantage
self.p_opt_b = tf.clip_by_value(self.r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * self.advantage
self.policy_loss = -tf.reduce_mean(tf.boolean_mask(tf.minimum(self.p_opt_a, self.p_opt_b), self.mask))
r_theta = probs / (old_probs + 1e-10)
p_opt_a = r_theta * self.advantage
p_opt_b = tf.clip_by_value(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * self.advantage
self.policy_loss = -tf.reduce_mean(tf.boolean_mask(tf.minimum(p_opt_a, p_opt_b), mask))
tf.boolean_mask(entropy, self.mask))
tf.boolean_mask(entropy, mask))
if self.use_curiosity:
self.loss += 10 * (0.2 * self.forward_loss + 0.8 * self.inverse_loss)
self.update_batch = optimizer.minimize(self.loss)

160
python/unitytrainers/ppo/trainer.py


class PPOTrainer(Trainer):
"""The PPOTrainer is an implementation of the PPO algorythm."""
"""The PPOTrainer is an implementation of the PPO algorithm."""
def __init__(self, sess, env, brain_name, trainer_parameters, training, seed):
"""

:param training: Whether the trainer is set for training.
"""
self.param_keys = ['batch_size', 'beta', 'buffer_size', 'epsilon', 'gamma', 'hidden_units', 'lambd',
'learning_rate',
'max_steps', 'normalize', 'num_epoch', 'num_layers', 'time_horizon', 'sequence_length',
'summary_freq',
'use_recurrent', 'graph_scope', 'summary_path', 'memory_size']
'learning_rate', 'max_steps', 'normalize', 'num_epoch', 'num_layers',
'time_horizon', 'sequence_length', 'summary_freq', 'use_recurrent',
'graph_scope', 'summary_path', 'memory_size', 'use_curiosity', 'curiosity_strength',
'curiosity_enc_size']
for k in self.param_keys:
if k not in trainer_parameters:

super(PPOTrainer, self).__init__(sess, env, brain_name, trainer_parameters, training)
self.use_recurrent = trainer_parameters["use_recurrent"]
self.use_curiosity = bool(trainer_parameters['use_curiosity'])
self.has_updated = False
self.m_size = None
if self.use_recurrent:
self.m_size = trainer_parameters["memory_size"]

normalize=trainer_parameters['normalize'],
use_recurrent=trainer_parameters['use_recurrent'],
num_layers=int(trainer_parameters['num_layers']),
m_size=self.m_size)
m_size=self.m_size,
use_curiosity=bool(trainer_parameters['use_curiosity']),
curiosity_strength=float(trainer_parameters['curiosity_strength']),
curiosity_enc_size=float(trainer_parameters['encoding_size']))
if self.use_curiosity:
stats['forward_loss'] = []
stats['inverse_loss'] = []
stats['intrinsic_reward'] = []
self.intrinsic_rewards = {}
self.stats = stats
self.training_buffer = Buffer()

self.is_continuous_observation = (env.brains[brain_name].vector_observation_space_type == "continuous")
self.use_observations = (env.brains[brain_name].number_visual_observations > 0)
self.use_states = (env.brains[brain_name].vector_observation_space_size > 0)
self.use_visual_obs = (env.brains[brain_name].number_visual_observations > 0)
self.use_vector_obs = (env.brains[brain_name].vector_observation_space_size > 0)
self.summary_path = trainer_parameters['summary_path']
if not os.path.exists(self.summary_path):
os.makedirs(self.summary_path)

def take_action(self, all_brain_info: AllBrainInfo):
"""
Decides actions given state/observation information, and takes them in environment.
Decides actions given observations information, and takes them in environment.
:param all_brain_info: A dictionary of brain names and BrainInfo from environment.
:return: a tuple containing action, memories, values and an object
to be passed to add experiences

run_list.append(self.model.output_pre)
if self.use_recurrent:
feed_dict[self.model.prev_action] = np.reshape(curr_brain_info.previous_vector_actions, [-1])
if self.use_observations:
for i, _ in enumerate(curr_brain_info.visual_observations):
feed_dict[self.model.visual_in[i]] = curr_brain_info.visual_observations[i]
if self.use_states:
feed_dict[self.model.vector_in] = curr_brain_info.vector_observations
if self.use_recurrent:
if self.use_visual_obs:
for i, _ in enumerate(curr_brain_info.visual_observations):
feed_dict[self.model.visual_in[i]] = curr_brain_info.visual_observations[i]
if self.use_vector_obs:
feed_dict[self.model.vector_in] = curr_brain_info.vector_observations
self.use_states and self.trainer_parameters['normalize']):
self.use_vector_obs and self.trainer_parameters['normalize']):
new_mean, new_variance = self.running_average(
curr_brain_info.vector_observations, steps, self.model.running_mean, self.model.running_variance)
feed_dict[self.model.new_mean] = new_mean

curr_info = curr_all_info[self.brain_name]
next_info = next_all_info[self.brain_name]
intrinsic_rewards = np.array([])
if self.use_curiosity:
feed_dict = {self.model.batch_size: len(curr_info.vector_observations), self.model.sequence_length: 1}
run_list = [self.model.intrinsic_reward]
if self.is_continuous_action:
run_list.append(self.model.output)
else:
feed_dict[self.model.action_holder] = np.reshape(take_action_outputs[self.model.output], [-1])
if self.use_visual_obs:
for i, _ in enumerate(curr_info.visual_observations):
feed_dict[self.model.visual_in[i]] = curr_info.visual_observations[i]
feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[i]
if self.use_vector_obs:
feed_dict[self.model.vector_in] = curr_info.vector_observations
feed_dict[self.model.next_vector_obs] = next_info.vector_observations
if self.use_recurrent:
feed_dict[self.model.prev_action] = np.reshape(curr_info.previous_vector_actions, [-1])
if curr_info.memories.shape[1] == 0:
curr_info.memories = np.zeros((len(curr_info.agents), self.m_size))
feed_dict[self.model.memory_in] = curr_info.memories
run_list += [self.model.memory_out]
intrinsic_rewards = self.sess.run(self.model.intrinsic_reward, feed_dict=feed_dict) * \
float(self.has_updated)
for agent_id in curr_info.agents:
self.training_buffer[agent_id].last_brain_info = curr_info
self.training_buffer[agent_id].last_take_action_outputs = take_action_outputs

idx = stored_info.agents.index(agent_id)
next_idx = next_info.agents.index(agent_id)
if not stored_info.local_done[idx]:
if self.use_observations:
if self.use_visual_obs:
self.training_buffer[agent_id]['observations%d' % i].append(stored_info.visual_observations[i][idx])
if self.use_states:
self.training_buffer[agent_id]['states'].append(stored_info.vector_observations[idx])
self.training_buffer[agent_id]['visual_obs%d' % i].append(
stored_info.visual_observations[i][idx])
self.training_buffer[agent_id]['next_visual_obs%d' % i].append(
next_info.visual_observations[i][idx])
if self.use_vector_obs:
self.training_buffer[agent_id]['vector_obs'].append(stored_info.vector_observations[idx])
self.training_buffer[agent_id]['next_vector_obs'].append(
next_info.vector_observations[next_idx])
if self.use_recurrent:
if stored_info.memories.shape[1] == 0:
stored_info.memories = np.zeros((len(stored_info.agents), self.m_size))

actions_pre = stored_take_action_outputs[self.model.output_pre]
self.training_buffer[agent_id]['actions_pre'].append(actions_pre[idx])
if self.is_continuous_action:
self.training_buffer[agent_id]['actions_pre'].append(actions_pre[idx])
self.training_buffer[agent_id]['rewards'].append(next_info.rewards[next_idx])
if self.use_curiosity:
self.training_buffer[agent_id]['rewards'].append(next_info.rewards[next_idx] +
intrinsic_rewards[next_idx])
else:
self.training_buffer[agent_id]['rewards'].append(next_info.rewards[next_idx])
if self.use_curiosity:
if agent_id not in self.intrinsic_rewards:
self.intrinsic_rewards[agent_id] = 0
self.intrinsic_rewards[agent_id] += intrinsic_rewards[next_idx]
if not next_info.local_done[next_idx]:
if agent_id not in self.episode_steps:
self.episode_steps[agent_id] = 0

for l in range(len(info.agents)):
agent_actions = self.training_buffer[info.agents[l]]['actions']
if ((info.local_done[l] or len(agent_actions) > self.trainer_parameters['time_horizon'])
and len(agent_actions) > 0):
and len(agent_actions) > 0):
agent_id = info.agents[l]
if info.local_done[l] and not info.max_reached[l]:
value_next = 0.0

else:
bootstrapping_info = info
idx = l
feed_dict = {self.model.batch_size: len(bootstrapping_info.vector_observations), self.model.sequence_length: 1}
if self.use_observations:
feed_dict = {self.model.batch_size: len(bootstrapping_info.vector_observations),
self.model.sequence_length: 1}
if self.use_visual_obs:
if self.use_states:
if self.use_vector_obs:
bootstrapping_info.memories = np.zeros((len(bootstrapping_info.vector_observations), self.m_size))
bootstrapping_info.memories = np.zeros(
(len(bootstrapping_info.vector_observations), self.m_size))
feed_dict[self.model.memory_in] = bootstrapping_info.memories
if not self.is_continuous_action and self.use_recurrent:
feed_dict[self.model.prev_action] = np.reshape(bootstrapping_info.previous_vector_actions, [-1])

self.stats['episode_length'].append(self.episode_steps[agent_id])
self.cumulative_rewards[agent_id] = 0
self.episode_steps[agent_id] = 0
if self.use_curiosity:
self.stats['intrinsic_reward'].append(self.intrinsic_rewards[agent_id])
self.intrinsic_rewards[agent_id] = 0
def end_episode(self):
"""

self.cumulative_rewards[agent_id] = 0
for agent_id in self.episode_steps:
self.episode_steps[agent_id] = 0
if self.use_curiosity:
for agent_id in self.intrinsic_rewards:
self.intrinsic_rewards[agent_id] = 0
def is_ready_update(self):
"""

"""
num_epoch = self.trainer_parameters['num_epoch']
n_sequences = max(int(self.trainer_parameters['batch_size'] / self.sequence_length), 1)
total_v, total_p = [], []
value_total, policy_total, forward_total, inverse_total = [], [], [], []
advantages = self.training_buffer.update_buffer['advantages'].get_batch()
self.training_buffer.update_buffer['advantages'].set(
(advantages - advantages.mean()) / (advantages.std() + 1e-10))

if self.use_recurrent:
feed_dict[self.model.prev_action] = np.array(
_buffer['prev_action'][start:end]).reshape([-1])
if self.use_states:
if self.use_vector_obs:
_buffer['states'][start:end]).reshape(
_buffer['vector_obs'][start:end]).reshape(
if self.use_curiosity:
feed_dict[self.model.next_vector_obs] = np.array(
_buffer['next_vector_obs'][start:end]).reshape(
[-1,
self.brain.vector_observation_space_size * self.brain.num_stacked_vector_observations])
_buffer['states'][start:end]).reshape([-1, self.brain.num_stacked_vector_observations])
if self.use_observations:
_buffer['vector_obs'][start:end]).reshape([-1, self.brain.num_stacked_vector_observations])
if self.use_visual_obs:
_obs = np.array(_buffer['observations%d' % i][start:end])
_obs = np.array(_buffer['visual_obs%d' % i][start:end])
if self.use_curiosity:
for i, _ in enumerate(self.model.visual_in):
_obs = np.array(_buffer['next_visual_obs%d' % i][start:end])
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.model.next_visual_in[i]] = _obs.reshape([-1, _w, _h, _c])
feed_dict[self.model.memory_in] = np.array(_buffer['memory'][start:end])[:, 0, :]
v_loss, p_loss, _ = self.sess.run(
[self.model.value_loss, self.model.policy_loss,
self.model.update_batch], feed_dict=feed_dict)
total_v.append(v_loss)
total_p.append(np.abs(p_loss))
self.stats['value_loss'].append(np.mean(total_v))
self.stats['policy_loss'].append(np.mean(total_p))
mem_in = np.array(_buffer['memory'][start:end])[:, 0, :]
feed_dict[self.model.memory_in] = mem_in
run_list = [self.model.value_loss, self.model.policy_loss, self.model.update_batch]
if self.use_curiosity:
run_list.extend([self.model.forward_loss, self.model.inverse_loss])
values = self.sess.run(run_list, feed_dict=feed_dict)
self.has_updated = True
run_out = dict(zip(run_list, values))
value_total.append(run_out[self.model.value_loss])
policy_total.append(np.abs(run_out[self.model.policy_loss]))
if self.use_curiosity:
inverse_total.append(run_out[self.model.inverse_loss])
forward_total.append(run_out[self.model.forward_loss])
self.stats['value_loss'].append(np.mean(value_total))
self.stats['policy_loss'].append(np.mean(policy_total))
if self.use_curiosity:
self.stats['forward_loss'].append(np.mean(forward_total))
self.stats['inverse_loss'].append(np.mean(inverse_total))
self.training_buffer.reset_update_buffer()
def discount_rewards(r, gamma=0.99, value_next=0.0):

3
python/unitytrainers/trainer_controller.py


take_action_outputs[brain_name]) = trainer.take_action(curr_info)
new_info = self.env.step(vector_action=take_action_vector, memory=take_action_memories,
text_action=take_action_text)
for brain_name, trainer in self.trainers.items():
trainer.add_experiences(curr_info, new_info, take_action_outputs[brain_name])
trainer.process_experiences(curr_info, new_info)

curr_info = new_info
# Final save Tensorflow model
if global_step != 0 and self.train_model:
self._save_model(sess, steps=global_step, saver=saver)
self._save_model(sess, steps=global_step, saver=saver)
except KeyboardInterrupt:
if self.train_model:
self.logger.info("Learning was interrupted. Please wait while the graph is generated.")

10
unity-environment/Assets/ML-Agents/Scripts/Agent.cs


info.vectorObservation.Add(observation.z);
info.vectorObservation.Add(observation.w);
}
/// <summary>
/// Adds a boolean observation to the vector observation of the agent.
/// Increases the size of the agent's vector observation by 1.
/// </summary>
/// <param name="observation"></param>
protected void AddVectorObs(bool observation)
{
info.vectorObservation.Add(observation ? 1f : 0f);
}
/// <summary>
/// Sets the text observation.

4
unity-environment/ProjectSettings/TagManager.asset


- orangeBlock
- block
- orangeGoal
- switchOff
- pyramid
- switchOn
- stone
layers:
- Default
- TransparentFX

996
docs/images/pyramids.png

之前 之后
宽度: 1652  |  高度: 914  |  大小: 283 KiB

8
unity-environment/Assets/ML-Agents/Examples/Pyramids.meta


fileFormatVersion: 2
guid: d970a35d94c53437b9ebc56130744a23
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials.meta


fileFormatVersion: 2
guid: e21d506a8dc40465eae48bae17b75e49
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

76
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/Brick.mat


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!21 &2100000
Material:
serializedVersion: 6
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_Name: Brick
m_Shader: {fileID: 46, guid: 0000000000000000f000000000000000, type: 0}
m_ShaderKeywords:
m_LightmapFlags: 4
m_EnableInstancingVariants: 0
m_DoubleSidedGI: 0
m_CustomRenderQueue: -1
stringTagMap: {}
disabledShaderPasses: []
m_SavedProperties:
serializedVersion: 3
m_TexEnvs:
- _BumpMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailAlbedoMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailMask:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailNormalMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _EmissionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MainTex:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MetallicGlossMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _OcclusionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _ParallaxMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
m_Floats:
- _BumpScale: 1
- _Cutoff: 0.5
- _DetailNormalMapScale: 1
- _DstBlend: 0
- _GlossMapScale: 1
- _Glossiness: 0
- _GlossyReflections: 1
- _Metallic: 0
- _Mode: 0
- _OcclusionStrength: 1
- _Parallax: 0.02
- _SmoothnessTextureChannel: 0
- _SpecularHighlights: 1
- _SrcBlend: 1
- _UVSec: 0
- _ZWrite: 1
m_Colors:
- _Color: {r: 0.5660378, g: 0.37924042, b: 0.18956926, a: 1}
- _EmissionColor: {r: 0, g: 0, b: 0, a: 1}

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/Brick.mat.meta


fileFormatVersion: 2
guid: 3061b37049eb04ddfa822f606369919d
NativeFormatImporter:
externalObjects: {}
mainObjectFileID: 2100000
userData:
assetBundleName:
assetBundleVariant:

76
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/Gold.mat


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!21 &2100000
Material:
serializedVersion: 6
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_Name: Gold
m_Shader: {fileID: 46, guid: 0000000000000000f000000000000000, type: 0}
m_ShaderKeywords: _EMISSION
m_LightmapFlags: 1
m_EnableInstancingVariants: 0
m_DoubleSidedGI: 0
m_CustomRenderQueue: -1
stringTagMap: {}
disabledShaderPasses: []
m_SavedProperties:
serializedVersion: 3
m_TexEnvs:
- _BumpMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailAlbedoMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailMask:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailNormalMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _EmissionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MainTex:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MetallicGlossMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _OcclusionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _ParallaxMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
m_Floats:
- _BumpScale: 1
- _Cutoff: 0.5
- _DetailNormalMapScale: 1
- _DstBlend: 0
- _GlossMapScale: 1
- _Glossiness: 0.5
- _GlossyReflections: 1
- _Metallic: 0.926
- _Mode: 0
- _OcclusionStrength: 1
- _Parallax: 0.02
- _SmoothnessTextureChannel: 0
- _SpecularHighlights: 1
- _SrcBlend: 1
- _UVSec: 0
- _ZWrite: 1
m_Colors:
- _Color: {r: 0.5660378, g: 0.5660378, b: 0.5660378, a: 1}
- _EmissionColor: {r: 1, g: 0.505618, b: 0, a: 1}

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/Gold.mat.meta


fileFormatVersion: 2
guid: 860625a788df041aeb5c35413765e9df
NativeFormatImporter:
externalObjects: {}
mainObjectFileID: 2100000
userData:
assetBundleName:
assetBundleVariant:

76
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/agent.mat


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!21 &2100000
Material:
serializedVersion: 6
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_Name: agent
m_Shader: {fileID: 46, guid: 0000000000000000f000000000000000, type: 0}
m_ShaderKeywords:
m_LightmapFlags: 4
m_EnableInstancingVariants: 0
m_DoubleSidedGI: 0
m_CustomRenderQueue: -1
stringTagMap: {}
disabledShaderPasses: []
m_SavedProperties:
serializedVersion: 3
m_TexEnvs:
- _BumpMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailAlbedoMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailMask:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailNormalMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _EmissionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MainTex:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MetallicGlossMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _OcclusionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _ParallaxMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
m_Floats:
- _BumpScale: 1
- _Cutoff: 0.5
- _DetailNormalMapScale: 1
- _DstBlend: 0
- _GlossMapScale: 1
- _Glossiness: 0.5
- _GlossyReflections: 1
- _Metallic: 0
- _Mode: 0
- _OcclusionStrength: 1
- _Parallax: 0.02
- _SmoothnessTextureChannel: 0
- _SpecularHighlights: 1
- _SrcBlend: 1
- _UVSec: 0
- _ZWrite: 1
m_Colors:
- _Color: {r: 0.24345203, g: 0.4278206, b: 0.503, a: 1}
- _EmissionColor: {r: 0, g: 0, b: 0, a: 1}

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/agent.mat.meta


fileFormatVersion: 2
guid: 7ec5b27785bab4f37a9ae5faa93c92b7
NativeFormatImporter:
externalObjects: {}
mainObjectFileID: 0
userData:
assetBundleName:
assetBundleVariant:

76
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/black.mat


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!21 &2100000
Material:
serializedVersion: 6
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_Name: black
m_Shader: {fileID: 46, guid: 0000000000000000f000000000000000, type: 0}
m_ShaderKeywords:
m_LightmapFlags: 4
m_EnableInstancingVariants: 0
m_DoubleSidedGI: 0
m_CustomRenderQueue: -1
stringTagMap: {}
disabledShaderPasses: []
m_SavedProperties:
serializedVersion: 3
m_TexEnvs:
- _BumpMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailAlbedoMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailMask:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailNormalMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _EmissionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MainTex:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MetallicGlossMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _OcclusionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _ParallaxMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
m_Floats:
- _BumpScale: 1
- _Cutoff: 0.5
- _DetailNormalMapScale: 1
- _DstBlend: 0
- _GlossMapScale: 1
- _Glossiness: 0
- _GlossyReflections: 1
- _Metallic: 0
- _Mode: 0
- _OcclusionStrength: 1
- _Parallax: 0.02
- _SmoothnessTextureChannel: 0
- _SpecularHighlights: 1
- _SrcBlend: 1
- _UVSec: 0
- _ZWrite: 1
m_Colors:
- _Color: {r: 0.19852942, g: 0.19852942, b: 0.19852942, a: 1}
- _EmissionColor: {r: 0, g: 0, b: 0, a: 1}

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/black.mat.meta


fileFormatVersion: 2
guid: 45735b9be79ab49b887c5f09cbb914b9
NativeFormatImporter:
externalObjects: {}
mainObjectFileID: 0
userData:
assetBundleName:
assetBundleVariant:

76
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/ground.mat


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!21 &2100000
Material:
serializedVersion: 6
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_Name: ground
m_Shader: {fileID: 46, guid: 0000000000000000f000000000000000, type: 0}
m_ShaderKeywords:
m_LightmapFlags: 4
m_EnableInstancingVariants: 0
m_DoubleSidedGI: 0
m_CustomRenderQueue: -1
stringTagMap: {}
disabledShaderPasses: []
m_SavedProperties:
serializedVersion: 3
m_TexEnvs:
- _BumpMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailAlbedoMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailMask:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailNormalMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _EmissionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MainTex:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MetallicGlossMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _OcclusionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _ParallaxMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
m_Floats:
- _BumpScale: 1
- _Cutoff: 0.5
- _DetailNormalMapScale: 1
- _DstBlend: 0
- _GlossMapScale: 1
- _Glossiness: 0.2
- _GlossyReflections: 1
- _Metallic: 0.2
- _Mode: 0
- _OcclusionStrength: 1
- _Parallax: 0.02
- _SmoothnessTextureChannel: 0
- _SpecularHighlights: 1
- _SrcBlend: 1
- _UVSec: 0
- _ZWrite: 1
m_Colors:
- _Color: {r: 0.49056602, g: 0.4321629, b: 0.3910644, a: 1}
- _EmissionColor: {r: 0, g: 0, b: 0, a: 1}

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/ground.mat.meta


fileFormatVersion: 2
guid: 7f92b7019e6e5485fa6201e1db7c5658
NativeFormatImporter:
externalObjects: {}
mainObjectFileID: 0
userData:
assetBundleName:
assetBundleVariant:

76
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/red.mat


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!21 &2100000
Material:
serializedVersion: 6
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_Name: red
m_Shader: {fileID: 46, guid: 0000000000000000f000000000000000, type: 0}
m_ShaderKeywords:
m_LightmapFlags: 4
m_EnableInstancingVariants: 0
m_DoubleSidedGI: 0
m_CustomRenderQueue: -1
stringTagMap: {}
disabledShaderPasses: []
m_SavedProperties:
serializedVersion: 3
m_TexEnvs:
- _BumpMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailAlbedoMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailMask:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailNormalMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _EmissionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MainTex:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MetallicGlossMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _OcclusionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _ParallaxMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
m_Floats:
- _BumpScale: 1
- _Cutoff: 0.5
- _DetailNormalMapScale: 1
- _DstBlend: 0
- _GlossMapScale: 1
- _Glossiness: 0
- _GlossyReflections: 1
- _Metallic: 0
- _Mode: 0
- _OcclusionStrength: 1
- _Parallax: 0.02
- _SmoothnessTextureChannel: 0
- _SpecularHighlights: 1
- _SrcBlend: 1
- _UVSec: 0
- _ZWrite: 1
m_Colors:
- _Color: {r: 0.72794116, g: 0.35326555, b: 0.35326555, a: 1}
- _EmissionColor: {r: 0, g: 0, b: 0, a: 1}

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/red.mat.meta


fileFormatVersion: 2
guid: fee4f54a87e3a4da494cc22082809bb4
NativeFormatImporter:
externalObjects: {}
mainObjectFileID: 0
userData:
assetBundleName:
assetBundleVariant:

76
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/wall.mat


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!21 &2100000
Material:
serializedVersion: 6
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_Name: wall
m_Shader: {fileID: 46, guid: 0000000000000000f000000000000000, type: 0}
m_ShaderKeywords:
m_LightmapFlags: 4
m_EnableInstancingVariants: 0
m_DoubleSidedGI: 0
m_CustomRenderQueue: -1
stringTagMap: {}
disabledShaderPasses: []
m_SavedProperties:
serializedVersion: 3
m_TexEnvs:
- _BumpMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailAlbedoMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailMask:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailNormalMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _EmissionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MainTex:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MetallicGlossMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _OcclusionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _ParallaxMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
m_Floats:
- _BumpScale: 1
- _Cutoff: 0.5
- _DetailNormalMapScale: 1
- _DstBlend: 0
- _GlossMapScale: 1
- _Glossiness: 0.5
- _GlossyReflections: 1
- _Metallic: 0
- _Mode: 0
- _OcclusionStrength: 1
- _Parallax: 0.02
- _SmoothnessTextureChannel: 0
- _SpecularHighlights: 1
- _SrcBlend: 1
- _UVSec: 0
- _ZWrite: 1
m_Colors:
- _Color: {r: 0.44705883, g: 0.4509804, b: 0.4627451, a: 0}
- _EmissionColor: {r: 0, g: 0, b: 0, a: 1}

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/wall.mat.meta


fileFormatVersion: 2
guid: 816cdc1b97eae4fe1bd0e092e5f7ed04
NativeFormatImporter:
externalObjects: {}
mainObjectFileID: 0
userData:
assetBundleName:
assetBundleVariant:

76
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/white.mat


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!21 &2100000
Material:
serializedVersion: 6
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_Name: white
m_Shader: {fileID: 46, guid: 0000000000000000f000000000000000, type: 0}
m_ShaderKeywords:
m_LightmapFlags: 4
m_EnableInstancingVariants: 0
m_DoubleSidedGI: 0
m_CustomRenderQueue: -1
stringTagMap: {}
disabledShaderPasses: []
m_SavedProperties:
serializedVersion: 3
m_TexEnvs:
- _BumpMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailAlbedoMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailMask:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailNormalMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _EmissionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MainTex:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MetallicGlossMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _OcclusionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _ParallaxMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
m_Floats:
- _BumpScale: 1
- _Cutoff: 0.5
- _DetailNormalMapScale: 1
- _DstBlend: 0
- _GlossMapScale: 1
- _Glossiness: 0
- _GlossyReflections: 1
- _Metallic: 0
- _Mode: 0
- _OcclusionStrength: 1
- _Parallax: 0.02
- _SmoothnessTextureChannel: 0
- _SpecularHighlights: 1
- _SrcBlend: 1
- _UVSec: 0
- _ZWrite: 1
m_Colors:
- _Color: {r: 1, g: 1, b: 1, a: 1}
- _EmissionColor: {r: 0, g: 0, b: 0, a: 1}

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Materials/white.mat.meta


fileFormatVersion: 2
guid: 83456930795894bb5b51f4e8a620bc8b
NativeFormatImporter:
externalObjects: {}
mainObjectFileID: 0
userData:
assetBundleName:
assetBundleVariant:

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs.meta


fileFormatVersion: 2
guid: 3ce93d04f41114481ac56aefa2c93bb2
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

1001
unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs/AreaPB.prefab
文件差异内容过多而无法显示
查看文件

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs/AreaPB.prefab.meta


fileFormatVersion: 2
guid: bd804431e808a492bb5658bcd296e58e
NativeFormatImporter:
externalObjects: {}
mainObjectFileID: 100100000
userData:
assetBundleName:
assetBundleVariant:

1001
unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs/BrickPyramid.prefab
文件差异内容过多而无法显示
查看文件

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs/BrickPyramid.prefab.meta


fileFormatVersion: 2
guid: 8be2b3870e2cd4ad8bbf080059b2a132
NativeFormatImporter:
externalObjects: {}
mainObjectFileID: 100100000
userData:
assetBundleName:
assetBundleVariant:

1001
unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs/StonePyramid.prefab
文件差异内容过多而无法显示
查看文件

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Prefabs/StonePyramid.prefab.meta


fileFormatVersion: 2
guid: 41512dd84b60643ceb3855fcf9d7d318
NativeFormatImporter:
externalObjects: {}
mainObjectFileID: 100100000
userData:
assetBundleName:
assetBundleVariant:

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Scenes.meta


fileFormatVersion: 2
guid: f391d733a889c4e2f92d1cfdf0912976
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

1001
unity-environment/Assets/ML-Agents/Examples/Pyramids/Scenes/Pyramids.unity
文件差异内容过多而无法显示
查看文件

7
unity-environment/Assets/ML-Agents/Examples/Pyramids/Scenes/Pyramids.unity.meta


fileFormatVersion: 2
guid: 35afbc150a44b4aa69ca04685486b5c4
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts.meta


fileFormatVersion: 2
guid: 4c6b273d7fcab4956958a9049c2a850c
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

18
unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAcademy.cs


using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.UI;
public class PyramidAcademy : Academy
{
public override void AcademyReset()
{
}
public override void AcademyStep()
{
}
}

11
unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAcademy.cs.meta


fileFormatVersion: 2
guid: dba8df9c8b16946dc88d331a301d0ab3
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

117
unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs


using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using Random = UnityEngine.Random;
public class PyramidAgent : Agent
{
public GameObject area;
private PyramidArea myArea;
private Rigidbody agentRb;
private RayPerception rayPer;
private PyramidSwitch switchLogic;
public GameObject areaSwitch;
public bool useVectorObs;
public override void InitializeAgent()
{
base.InitializeAgent();
agentRb = GetComponent<Rigidbody>();
myArea = area.GetComponent<PyramidArea>();
rayPer = GetComponent<RayPerception>();
switchLogic = areaSwitch.GetComponent<PyramidSwitch>();
}
public override void CollectObservations()
{
if (useVectorObs)
{
const float rayDistance = 35f;
float[] rayAngles = {20f, 90f, 160f, 45f, 135f, 70f, 110f};
float[] rayAngles1 = {25f, 95f, 165f, 50f, 140f, 75f, 115f};
float[] rayAngles2 = {15f, 85f, 155f, 40f, 130f, 65f, 105f};
string[] detectableObjects = {"block", "wall", "goal", "switchOff", "switchOn", "stone"};
AddVectorObs(rayPer.Perceive(rayDistance, rayAngles, detectableObjects, 0f, 0f));
AddVectorObs(rayPer.Perceive(rayDistance, rayAngles1, detectableObjects, 0f, 5f));
AddVectorObs(rayPer.Perceive(rayDistance, rayAngles2, detectableObjects, 0f, 10f));
AddVectorObs(switchLogic.GetState());
AddVectorObs(transform.InverseTransformDirection(agentRb.velocity));
}
}
public void MoveAgent(float[] act)
{
var dirToGo = Vector3.zero;
var rotateDir = Vector3.zero;
if (brain.brainParameters.vectorActionSpaceType == SpaceType.continuous)
{
dirToGo = transform.forward * Mathf.Clamp(act[0], -1f, 1f);
rotateDir = transform.up * Mathf.Clamp(act[1], -1f, 1f);
}
else
{
var action = Mathf.FloorToInt(act[0]);
switch (action)
{
case 0:
dirToGo = transform.forward * 1f;
break;
case 1:
dirToGo = transform.forward * -1f;
break;
case 2:
rotateDir = transform.up * 1f;
break;
case 3:
rotateDir = transform.up * -1f;
break;
}
}
transform.Rotate(rotateDir, Time.deltaTime * 200f);
agentRb.AddForce(dirToGo * 2f, ForceMode.VelocityChange);
}
public override void AgentAction(float[] vectorAction, string textAction)
{
AddReward(-1f / agentParameters.maxStep);
MoveAgent(vectorAction);
}
public override void AgentReset()
{
var enumerable = Enumerable.Range(0, 9).OrderBy(x => Guid.NewGuid()).Take(9);
var items = enumerable.ToArray();
myArea.CleanPyramidArea();
agentRb.velocity = Vector3.zero;
myArea.PlaceObject(gameObject, items[0]);
transform.rotation = Quaternion.Euler(new Vector3(0f, Random.Range(0, 360)));
switchLogic.ResetSwitch(items[1], items[2]);
myArea.CreateStonePyramid(1, items[3]);
myArea.CreateStonePyramid(1, items[4]);
myArea.CreateStonePyramid(1, items[5]);
myArea.CreateStonePyramid(1, items[6]);
myArea.CreateStonePyramid(1, items[7]);
myArea.CreateStonePyramid(1, items[8]);
}
private void OnCollisionEnter(Collision collision)
{
if (collision.gameObject.CompareTag("goal"))
{
SetReward(2f);
Done();
}
}
public override void AgentOnDone()
{
}
}

11
unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs.meta


fileFormatVersion: 2
guid: b8db44472779248d3be46895c4d562d5
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

55
unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidArea.cs


using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class PyramidArea : Area
{
public GameObject pyramid;
public GameObject stonePyramid;
public GameObject[] spawnAreas;
public int numPyra;
public float range;
public void CreatePyramid(int numObjects, int spawnAreaIndex)
{
CreateObject(numObjects, pyramid, spawnAreaIndex);
}
public void CreateStonePyramid(int numObjects, int spawnAreaIndex)
{
CreateObject(numObjects, stonePyramid, spawnAreaIndex);
}
private void CreateObject(int numObjects, GameObject desiredObject, int spawnAreaIndex)
{
for (var i = 0; i < numObjects; i++)
{
var newObject = Instantiate(desiredObject, Vector3.zero,
Quaternion.Euler(0f, 0f, 0f), transform);
PlaceObject(newObject, spawnAreaIndex);
}
}
public void PlaceObject(GameObject objectToPlace, int spawnAreaIndex)
{
var spawnTransform = spawnAreas[spawnAreaIndex].transform;
var xRange = spawnTransform.localScale.x / 2.1f;
var zRange = spawnTransform.localScale.z / 2.1f;
objectToPlace.transform.position = new Vector3(Random.Range(-xRange, xRange), 2f, Random.Range(-zRange, zRange))
+ spawnTransform.position;
}
public void CleanPyramidArea()
{
foreach (Transform child in transform) if (child.CompareTag("pyramid"))
{
Destroy(child.gameObject);
}
}
public override void ResetArea()
{
}
}

11
unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidArea.cs.meta


fileFormatVersion: 2
guid: e048de15d0b8a4643a75c2b09981792e
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

47
unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidSwitch.cs


using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class PyramidSwitch : MonoBehaviour
{
public Material onMaterial;
public Material offMaterial;
public GameObject myButton;
private bool state;
private GameObject area;
private PyramidArea areaComponent;
private int pyramidIndex;
public bool GetState()
{
return state;
}
private void Start()
{
area = gameObject.transform.parent.gameObject;
areaComponent = area.GetComponent<PyramidArea>();
}
public void ResetSwitch(int spawnAreaIndex, int pyramidSpawnIndex)
{
areaComponent.PlaceObject(gameObject, spawnAreaIndex);
state = false;
pyramidIndex = pyramidSpawnIndex;
tag = "switchOff";
transform.rotation = Quaternion.Euler(0f, 0f, 0f);
myButton.GetComponent<Renderer>().material = offMaterial;
}
private void OnCollisionEnter(Collision other)
{
if (other.gameObject.CompareTag("agent") && state == false)
{
myButton.GetComponent<Renderer>().material = onMaterial;
state = true;
areaComponent.CreatePyramid(1, pyramidIndex);
tag = "switchOn";
}
}
}

11
unity-environment/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidSwitch.cs.meta


fileFormatVersion: 2
guid: abd01d977612744528db278c446e9a11
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

8
unity-environment/Assets/ML-Agents/Examples/Pyramids/TFModels.meta


fileFormatVersion: 2
guid: c577914cc4ace45baa8c4dd54778ae00
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

1001
unity-environment/Assets/ML-Agents/Examples/Pyramids/TFModels/Pyramids.bytes
文件差异内容过多而无法显示
查看文件

7
unity-environment/Assets/ML-Agents/Examples/Pyramids/TFModels/Pyramids.bytes.meta


fileFormatVersion: 2
guid: 97f59608051e548d9a79803894260d13
TextScriptImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存