浏览代码

Add Seeding, MaxStepReached, and Bootstrapping fix (#303)

* Add ability to seed learning (numpy, tensorflow, and Unity) with `--seed` flag.
* Add `maxStepReached` flag to Agents and Academy.
* Change way value bootstrapping works in PPO to take advantage of timeouts.
* Default size of GridWorld changed to 5x5 in order to validate bootstrapping changes.
/develop-generalizationTraining-TrainerController
GitHub 7 年前
当前提交
36d58cee
共有 20 个文件被更改,包括 1272 次插入177 次删除
  1. 23
      python/learn.py
  2. 27
      python/trainer_configurations.yaml
  3. 109
      python/trainers/buffer.py
  4. 2
      python/trainers/ghost_trainer.py
  5. 5
      python/trainers/imitation_trainer.py
  6. 7
      python/trainers/ppo_models.py
  7. 13
      python/trainers/ppo_trainer.py
  8. 4
      python/unityagents/brain.py
  9. 13
      python/unityagents/environment.py
  10. 9
      unity-environment/Assets/ML-Agents/Examples/Banana/Scripts/BananaAgent.cs
  11. 132
      unity-environment/Assets/ML-Agents/Examples/GridWorld/GridWorld.unity
  12. 24
      unity-environment/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  13. 8
      unity-environment/Assets/ML-Agents/Scripts/Academy.cs
  14. 5
      unity-environment/Assets/ML-Agents/Scripts/Agent.cs
  15. 27
      unity-environment/Assets/ML-Agents/Scripts/Brain.cs
  16. 31
      unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs
  17. 1001
      unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld_5x5.bytes
  18. 9
      unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld_5x5.bytes.meta
  19. 0
      /unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld_3x3.bytes
  20. 0
      /unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld_3x3.bytes.meta

23
python/learn.py


# # Unity ML Agents
# ## ML-Agent Learning (PPO)
# ## ML-Agent Learning
# Launches trainers for each External Brains in a Unity Environemnt
import logging

from datetime import datetime
from trainers.ghost_trainer import GhostTrainer
from trainers.ppo_models import *
from trainers.ppo_trainer import PPOTrainer

Options:
--help Show this message.
--curriculum=<file> Curriculum json file for environment [default: None].
--slow Whether to run the game at training speed [default: False].
--run-path=<path> The sub-directory name for model and summary statistics [default: ppo].
--run-path=<path> The sub-directory name for model and summary statistics [default: ppo].
--seed=<n> Random seed used for training [default: None].
--slow Whether to run the game at training speed [default: False].
--train Whether to train model, or only run inference [default: False].
--worker-id=<n> Number to add to communication port (5005). Used for multi-environment [default: 0].
'''

# General parameters
model_path = './models/{}'.format(str(options['--run-path']))
seed = int(options['--seed'])
load_model = options['--load']
train_model = options['--train']
save_freq = int(options['--save-freq'])

lesson = int(options['--lesson'])
fast_simulation = not bool(options['--slow'])
env = UnityEnvironment(file_name=env_name, worker_id=worker_id, curriculum=curriculum_file)
if seed is None:
seed = datetime.now()
np.random.seed(seed)
tf.set_random_seed(seed)
env = UnityEnvironment(file_name=env_name, worker_id=worker_id, curriculum=curriculum_file, seed=seed)
env.curriculum.set_lesson_number(lesson)
logger.info(str(env))

trainer_parameters_dict[brain_name]['original_brain_parameters'] = trainer_parameters_dict[
trainer_parameters_dict[brain_name]['brain_to_copy']]
trainers[brain_name] = GhostTrainer(sess, env, brain_name, trainer_parameters_dict[brain_name],
train_model)
train_model, seed)
train_model)
train_model, seed)
train_model)
train_model, seed)
for k, t in trainers.items():
logger.info(t)

27
python/trainer_configurations.yaml


default:
batch_size: 256
beta: 2.5e-3
buffer_size: 5000
batch_size: 32
beta: 5.0e-3
buffer_size: 512
max_steps: 1.0e6
normalize: false
max_steps: 5.0e4
normalize: true
sequence_length: 32
sequence_length: 64
summary_freq: 1000
use_recurrent: false

batch_size: 1024
batch_size: 1000
buffer_size: 10000
hidden_units: 64
GridWorldBrain:
batch_size: 32
num_layers: 1
hidden_units: 256
beta: 5.0e-3
gamma: 0.9
buffer_size: 256
max_steps: 5.0e5
summary_freq: 2000
time_horizon: 5
TurretBrain: ExampleBrain
ghost-HunterBrain:

109
python/trainers/buffer.py


class Buffer(dict):
"""
Buffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id.
Buffer also contains an update_buffer that corresponds to the buffer used when updating the model.
"""
Buffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id.
Buffer also contains an update_buffer that corresponds to the buffer used when updating the model.
"""
AgentBuffer contains a dictionary of AgentBufferFields. Each agent has his own AgentBuffer.
The keys correspond to the name of the field. Example: state, action
"""
AgentBuffer contains a dictionary of AgentBufferFields. Each agent has his own AgentBuffer.
The keys correspond to the name of the field. Example: state, action
"""
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to his
AgentBufferField with the append method.
"""
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to his
AgentBufferField with the append method.
"""
def __str__(self):
return str(np.array(self).shape)

Ads a list of np.arrays to the end of the list of np.arrays.
:param data: The np.array list to append.
"""
Ads a list of np.arrays to the end of the list of np.arrays.
:param data: The np.array list to append.
"""
Sets the list of np.array to the input data
:param data: The np.array list to be set.
"""
Sets the list of np.array to the input data
:param data: The np.array list to be set.
"""
Retrieve the last batch_size elements of length training_length
from the list of np.array
:param batch_size: The number of elements to retrieve. If None:
All elements will be retrieved.
:param training_length: The length of the sequence to be retrieved. If
None: only takes one element.
:param sequential: If true and training_length is not None: the elements
will not repeat in the sequence. [a,b,c,d,e] with training_length = 2 and
sequential=True gives [[0,a],[b,c],[d,e]]. If sequential=False gives
[[a,b],[b,c],[c,d],[d,e]]
"""
Retrieve the last batch_size elements of length training_length
from the list of np.array
:param batch_size: The number of elements to retrieve. If None:
All elements will be retrieved.
:param training_length: The length of the sequence to be retrieved. If
None: only takes one element.
:param sequential: If true and training_length is not None: the elements
will not repeat in the sequence. [a,b,c,d,e] with training_length = 2 and
sequential=True gives [[0,a],[b,c],[d,e]]. If sequential=False gives
[[a,b],[b,c],[c,d],[d,e]]
"""
if training_length is None:
# When the training length is None, the method returns a list of elements,
# not a list of sequences of elements.

def reset_field(self):
"""
Resets the AgentBufferField
"""
Resets the AgentBufferField
"""
self[:] = []
def __str__(self):

"""
Resets the AgentBuffer
"""
Resets the AgentBuffer
"""
for k in self.keys():
self[k].reset_field()

def check_length(self, key_list):
"""
Some methods will require that some fields have the same length.
check_length will return true if the fields in key_list
have the same length.
:param key_list: The fields which length will be compared
"""
Some methods will require that some fields have the same length.
check_length will return true if the fields in key_list
have the same length.
:param key_list: The fields which length will be compared
"""
if len(key_list) < 2:
return True
l = None

def shuffle(self, key_list=None):
"""
Shuffles the fields in key_list in a consistent way: The reordering will
be the same accross fields.
:param key_list: The fields that must be shuffled.
"""
Shuffles the fields in key_list in a consistent way: The reordering will
be the same accross fields.
:param key_list: The fields that must be shuffled.
"""
if key_list is None:
key_list = list(self.keys())
if not self.check_length(key_list):

def reset_update_buffer(self):
"""
Resets the update buffer
"""
Resets the update buffer
"""
Resets the update buffer and all the local local_buffers
"""
# self.update_buffer.reset_agent()
Resets the update buffer and all the local local_buffers
"""
agent_ids = list(self.keys())
for k in agent_ids:
self[k].reset_agent()

Appends the buffer of an agent to the update buffer.
:param agent_id: The id of the agent which data will be appended
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
Appends the buffer of an agent to the update buffer.
:param agent_id: The id of the agent which data will be appended
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
if key_list is None:
key_list = self[agent_id].keys()
if not self[agent_id].check_length(key_list):

def append_all_agent_batch_to_update_buffer(self, key_list=None, batch_size=None, training_length=None):
"""
Appends the buffer of all agents to the update buffer.
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
Appends the buffer of all agents to the update buffer.
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
for agent_id in self.keys():
self.append_update_buffer(agent_id, key_list, batch_size, training_length)

2
python/trainers/ghost_trainer.py


class GhostTrainer(Trainer):
"""Keeps copies of a PPOTrainer past graphs and uses them to other Trainers."""
def __init__(self, sess, env, brain_name, trainer_parameters, training):
def __init__(self, sess, env, brain_name, trainer_parameters, training, seed):
"""
Responsible for saving and reusing past models.
:param sess: Tensorflow session.

5
python/trainers/imitation_trainer.py


logger = logging.getLogger("unityagents")
class ImitationNN(object):
def __init__(self, state_size, action_size, h_size, lr, action_type, n_layers):
self.state = tf.placeholder(shape=[None, state_size], dtype=tf.float32, name="state")

self.update = optimizer.minimize(self.loss)
def __init__(self, sess, env, brain_name, trainer_parameters, training):
def __init__(self, sess, env, brain_name, trainer_parameters, training, seed):
"""
Responsible for collecting experiences and training PPO model.
:param sess: Tensorflow session.

s_size = self.brain.state_space_size * self.brain.stacked_states
a_size = self.brain.action_space_size
with tf.variable_scope(self.variable_scope):
tf.set_random_seed(seed)
self.network = ImitationNN(state_size = s_size,
action_size = a_size,
h_size = int(trainer_parameters['hidden_units']),

7
python/trainers/ppo_models.py


streams = []
for i in range(num_streams):
conv1 = tf.layers.conv2d(self.observation_in[-1], 16, kernel_size=[8, 8], strides=[4, 4],
activation=activation)
activation=tf.nn.elu)
activation=activation)
activation=tf.nn.elu)
hidden = c_layers.flatten(conv2)
for j in range(num_layers):

height_size, width_size = brain.camera_resolutions[i]['height'], brain.camera_resolutions[i]['width']
bw = brain.camera_resolutions[i]['blackAndWhite']
visual_encoders.append(
self.create_visual_encoder(height_size, width_size, bw, h_size, 2, tf.nn.tanh, num_layers)[0])
self.create_visual_encoder(height_size, width_size, bw, h_size, 1, tf.nn.elu, num_layers)[0])
hidden_visual = [tf.concat(visual_encoders, axis=1)]
if brain.state_space_size > 0:
s_size = brain.state_space_size * brain.stacked_states

self.memory_out = tf.identity(self.memory_out, name = 'recurrent_out')
self.policy = tf.layers.dense(hidden, a_size, activation=None, use_bias=False,
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01))
self.probs = tf.nn.softmax(self.policy, name="action_probs")
self.output = tf.multinomial(self.policy, 1)

13
python/trainers/ppo_trainer.py


class PPOTrainer(Trainer):
"""The PPOTrainer is an implementation of the PPO algorythm."""
def __init__(self, sess, env, brain_name, trainer_parameters, training):
def __init__(self, sess, env, brain_name, trainer_parameters, training, seed):
"""
Responsible for collecting experiences and training PPO model.
:param sess: Tensorflow session.

self.m_size = env.brains[brain_name].memory_space_size
self.sequence_length = trainer_parameters["sequence_length"]
if self.use_recurrent:
if (self.m_size == 0):
if self.m_size == 0:
elif (self.m_size % 4 != 0):
elif self.m_size % 4 != 0:
tf.set_random_seed(seed)
self.model = create_agent_model(env.brains[brain_name],
lr=float(trainer_parameters['learning_rate']),
h_size=int(trainer_parameters['hidden_units']),

"""
steps = self.get_step
info = info[self.brain_name]
epsi = None
feed_dict = {self.model.batch_size: len(info.states), self.model.sequence_length: 1}
run_list = [self.model.output, self.model.probs, self.model.value, self.model.entropy,
self.model.learning_rate]

for l in range(len(info.agents)):
agent_actions = self.training_buffer[info.agents[l]]['actions']
if ((info.local_done[l] or len(agent_actions) > self.trainer_parameters['time_horizon'])
and len(agent_actions) > 0):
if info.local_done[l]:
and len(agent_actions) > 0):
if info.local_done[l] and not info.max_reached[l]:
value_next = 0.0
else:
feed_dict = {self.model.batch_size: len(info.states), self.model.sequence_length: 1}

4
python/unityagents/brain.py


class BrainInfo:
def __init__(self, observation, state, memory=None, reward=None, agents=None, local_done=None, action =None):
def __init__(self, observation, state, memory=None, reward=None, agents=None, local_done=None,
action=None, max_reached=None):
"""
Describes experience at current step of all agents linked to a brain.
"""

self.rewards = reward
self.local_done = local_done
self.max_reached = max_reached
self.agents = agents
self.previous_actions = action

13
python/unityagents/environment.py


from .exception import UnityEnvironmentException, UnityActionException, UnityTimeOutException
from .curriculum import Curriculum
from datetime import datetime
from PIL import Image
from sys import platform

class UnityEnvironment(object):
def __init__(self, file_name, worker_id=0,
base_port=5005, curriculum=None):
base_port=5005, curriculum=None,
seed=None):
"""
Starts a new unity environment and establishes a connection with the environment.
Notice: Currently communication between Unity and Python takes place over an open socket without authentication.

:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
:int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
"""
if seed is None:
seed = datetime.now()
atexit.register(self.close)
self.port = base_port + worker_id

# Launch Unity environment
proc1 = subprocess.Popen(
[launch_string,
'--port', str(self.port)])
'--port', str(self.port),
'--seed', str(seed)])
self._socket.settimeout(30)
try:

rewards = state_dict["rewards"]
dones = state_dict["dones"]
agents = state_dict["agents"]
# actions = state_dict["actions"]
maxes = state_dict["maxes"]
if n_agent > 0:
actions = np.array(state_dict["actions"]).reshape((n_agent, -1))
else:

observations.append(np.array(obs_n))
self._data[b] = BrainInfo(observations, states, memories, rewards, agents, dones, actions)
self._data[b] = BrainInfo(observations, states, memories, rewards, agents, dones, actions, max_reached=maxes)
try:
self._global_done = self._conn.recv(self._buffer_size).decode('utf-8') == 'True'

9
unity-environment/Assets/ML-Agents/Examples/Banana/Scripts/BananaAgent.cs


{
public GameObject area;
bool frozen;
bool shoot;
float frozenTime;
Rigidbody agentRB;
public float turnSpeed;

Vector3 localVelocity = transform.InverseTransformDirection(agentRB.velocity);
state.Add(localVelocity.x);
state.Add(localVelocity.z);
state.Add(System.Convert.ToInt32(frozen));
state.Add(System.Convert.ToInt32(shoot));
return state;
}

public void MoveAgent(float[] act)
{
shoot = false;
Monitor.Log("Bananas", bananas, MonitorType.text, gameObject.transform);
if (Time.time > frozenTime + 4f) {

Vector3 dirToGo = Vector3.zero;
Vector3 rotateDir = Vector3.zero;
bool shoot = false;
dirToGo = transform.forward * Mathf.Clamp(act[0], 0f, 1f);
dirToGo = transform.forward * Mathf.Clamp(act[0], -1f, 1f);
rotateDir = transform.up * Mathf.Clamp(act[1], -1f, 1f);
if (Mathf.Clamp(act[2], 0f, 1f) > 0.5f) {
shoot = true;

public override void AgentReset()
{
Unfreeze();
shoot = false;
agentRB.velocity = Vector3.zero;
bananas = 0;
myLazer.transform.localScale = new Vector3(0f, 0f, 0f);

132
unity-environment/Assets/ML-Agents/Examples/GridWorld/GridWorld.unity


waitTime: 0.2
isInference: 0
trainingConfiguration:
width: 300
height: 300
width: 84
height: 84
timeScale: 10
timeScale: 20
targetFrameRate: -1
inferenceConfiguration:
width: 1280

targetFrameRate: 60
defaultResetParameters:
- key: gridSize
value: 3
value: 5
maxStepReached: 0
episodeCount: 0
currentStep: 0
actorObjs: []

m_Father: {fileID: 0}
m_RootOrder: 1
m_LocalEulerAnglesHint: {x: 90, y: 0, z: 0}
--- !u!114 &183002472
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 0}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 41e9bda8f3cf1492fa74926a530f6f70, type: 3}
m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)
m_EditorClassIdentifier:
broadcast: 1
continuousPlayerActions: []
discretePlayerActions:
- key: 273
value: 0
- key: 274
value: 1
- key: 276
value: 2
- key: 275
value: 3
defaultAction: -1
brain: {fileID: 1535917239}
--- !u!1 &231883441
GameObject:
m_ObjectHideFlags: 0

m_Father: {fileID: 0}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 45, y: 45, z: 0}
--- !u!114 &299967728
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 0}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 943466ab374444748a364f9d6c3e2fe2, type: 3}
m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)
m_EditorClassIdentifier:
broadcast: 1
brain: {fileID: 1535917239}
--- !u!1 &363761396
GameObject:
m_ObjectHideFlags: 0

m_OcclusionCulling: 1
m_StereoConvergence: 10
m_StereoSeparation: 0.022
--- !u!114 &651992469
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 0}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 35813a1be64e144f887d7d5f15b963fa, type: 3}
m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)
m_EditorClassIdentifier:
brain: {fileID: 1535917239}
--- !u!1 &742849316
GameObject:
m_ObjectHideFlags: 0

m_Father: {fileID: 0}
m_RootOrder: 3
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!114 &780827900
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 0}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 41e9bda8f3cf1492fa74926a530f6f70, type: 3}
m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)
m_EditorClassIdentifier:
broadcast: 1
continuousPlayerActions: []
discretePlayerActions:
- key: 273
value: 0
- key: 274
value: 1
- key: 276
value: 2
- key: 275
value: 3
defaultAction: -1
brain: {fileID: 1535917239}
--- !u!1 &959566328
GameObject:
m_ObjectHideFlags: 0

m_Father: {fileID: 486401524}
m_RootOrder: 1
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!114 &1104052206
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 0}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 943466ab374444748a364f9d6c3e2fe2, type: 3}
m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)
m_EditorClassIdentifier:
broadcast: 1
brain: {fileID: 1535917239}
--- !u!1 &1208586857
GameObject:
m_ObjectHideFlags: 0

- component: {fileID: 1535917238}
- component: {fileID: 1535917239}
m_Layer: 0
m_Name: GridBrain
m_Name: GridWorldBrain
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0

m_EditorClassIdentifier:
brainParameters:
stateSize: 0
stackedStates: 1
actionSize: 4
memorySize: 0
cameraResolutions:

- Right
actionSpaceType: 0
stateSpaceType: 1
brainType: 0
brainType: 3
- {fileID: 183002472}
- {fileID: 1104052206}
- {fileID: 651992469}
- {fileID: 1737050307}
instanceID: 12658
- {fileID: 780827900}
- {fileID: 299967728}
- {fileID: 2102493396}
- {fileID: 1810907653}
instanceID: 12840
--- !u!1 &1553342942
GameObject:
m_ObjectHideFlags: 0

m_Father: {fileID: 486401524}
m_RootOrder: 4
m_LocalEulerAnglesHint: {x: 0, y: 90, z: 0}
--- !u!114 &1737050307
--- !u!114 &1810907653
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}

m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 8b23992c8eb17439887f5e944bf04a40, type: 3}
m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)
m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)
graphModel: {fileID: 4900000, guid: 426162c47484e466d8378d2321a5617c, type: 3}
graphModel: {fileID: 4900000, guid: 69bc5a8ff4f4b465b8b374d88541c8bd, type: 3}
graphScope:
graphPlaceholders: []
BatchSizePlaceholderName: batch_size

brain: {fileID: 1535917239}
observations:
- {fileID: 489340228}
maxStep: 100
maxStep: 25
stackedStates: []
maxStepReached: 0
value: 0
CumulativeReward: 0
stepCounter: 0

academy: {fileID: 2047663}
--- !u!114 &2102493396
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 0}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 35813a1be64e144f887d7d5f15b963fa, type: 3}
m_Name: (Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)(Clone)
m_EditorClassIdentifier:
brain: {fileID: 1535917239}

24
unity-environment/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


if (blockTest.Where(col => col.gameObject.tag == "wall").ToArray().Length == 0)
{
transform.position = targetPos;
}
}
void OnTriggerEnter(Collider col)
{
if (col.gameObject.CompareTag("goal"))
{
reward = 1f;
done = true;
}
if (col.gameObject.CompareTag("pit"))
{
reward = -1f;
done = true;
if (blockTest.Where(col => col.gameObject.tag == "goal").ToArray().Length == 1)
{
done = true;
reward = 1;
}
if (blockTest.Where(col => col.gameObject.tag == "pit").ToArray().Length == 1)
{
done = true;
reward = -1;
}
}
}

8
unity-environment/Assets/ML-Agents/Scripts/Academy.cs


[HideInInspector]
public bool done;
/// <summary>
/// The max step reached.
/// </summary>
[HideInInspector]
public bool maxStepReached;
/**< \brief Increments each time the environment is reset. */
[HideInInspector]
public int episodeCount;

if (done)
{
brain.SendDone();
brain.SendMaxReached();
}
brain.ResetIfDone();

if ((currentStep >= maxSteps) && maxSteps > 0)
{
done = true;
maxStepReached = true;
}
if ((framesSinceAction > frameToSkip) || done)

5
unity-environment/Assets/ML-Agents/Scripts/Agent.cs


[HideInInspector]
public bool done;
/**< \brief Whether or not the max step is reached*/
[HideInInspector]
public bool maxStepReached;
/**< \brief The current value estimate of the agent */
/**< When using an External brain, you can pass value estimates to the
* agent at every step using env.Step(actions, values).

if ((stepCounter > maxStep) && (maxStep > 0))
{
done = true;
maxStepReached = true;
}
}

27
unity-environment/Assets/ML-Agents/Scripts/Brain.cs


public Dictionary<int, List<Camera>> currentCameras = new Dictionary<int, List<Camera>>(32);
public Dictionary<int, float> currentRewards = new Dictionary<int, float>(32);
public Dictionary<int, bool> currentDones = new Dictionary<int, bool>(32);
public Dictionary<int, bool> currentMaxes = new Dictionary<int, bool>(32);
public Dictionary<int, float[]> currentActions = new Dictionary<int, float[]>(32);
public Dictionary<int, float[]> currentMemories = new Dictionary<int, float[]>(32);

currentCameras.Clear();
currentRewards.Clear();
currentDones.Clear();
currentMaxes.Clear();
currentActions.Clear();
currentMemories.Clear();

currentCameras.Add(idAgent.Key, observations);
currentRewards.Add(idAgent.Key, idAgent.Value.reward);
currentDones.Add(idAgent.Key, idAgent.Value.done);
currentMaxes.Add(idAgent.Key, idAgent.Value.maxStepReached);
currentActions.Add(idAgent.Key, idAgent.Value.agentStoredAction);
currentMemories.Add(idAgent.Key, idAgent.Value.memory);
}

return currentDones;
}
/// Collects the done flag of all the agents which subscribe to this brain
/// and returns a dictionary {id -> done}
public Dictionary<int, bool> CollectMaxes()
{
currentMaxes.Clear();
foreach (KeyValuePair<int, Agent> idAgent in agents)
{
currentMaxes.Add(idAgent.Key, idAgent.Value.maxStepReached);
}
return currentMaxes;
}
/// Collects the actions of all the agents which subscribe to this brain
/// and returns a dictionary {id -> action}
public Dictionary<int, float[]> CollectActions()

}
}
///Sets all the agents which subscribe to the brain to maxStepReached
public void SendMaxReached()
{
foreach (KeyValuePair<int, Agent> idAgent in agents)
{
idAgent.Value.maxStepReached = true;
}
}
/// Uses coreBrain to call SendState on the CoreBrain
public void SendState()
{

{
agent.Reset();
agent.done = false;
agent.maxStepReached = false;
}
}

{
agent.ResetReward();
agent.done = false;
agent.maxStepReached = false;
}
}
}

31
unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs


Dictionary<string, Dictionary<int, float[]>> storedMemories;
Dictionary<string, Dictionary<int, float>> storedValues;
const int messageLength = 12000;
const int defaultNumAgents = 32;
const int defaultNumObservations = 32;
List<float> concatenatedStates = new List<float>(1024);
List<float> concatenatedRewards = new List<float>(32);
List<float> concatenatedMemories = new List<float>(1024);
List<bool> concatenatedDones = new List<bool>(32);
List<float> concatenatedActions = new List<float>(1024);
List<float> concatenatedStates = new List<float>(defaultNumAgents*defaultNumObservations);
List<float> concatenatedRewards = new List<float>(defaultNumAgents);
List<float> concatenatedMemories = new List<float>(defaultNumAgents * defaultNumObservations);
List<bool> concatenatedDones = new List<bool>(defaultNumAgents);
List<bool> concatenatedMaxes = new List<bool>(defaultNumAgents);
List<float> concatenatedActions = new List<float>(defaultNumAgents * defaultNumObservations);
private int comPort;
int comPort;
int randomSeed;
const int messageLength = 12000;
StreamWriter logWriter;
string logPath;

public List<float> actions;
public List<float> memories;
public List<bool> dones;
public List<bool> maxes;
}
StepMessage sMessage;

{
string[] args = System.Environment.GetCommandLineArgs();
var inputPort = "";
var inputSeed = "";
for (int i = 0; i < args.Length; i++)
{
if (args[i] == "--port")

if (args[i] == "--seed")
{
inputSeed = args[i + 1];
}
randomSeed = int.Parse(inputSeed);
Random.InitState(randomSeed);
}
/// Sends Academy parameters to external agent

concatenatedRewards.Clear();
concatenatedMemories.Clear();
concatenatedDones.Clear();
concatenatedMaxes.Clear();
concatenatedActions.Clear();
foreach (int id in current_agents[brainName])

concatenatedMemories.AddRange(brain.currentMemories[id].ToList());
concatenatedDones.Add(brain.currentDones[id]);
concatenatedMaxes.Add(brain.currentMaxes[id]);
concatenatedActions.AddRange(brain.currentActions[id].ToList());
}

sMessage.actions = concatenatedActions;
sMessage.memories = concatenatedMemories;
sMessage.dones = concatenatedDones;
sMessage.maxes = concatenatedMaxes;
sMessageString = JsonUtility.ToJson(sMessage);
sender.Send(AppendLength(Encoding.ASCII.GetBytes(sMessageString)));

1001
unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld_5x5.bytes
文件差异内容过多而无法显示
查看文件

9
unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld_5x5.bytes.meta


fileFormatVersion: 2
guid: 69bc5a8ff4f4b465b8b374d88541c8bd
timeCreated: 1517599883
licenseType: Pro
TextScriptImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

/unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld.bytes → /unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld_3x3.bytes

/unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld.bytes.meta → /unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld_3x3.bytes.meta

正在加载...
取消
保存