浏览代码

formating and added documentation

/develop-generalizationTraining-TrainerController
vincentpierre 7 年前
当前提交
cde3c8f7
共有 20 个文件被更改,包括 727 次插入465 次删除
  1. 3
      README.md
  2. 20
      python/Basics.ipynb
  3. 14
      python/PPO.ipynb
  4. 4
      python/README.md
  5. 14
      python/ppo.py
  6. 2
      python/ppo/trainer.py
  7. 2
      python/setup.py
  8. 2
      python/unityagents/__init__.py
  9. 4
      python/unityagents/environment.py
  10. 15
      unity-environment/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAcademy.cs
  11. 180
      unity-environment/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
  12. 32
      unity-environment/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DDecision.cs
  13. 151
      unity-environment/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAcademy.cs
  14. 225
      unity-environment/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  15. 43
      unity-environment/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAcademy.cs
  16. 119
      unity-environment/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
  17. 153
      unity-environment/Assets/ML-Agents/Examples/Tennis/Scripts/hitWall.cs
  18. 2
      unity-environment/Assets/ML-Agents/Template/Scripts/TemplateDecision.cs
  19. 11
      unity-environment/README.md
  20. 196
      images/unity-wide.png

3
README.md


![alt text](images/banner.png "Unity ML - Agents")
<img src="images/unity-wide.png" align="middle" width="3000"/>
# Unity ML - Agents
**Unity Machine Learning Agents** allows researchers and developers to create games and simulations using the Unity Editor which serve as environments where intelligent agents can be trained using reinforcement learning, neuroevolution, or other machine learning methods through a simple-to-use Python API. For more information, see the [wiki page](../../wiki).

20
python/Basics.ipynb


"source": [
"# Unity ML Agents\n",
"## Environment Basics\n",
"This notebook contains a walkthrough of the basic functions of the Python API for Unity ML Agents. For instructions on building a Unity environment, see [here](https://github.com/Unity-Technologies/python-rl-control/tree/master/Projects)."
"This notebook contains a walkthrough of the basic functions of the Python API for Unity ML Agents. For instructions on building a Unity environment, see [here](https://github.com/Unity-Technologies/ml-agents/wiki/Getting-Started-with-Balance-Ball)."
]
},
{

{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"env = UnityEnvironment(file_name=env_name)\n",

{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Reset the environment\n",

"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 2",
"display_name": "Python 3",
"name": "python2"
"name": "python3"
"version": 2
"version": 3
"pygments_lexer": "ipython2",
"version": "2.7.10"
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,

14
python/PPO.ipynb


"run_path = \"ppo\" # The sub-directory name for model and summary statistics\n",
"load_model = False # Whether to load a saved model.\n",
"train_model = True # Whether to train the model.\n",
"summary_freq = 1000 # Frequency at which to save training statistics.\n",
"save_freq = 20000 # Frequency at which to save model.\n",
"summary_freq = 10000 # Frequency at which to save training statistics.\n",
"save_freq = 50000 # Frequency at which to save model.\n",
"env_name = \"simple\" # Name of the training environment file.\n",
"\n",
"### Algorithm-specific parameters for tuning\n",

"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 2",
"display_name": "Python 3",
"name": "python2"
"name": "python3"
"version": 2
"version": 3
"pygments_lexer": "ipython2",
"version": "2.7.10"
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,

4
python/README.md


![alt text](../images/banner.png "Unity ML - Agents")
# Unity ML - Agents (Python API)
## Python Setup

To train using PPO without the notebook, run:
`python3 ppo.py <env_name> --train-model`
`python3 ppo.py <env_name> --train`
For a list of additional hyperparameters, run:

14
python/ppo.py


--help Show this message.
--max-step=<n> Maximum number of steps to run environment [default: 5e6].
--run-path=<path> The sub-directory name for model and summary statistics [default: ppo].
--load-model Whether to load the model or randomly initialize [default: False].
--train-model Whether to train model, or only run inference [default: True].
--summary-freq=<n> Frequency at which to save training statistics [default: 5000].
--save-freq=<n> Frequency at which to save model [default: 20000].
--load Whether to load the model or randomly initialize [default: False].
--train Whether to train model, or only run inference [default: True].
--summary-freq=<n> Frequency at which to save training statistics [default: 10000].
--save-freq=<n> Frequency at which to save model [default: 50000].
--gamma=<n> Reward discount rate [default: 0.99].
--lambd=<n> Lambda parameter for GAE [default: 0.95].
--time-horizon=<n> How many steps to collect per agent before adding to buffer [default: 2048].

max_steps = float(options['--max-step'])
model_path = './models/{}'.format(str(options['--run-path']))
summary_path = './summaries/{}'.format(str(options['--run-path']))
load_model = options['--load-model']
train_model = options['--train-model']
load_model = options['--load']
train_model = options['--train']
summary_freq = int(options['--summary-freq'])
save_freq = int(options['--save-freq'])
env_name = options['<env>']

summary_writer = tf.summary.FileWriter(summary_path)
info = env.reset(train_mode=train_model)[brain_name]
trainer = Trainer(ppo_model, sess, info, is_continuous, use_observations)
while steps <= max_steps:
while steps <= max_steps or not train_model:
if env.global_done:
info = env.reset(train_mode=train_model)[brain_name]
# Decide and take an action

2
python/ppo/trainer.py


summary = tf.Summary()
for key in self.stats:
if len(self.stats[key]) > 0:
stat_mean = np.mean(self.stats[key])
stat_mean = float(np.mean(self.stats[key]))
summary.value.add(tag='Info/{}'.format(key), simple_value=stat_mean)
self.stats[key] = []
summary_writer.add_summary(summary, steps)

2
python/setup.py


description='Unity Machine Learning Agents',
license='Apache License 2.0',
author='Unity Technologies',
author_email='ml@unity3d.com',
author_email='ML-Agents@unity3d.com',
url='https://github.com/Unity-Technologies/ml-agents',
packages=find_packages(exclude = ['ppo']),
install_requires = required,

2
python/unityagents/__init__.py


from .environment import *
from .brain import *
from .exception import *

4
python/unityagents/environment.py


import subprocess
import signal
from .brain import *
from .exception import *
from .brain import BrainInfo, BrainParameters
from .exception import UnityEnvironmentException, UnityActionException
from PIL import Image
from sys import platform

15
unity-environment/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAcademy.cs


using System.Collections.Generic;
using UnityEngine;
public class Ball3DAcademy : Academy {
public class Ball3DAcademy : Academy
{
public override void AcademyReset()
{
public override void AcademyReset()
{
}
}
}

180
unity-environment/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs


using System.Collections.Generic;
using UnityEngine;
public class Ball3DAgent : Agent {
[Header("Specific to Ball3D")]
public GameObject ball;
public class Ball3DAgent : Agent
{
[Header("Specific to Ball3D")]
public GameObject ball;
public override List<float> CollectState()
{
List<float> state = new List<float>();
state.Add (gameObject.transform.rotation.z);
state.Add(gameObject.transform.rotation.x);
state.Add ((ball.transform.position.x - gameObject.transform.position.x) / 5f);
state.Add ((ball.transform.position.y - gameObject.transform.position.y) / 5f);
state.Add((ball.transform.position.z - gameObject.transform.position.z) / 5f);
state.Add (ball.transform.GetComponent<Rigidbody>().velocity.x / 5f);
state.Add (ball.transform.GetComponent<Rigidbody>().velocity.y / 5f);
state.Add(ball.transform.GetComponent<Rigidbody>().velocity.z / 5f);
return state;
}
public override List<float> CollectState()
{
List<float> state = new List<float>();
state.Add(gameObject.transform.rotation.z);
state.Add(gameObject.transform.rotation.x);
state.Add((ball.transform.position.x - gameObject.transform.position.x) / 5f);
state.Add((ball.transform.position.y - gameObject.transform.position.y) / 5f);
state.Add((ball.transform.position.z - gameObject.transform.position.z) / 5f);
state.Add(ball.transform.GetComponent<Rigidbody>().velocity.x / 5f);
state.Add(ball.transform.GetComponent<Rigidbody>().velocity.y / 5f);
state.Add(ball.transform.GetComponent<Rigidbody>().velocity.z / 5f);
return state;
}
// to be implemented by the developer
public override void AgentStep(float[] act)
{
if (brain.brainParameters.actionSpaceType == StateType.continuous) {
float action_z = act [0];
if (action_z > 2f) {
action_z = 2f;
}
if (action_z < -2f) {
action_z = -2f;
}
if ((gameObject.transform.rotation.z < 0.25f && action_z > 0f) ||
(gameObject.transform.rotation.z > -0.25f && action_z < 0f)) {
gameObject.transform.Rotate (new Vector3 (0, 0, 1), action_z);
}
float action_x = act [1];
if (action_x > 2f) {
action_x = 2f;
}
if (action_x < -2f) {
action_x = -2f;
}
if ((gameObject.transform.rotation.x < 0.25f && action_x > 0f) ||
(gameObject.transform.rotation.x > -0.25f && action_x < 0f)) {
gameObject.transform.Rotate (new Vector3 (1, 0, 0), action_x);
}
// to be implemented by the developer
public override void AgentStep(float[] act)
{
if (brain.brainParameters.actionSpaceType == StateType.continuous)
{
float action_z = act[0];
if (action_z > 2f)
{
action_z = 2f;
}
if (action_z < -2f)
{
action_z = -2f;
}
if ((gameObject.transform.rotation.z < 0.25f && action_z > 0f) ||
(gameObject.transform.rotation.z > -0.25f && action_z < 0f))
{
gameObject.transform.Rotate(new Vector3(0, 0, 1), action_z);
}
float action_x = act[1];
if (action_x > 2f)
{
action_x = 2f;
}
if (action_x < -2f)
{
action_x = -2f;
}
if ((gameObject.transform.rotation.x < 0.25f && action_x > 0f) ||
(gameObject.transform.rotation.x > -0.25f && action_x < 0f))
{
gameObject.transform.Rotate(new Vector3(1, 0, 0), action_x);
}
if (done == false) {
reward = 0.1f;
}
} else {
int action = (int)act [0];
if (action == 0 || action == 1)
{
action = (action * 2) - 1;
float changeValue = action * 2f;
if ((gameObject.transform.rotation.z < 0.25f && changeValue > 0f) ||
(gameObject.transform.rotation.z > -0.25f && changeValue < 0f))
{
gameObject.transform.Rotate(new Vector3(0, 0, 1), changeValue);
}
}
if (action == 2 || action == 3)
{
action = ((action - 2) * 2) - 1;
float changeValue = action * 2f;
if ((gameObject.transform.rotation.x < 0.25f && changeValue > 0f) ||
(gameObject.transform.rotation.x > -0.25f && changeValue < 0f))
{
gameObject.transform.Rotate(new Vector3(1, 0, 0), changeValue);
}
}
if (done == false)
{
reward = 0.1f;
}
}
if ((ball.transform.position.y - gameObject.transform.position.y) < -2f ||
Mathf.Abs (ball.transform.position.x - gameObject.transform.position.x) > 3f ||
Mathf.Abs (ball.transform.position.z - gameObject.transform.position.z) > 3f) {
done = true;
if (done == false)
{
reward = 0.1f;
}
}
else
{
int action = (int)act[0];
if (action == 0 || action == 1)
{
action = (action * 2) - 1;
float changeValue = action * 2f;
if ((gameObject.transform.rotation.z < 0.25f && changeValue > 0f) ||
(gameObject.transform.rotation.z > -0.25f && changeValue < 0f))
{
gameObject.transform.Rotate(new Vector3(0, 0, 1), changeValue);
}
}
if (action == 2 || action == 3)
{
action = ((action - 2) * 2) - 1;
float changeValue = action * 2f;
if ((gameObject.transform.rotation.x < 0.25f && changeValue > 0f) ||
(gameObject.transform.rotation.x > -0.25f && changeValue < 0f))
{
gameObject.transform.Rotate(new Vector3(1, 0, 0), changeValue);
}
}
if (done == false)
{
reward = 0.1f;
}
}
if ((ball.transform.position.y - gameObject.transform.position.y) < -2f ||
Mathf.Abs(ball.transform.position.x - gameObject.transform.position.x) > 3f ||
Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
{
done = true;
}
}
}
}
// to be implemented by the developer
public override void AgentReset()
{
// to be implemented by the developer
public override void AgentReset()
{
gameObject.transform.Rotate(new Vector3(0, 0, 1), Random.Range(-10f, 10f));
ball.GetComponent<Rigidbody>().velocity = new Vector3(0f, 0f, 0f);
gameObject.transform.Rotate(new Vector3(0, 0, 1), Random.Range(-10f, 10f));
ball.GetComponent<Rigidbody>().velocity = new Vector3(0f, 0f, 0f);
}
}
}

32
unity-environment/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DDecision.cs


using System.Collections.Generic;
using UnityEngine;
public class Ball3DDecision : MonoBehaviour, Decision {
public float[] Decide (List<float> state, List<Camera> observation, float reward, bool done, float[] memory)
{
if (gameObject.GetComponent<Brain> ().brainParameters.actionSpaceType == StateType.continuous) {
return new float[4]{ 0f, 0f, 0f, 0.0f };
} else {
return new float[1]{ 1f };
}
}
public class Ball3DDecision : MonoBehaviour, Decision
{
public float[] Decide(List<float> state, List<Camera> observation, float reward, bool done, float[] memory)
{
if (gameObject.GetComponent<Brain>().brainParameters.actionSpaceType == StateType.continuous)
{
return new float[4]{ 0f, 0f, 0f, 0.0f };
public float[] MakeMemory(List<float> state, List<Camera> observation, float reward, bool done, float[] memory)
{
return new float[0];
}
}
else
{
return new float[1]{ 1f };
}
}
public float[] MakeMemory(List<float> state, List<Camera> observation, float reward, bool done, float[] memory)
{
return new float[0];
}
}

151
unity-environment/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAcademy.cs


public class GridAcademy : Academy
{
[HideInInspector]
public List<GameObject> actorObjs;
[HideInInspector]
public string[] players;
[HideInInspector]
public GameObject trueAgent;
[HideInInspector]
public GameObject visualAgent;
[HideInInspector]
public List<GameObject> actorObjs;
[HideInInspector]
public string[] players;
[HideInInspector]
public GameObject trueAgent;
[HideInInspector]
public GameObject visualAgent;
public override void InitializeAcademy()
{
}
public override void InitializeAcademy ()
{
public void SetEnvironment()
{
trueAgent = GameObject.Find("trueAgent");
List<string> playersList = new List<string>();
actorObjs = new List<GameObject>();
for (int i = 0; i < (int)resetParameters["numObstacles"]; i++)
{
playersList.Add("pit");
}
playersList.Add("agent");
}
for (int i = 0; i < (int)resetParameters["numGoals"]; i++)
{
playersList.Add("goal");
}
players = playersList.ToArray();
Camera cam = GameObject.Find("Main Camera").GetComponent<Camera>();
public void SetEnvironment ()
{
trueAgent = GameObject.Find ("trueAgent");
List<string> playersList = new List<string> ();
actorObjs = new List<GameObject> ();
for (int i = 0; i < (int)resetParameters ["numObstacles"]; i++) {
playersList.Add ("pit");
}
playersList.Add ("agent");
cam.transform.position = new Vector3(-((int)resetParameters["gridSize"] - 1) / 2f, (int)resetParameters["gridSize"] * 1.25f, -((int)resetParameters["gridSize"] - 1) / 2f);
cam.orthographicSize = ((int)resetParameters["gridSize"] + 5f) / 2f;
GameObject.Find("Plane").transform.localScale = new Vector3((int)resetParameters["gridSize"] / 10.0f, 1f, (int)resetParameters["gridSize"] / 10.0f);
GameObject.Find("Plane").transform.position = new Vector3(((int)resetParameters["gridSize"] - 1) / 2f, -0.5f, ((int)resetParameters["gridSize"] - 1) / 2f);
GameObject.Find("sN").transform.localScale = new Vector3(1, 1, (int)resetParameters["gridSize"] + 2);
GameObject.Find("sS").transform.localScale = new Vector3(1, 1, (int)resetParameters["gridSize"] + 2);
GameObject.Find("sN").transform.position = new Vector3(((int)resetParameters["gridSize"] - 1) / 2f, 0.0f, (int)resetParameters["gridSize"]);
GameObject.Find("sS").transform.position = new Vector3(((int)resetParameters["gridSize"] - 1) / 2f, 0.0f, -1);
GameObject.Find("sE").transform.localScale = new Vector3(1, 1, (int)resetParameters["gridSize"] + 2);
GameObject.Find("sW").transform.localScale = new Vector3(1, 1, (int)resetParameters["gridSize"] + 2);
GameObject.Find("sE").transform.position = new Vector3((int)resetParameters["gridSize"], 0.0f, ((int)resetParameters["gridSize"] - 1) / 2f);
GameObject.Find("sW").transform.position = new Vector3(-1, 0.0f, ((int)resetParameters["gridSize"] - 1) / 2f);
Camera aCam = GameObject.Find("agentCam").GetComponent<Camera>();
aCam.orthographicSize = ((int)resetParameters["gridSize"]) / 2f;
aCam.transform.position = new Vector3(((int)resetParameters["gridSize"] - 1) / 2f, (int)resetParameters["gridSize"] + 1f, ((int)resetParameters["gridSize"] - 1) / 2f);
for (int i = 0; i < (int)resetParameters ["numGoals"]; i++) {
playersList.Add ("goal");
}
players = playersList.ToArray ();
Camera cam = GameObject.Find ("Main Camera").GetComponent<Camera> ();
}
cam.transform.position = new Vector3 (-((int)resetParameters ["gridSize"] - 1) / 2f, (int)resetParameters ["gridSize"] * 1.25f, -((int)resetParameters ["gridSize"] - 1) / 2f);
cam.orthographicSize = ((int)resetParameters ["gridSize"] + 5f) / 2f;
GameObject.Find ("Plane").transform.localScale = new Vector3 ((int)resetParameters ["gridSize"] / 10.0f, 1f, (int)resetParameters ["gridSize"] / 10.0f);
GameObject.Find ("Plane").transform.position = new Vector3 (((int)resetParameters ["gridSize"] - 1) / 2f, -0.5f, ((int)resetParameters ["gridSize"] - 1) / 2f);
GameObject.Find ("sN").transform.localScale = new Vector3 (1, 1, (int)resetParameters ["gridSize"] + 2);
GameObject.Find ("sS").transform.localScale = new Vector3 (1, 1, (int)resetParameters ["gridSize"] + 2);
GameObject.Find ("sN").transform.position = new Vector3 (((int)resetParameters ["gridSize"] - 1) / 2f, 0.0f, (int)resetParameters ["gridSize"]);
GameObject.Find ("sS").transform.position = new Vector3 (((int)resetParameters ["gridSize"] - 1) / 2f, 0.0f, -1);
GameObject.Find ("sE").transform.localScale = new Vector3 (1, 1, (int)resetParameters ["gridSize"] + 2);
GameObject.Find ("sW").transform.localScale = new Vector3 (1, 1, (int)resetParameters ["gridSize"] + 2);
GameObject.Find ("sE").transform.position = new Vector3 ((int)resetParameters ["gridSize"], 0.0f, ((int)resetParameters ["gridSize"] - 1) / 2f);
GameObject.Find ("sW").transform.position = new Vector3 (-1, 0.0f, ((int)resetParameters ["gridSize"] - 1) / 2f);
Camera aCam = GameObject.Find ("agentCam").GetComponent<Camera> ();
aCam.orthographicSize = ((int)resetParameters ["gridSize"]) / 2f;
aCam.transform.position = new Vector3 (((int)resetParameters ["gridSize"] - 1) / 2f, (int)resetParameters ["gridSize"] + 1f, ((int)resetParameters ["gridSize"] - 1) / 2f);
public override void AcademyReset()
{
foreach (GameObject actor in actorObjs)
{
DestroyImmediate(actor);
}
SetEnvironment();
}
actorObjs = new List<GameObject>();
public override void AcademyReset ()
{
foreach (GameObject actor in actorObjs) {
DestroyImmediate (actor);
}
SetEnvironment ();
HashSet<int> numbers = new HashSet<int>();
while (numbers.Count < players.Length)
{
numbers.Add(Random.Range(0, (int)resetParameters["gridSize"] * (int)resetParameters["gridSize"]));
}
int[] numbersA = Enumerable.ToArray(numbers);
actorObjs = new List<GameObject> ();
for (int i = 0; i < players.Length; i++)
{
int x = (numbersA[i]) / (int)resetParameters["gridSize"];
int y = (numbersA[i]) % (int)resetParameters["gridSize"];
GameObject actorObj = (GameObject)GameObject.Instantiate(Resources.Load(players[i]));
actorObj.transform.position = new Vector3(x, -0.25f, y);
actorObj.name = players[i];
actorObjs.Add(actorObj);
if (players[i] == "agent")
{
trueAgent.transform.position = actorObj.transform.position;
trueAgent.transform.rotation = actorObj.transform.rotation;
visualAgent = actorObj;
}
}
HashSet<int> numbers = new HashSet<int> ();
while (numbers.Count < players.Length) {
numbers.Add (Random.Range (0, (int)resetParameters ["gridSize"] * (int)resetParameters ["gridSize"]));
}
int[] numbersA = Enumerable.ToArray (numbers);
}
for (int i = 0; i < players.Length; i++) {
int x = (numbersA [i]) / (int)resetParameters ["gridSize"];
int y = (numbersA [i]) % (int)resetParameters ["gridSize"];
GameObject actorObj = (GameObject)GameObject.Instantiate (Resources.Load (players [i]));
actorObj.transform.position = new Vector3 (x, -0.25f, y);
actorObj.name = players [i];
actorObjs.Add (actorObj);
if (players [i] == "agent") {
trueAgent.transform.position = actorObj.transform.position;
trueAgent.transform.rotation = actorObj.transform.rotation;
visualAgent = actorObj;
}
}
public override void AcademyStep()
{
}
public override void AcademyStep ()
{
}
}
}

225
unity-environment/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


using System.Net.Sockets;
using System.Text;
public class GridAgent : Agent {
public class GridAgent : Agent
{
[Header("Specific to GridWorld")]
public GridAcademy academy;
[HideInInspector]
public int gridSize;
GameObject trueAgent;
public override void InitializeAgent()
{
trueAgent = gameObject;
gridSize = (int)academy.resetParameters["gridSize"];
}
public override List<float> CollectState()
{
int closestGoalDistance = 2 * (int)academy.resetParameters["gridSize"];
GameObject currentClosestGoal = academy.actorObjs[0];
int closestPitDistance = 2 * (int)academy.resetParameters["gridSize"];
GameObject currentClosestPit = academy.actorObjs[0];
GameObject agent = academy.actorObjs[0];
List<float> state = new List<float>();
foreach (GameObject actor in academy.actorObjs)
{
if (actor.tag == "agent")
{
agent = actor;
state.Add(actor.transform.position.x / (gridSize + 1));
state.Add(actor.transform.position.z / (gridSize + 1));
continue;
}
}
foreach (GameObject actor in academy.actorObjs)
{
if (actor.tag == "goal")
{
int distance = (int)Mathf.Abs(agent.transform.position.x - actor.transform.position.x) + (int)Mathf.Abs(agent.transform.position.z - actor.transform.position.z);
if (closestGoalDistance > distance)
{
closestGoalDistance = distance;
currentClosestGoal = actor;
}
}
if (actor.tag == "pit")
{
int distance = (int)Mathf.Abs(agent.transform.position.x - actor.transform.position.x) + (int)Mathf.Abs(agent.transform.position.z - actor.transform.position.z);
if (closestPitDistance > distance)
{
closestPitDistance = distance;
currentClosestPit = actor;
}
}
}
[Header("Specific to GridWorld")]
public GridAcademy academy;
[HideInInspector]
public int gridSize;
state.Add(currentClosestGoal.transform.position.x / (gridSize + 1));
state.Add(currentClosestGoal.transform.position.z / (gridSize + 1));
state.Add(currentClosestPit.transform.position.x / (gridSize + 1));
state.Add(currentClosestPit.transform.position.z / (gridSize + 1));
GameObject trueAgent;
return state;
}
public override void InitializeAgent(){
trueAgent = gameObject;
gridSize = (int)academy.resetParameters["gridSize"];
}
// to be implemented by the developer
public override void AgentStep(float[] act)
{
public override List<float> CollectState()
{
int closestGoalDistance = 2 * (int)academy.resetParameters ["gridSize"];
GameObject currentClosestGoal = academy.actorObjs[0];
int closestPitDistance = 2 * (int)academy.resetParameters ["gridSize"];
GameObject currentClosestPit = academy.actorObjs[0];
GameObject agent = academy.actorObjs[0];
List<float> state = new List<float> ();
foreach (GameObject actor in academy.actorObjs) {
if (actor.tag == "agent") {
agent = actor;
state.Add (actor.transform.position.x / (gridSize+1));
state.Add (actor.transform.position.z / (gridSize+1));
continue;
}
}
foreach (GameObject actor in academy.actorObjs) {
if (actor.tag == "goal") {
int distance = (int)Mathf.Abs(agent.transform.position.x - actor.transform.position.x) + (int)Mathf.Abs(agent.transform.position.z - actor.transform.position.z);
if (closestGoalDistance > distance){
closestGoalDistance = distance;
currentClosestGoal =actor;
}
}
if (actor.tag == "pit") {
int distance = (int)Mathf.Abs(agent.transform.position.x - actor.transform.position.x) + (int)Mathf.Abs(agent.transform.position.z - actor.transform.position.z);
if (closestPitDistance > distance){
closestPitDistance = distance;
currentClosestPit =actor;
}
}
}
reward = -0.01f;
int action = Mathf.FloorToInt(act[0]);
state.Add (currentClosestGoal.transform.position.x / (gridSize+1));
state.Add (currentClosestGoal.transform.position.z / (gridSize+1));
state.Add (currentClosestPit.transform.position.x / (gridSize+1));
state.Add (currentClosestPit.transform.position.z / (gridSize+1));
// 0 - Forward, 1 - Backward, 2 - Left, 3 - Right
if (action == 3)
{
Collider[] blockTest = Physics.OverlapBox(new Vector3(trueAgent.transform.position.x + 1, 0, trueAgent.transform.position.z), new Vector3(0.3f, 0.3f, 0.3f));
if (blockTest.Where(col => col.gameObject.tag == "wall").ToArray().Length == 0)
{
trueAgent.transform.position = new Vector3(trueAgent.transform.position.x + 1, 0, trueAgent.transform.position.z);
}
}
return state;
}
// to be implemented by the developer
public override void AgentStep(float[] act){
reward = -0.01f;
int action = Mathf.FloorToInt (act [0]);
// 0 - Forward, 1 - Backward, 2 - Left, 3 - Right
if (action == 3) {
Collider[] blockTest = Physics.OverlapBox (new Vector3 (trueAgent.transform.position.x + 1, 0, trueAgent.transform.position.z), new Vector3 (0.3f, 0.3f, 0.3f));
if (blockTest.Where (col => col.gameObject.tag == "wall").ToArray ().Length == 0) {
trueAgent.transform.position = new Vector3 (trueAgent.transform.position.x + 1, 0, trueAgent.transform.position.z);
}
}
if (action == 2) {
Collider[] blockTest = Physics.OverlapBox (new Vector3 (trueAgent.transform.position.x - 1, 0, trueAgent.transform.position.z), new Vector3 (0.3f, 0.3f, 0.3f));
if (blockTest.Where (col => col.gameObject.tag == "wall").ToArray ().Length == 0) {
trueAgent.transform.position = new Vector3 (trueAgent.transform.position.x - 1, 0, trueAgent.transform.position.z);
}
}
if (action == 0) {
Collider[] blockTest = Physics.OverlapBox (new Vector3 (trueAgent.transform.position.x, 0, trueAgent.transform.position.z + 1), new Vector3 (0.3f, 0.3f, 0.3f));
if (blockTest.Where (col => col.gameObject.tag == "wall").ToArray ().Length == 0) {
trueAgent.transform.position = new Vector3 (trueAgent.transform.position.x, 0, trueAgent.transform.position.z + 1);
}
}
if (action == 1) {
Collider[] blockTest = Physics.OverlapBox (new Vector3 (trueAgent.transform.position.x, 0, trueAgent.transform.position.z - 1), new Vector3 (0.3f, 0.3f, 0.3f));
if (blockTest.Where (col => col.gameObject.tag == "wall").ToArray ().Length == 0) {
trueAgent.transform.position = new Vector3 (trueAgent.transform.position.x, 0, trueAgent.transform.position.z - 1);
}
}
Collider[] hitObjects = Physics.OverlapBox (trueAgent.transform.position, new Vector3 (0.3f, 0.3f, 0.3f));
if (hitObjects.Where(col => col.gameObject.tag == "goal").ToArray().Length == 1) {
reward = 1f;
done = true;
}
if (hitObjects.Where(col => col.gameObject.tag == "pit").ToArray().Length == 1) {
reward = -1f;
done = true;
}
//if (trainMode == "train") {
if (true){
academy.visualAgent.transform.position = trueAgent.transform.position;
academy.visualAgent.transform.rotation = trueAgent.transform.rotation;
}
}
if (action == 2)
{
Collider[] blockTest = Physics.OverlapBox(new Vector3(trueAgent.transform.position.x - 1, 0, trueAgent.transform.position.z), new Vector3(0.3f, 0.3f, 0.3f));
if (blockTest.Where(col => col.gameObject.tag == "wall").ToArray().Length == 0)
{
trueAgent.transform.position = new Vector3(trueAgent.transform.position.x - 1, 0, trueAgent.transform.position.z);
}
}
if (action == 0)
{
Collider[] blockTest = Physics.OverlapBox(new Vector3(trueAgent.transform.position.x, 0, trueAgent.transform.position.z + 1), new Vector3(0.3f, 0.3f, 0.3f));
if (blockTest.Where(col => col.gameObject.tag == "wall").ToArray().Length == 0)
{
trueAgent.transform.position = new Vector3(trueAgent.transform.position.x, 0, trueAgent.transform.position.z + 1);
}
}
if (action == 1)
{
Collider[] blockTest = Physics.OverlapBox(new Vector3(trueAgent.transform.position.x, 0, trueAgent.transform.position.z - 1), new Vector3(0.3f, 0.3f, 0.3f));
if (blockTest.Where(col => col.gameObject.tag == "wall").ToArray().Length == 0)
{
trueAgent.transform.position = new Vector3(trueAgent.transform.position.x, 0, trueAgent.transform.position.z - 1);
}
}
// to be implemented by the developer
public override void AgentReset(){
Collider[] hitObjects = Physics.OverlapBox(trueAgent.transform.position, new Vector3(0.3f, 0.3f, 0.3f));
if (hitObjects.Where(col => col.gameObject.tag == "goal").ToArray().Length == 1)
{
reward = 1f;
done = true;
}
if (hitObjects.Where(col => col.gameObject.tag == "pit").ToArray().Length == 1)
{
reward = -1f;
done = true;
}
academy.AcademyReset ();
//if (trainMode == "train") {
if (true)
{
academy.visualAgent.transform.position = trueAgent.transform.position;
academy.visualAgent.transform.rotation = trueAgent.transform.rotation;
}
}
}
// to be implemented by the developer
public override void AgentReset()
{
academy.AcademyReset();
}
}

43
unity-environment/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAcademy.cs


using UnityEngine;
using UnityEngine.UI;
public class TennisAcademy : Academy {
public class TennisAcademy : Academy
{
[Header("Specific to Tennis")]
public GameObject ball;
[Header("Specific to Tennis")]
public GameObject ball;
public override void AcademyReset()
{
float ballOut = Random.Range(4f, 11f);
int flip = Random.Range(0, 2);
if (flip == 0)
{
ball.transform.position = new Vector3(-ballOut, 5f, 5f);
}
else
{
ball.transform.position = new Vector3(ballOut, 5f, 5f);
}
ball.GetComponent<Rigidbody>().velocity = new Vector3(0f, 0f, 0f);
ball.transform.localScale = new Vector3(1,1,1) * resetParameters["ballSize"];
}
public override void AcademyReset()
{
float ballOut = Random.Range(4f, 11f);
int flip = Random.Range(0, 2);
if (flip == 0)
{
ball.transform.position = new Vector3(-ballOut, 5f, 5f);
}
else
{
ball.transform.position = new Vector3(ballOut, 5f, 5f);
}
ball.GetComponent<Rigidbody>().velocity = new Vector3(0f, 0f, 0f);
ball.transform.localScale = new Vector3(1, 1, 1) * resetParameters["ballSize"];
}
public override void AcademyStep()
{
public override void AcademyStep()
{
}
}
}

119
unity-environment/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs


using UnityEngine;
using UnityEngine.UI;
public class TennisAgent : Agent {
[Header("Specific to Tennis")]
public GameObject ball;
public class TennisAgent : Agent
{
[Header("Specific to Tennis")]
public GameObject ball;
public override List<float> CollectState()
{
List<float> state = new List<float>();
public override List<float> CollectState()
{
List<float> state = new List<float>();
state.Add(gameObject.transform.position.y / 2f);
state.Add(invertMult * gameObject.GetComponent<Rigidbody>().velocity.x / 10f);
state.Add(gameObject.GetComponent<Rigidbody>().velocity.y / 10f);
state.Add(gameObject.transform.position.y / 2f);
state.Add(invertMult * gameObject.GetComponent<Rigidbody>().velocity.x / 10f);
state.Add(gameObject.GetComponent<Rigidbody>().velocity.y / 10f);
state.Add(invertMult * ball.transform.position.x / 8f);
state.Add(ball.transform.position.y / 8f);
state.Add(invertMult * ball.GetComponent<Rigidbody>().velocity.x / 10f);
state.Add(ball.GetComponent<Rigidbody>().velocity.y / 10f);
return state;
}
state.Add(invertMult * ball.transform.position.x / 8f);
state.Add(ball.transform.position.y / 8f);
state.Add(invertMult * ball.GetComponent<Rigidbody>().velocity.x / 10f);
state.Add(ball.GetComponent<Rigidbody>().velocity.y / 10f);
return state;
}
// to be implemented by the developer
public override void AgentStep(float[] act)
{
// to be implemented by the developer
public override void AgentStep(float[] act)
{
if (act[0] == 0f) { moveX = invertMult * -0.25f; }
if (act[0] == 1f) { moveX = invertMult * 0.25f; }
if (act[0] == 2f) { moveX = 0.0f; }
if (act[0] == 3f) { moveY = 0.5f; }
if (act[0] == 0f)
{
moveX = invertMult * -0.25f;
}
if (act[0] == 1f)
{
moveX = invertMult * 0.25f;
}
if (act[0] == 2f)
{
moveX = 0.0f;
}
if (act[0] == 3f)
{
moveY = 0.5f;
}
if (gameObject.transform.position.y > -1.9f) {
moveY = 0f;
}
else {
gameObject.GetComponent<Rigidbody>().velocity = new Vector3(0f, moveY * 12f, 0f);
}
if (gameObject.transform.position.y > -1.9f)
{
moveY = 0f;
}
else
{
gameObject.GetComponent<Rigidbody>().velocity = new Vector3(0f, moveY * 12f, 0f);
}
gameObject.transform.position = new Vector3(gameObject.transform.position.x + moveX, gameObject.transform.position.y, 5f);
gameObject.transform.position = new Vector3(gameObject.transform.position.x + moveX, gameObject.transform.position.y, 5f);
if (gameObject.transform.position.x > -(invertMult) * 11f) { gameObject.transform.position = new Vector3(-(invertMult) * 11f, gameObject.transform.position.y, 5f); }
if (gameObject.transform.position.x < -(invertMult) * 2f) { gameObject.transform.position = new Vector3(-(invertMult) * 2f, gameObject.transform.position.y, 5f); }
if (gameObject.transform.position.x > -(invertMult) * 11f)
{
gameObject.transform.position = new Vector3(-(invertMult) * 11f, gameObject.transform.position.y, 5f);
}
if (gameObject.transform.position.x < -(invertMult) * 2f)
{
gameObject.transform.position = new Vector3(-(invertMult) * 2f, gameObject.transform.position.y, 5f);
}
if (gameObject.transform.position.x < -(invertMult) * 11f) { gameObject.transform.position = new Vector3(-(invertMult) * 11f, gameObject.transform.position.y, 5f); }
if (gameObject.transform.position.x > -(invertMult) * 2f) { gameObject.transform.position = new Vector3(-(invertMult) * 2f, gameObject.transform.position.y, 5f); }
}
if (gameObject.transform.position.y < -2f) { gameObject.transform.position = new Vector3(gameObject.transform.position.x, -2f, 5f); }
//if (gameObject.transform.position.y > 1f) { gameObject.transform.position = new Vector3(gameObject.transform.position.x, 1f, 5f); }
if (gameObject.transform.position.x < -(invertMult) * 11f)
{
gameObject.transform.position = new Vector3(-(invertMult) * 11f, gameObject.transform.position.y, 5f);
}
if (gameObject.transform.position.x > -(invertMult) * 2f)
{
gameObject.transform.position = new Vector3(-(invertMult) * 2f, gameObject.transform.position.y, 5f);
}
}
if (gameObject.transform.position.y < -2f)
{
gameObject.transform.position = new Vector3(gameObject.transform.position.x, -2f, 5f);
}
}
}
// to be implemented by the developer
public override void AgentReset()
{
if (invertX)
{
// to be implemented by the developer
public override void AgentReset()
{
if (invertX)
{
else
else
gameObject.transform.position = new Vector3(-(invertMult)*7f, -1.5f, 5f);
gameObject.transform.position = new Vector3(-(invertMult) * 7f, -1.5f, 5f);
}
}
}

153
unity-environment/Assets/ML-Agents/Examples/Tennis/Scripts/hitWall.cs


using System.Collections.Generic;
using UnityEngine;
public class hitWall : MonoBehaviour {
public class hitWall : MonoBehaviour
{
// Use this for initialization
void Start () {
lastAgentHit = 0;
}
// Update is called once per frame
void Update () {
}
// Use this for initialization
void Start()
{
lastAgentHit = -1;
}
// Update is called once per frame
void Update()
{
}
TennisAgent agentB = GameObject.Find("AgentB").GetComponent<TennisAgent>();
TennisAcademy academy = GameObject.Find("Academy").GetComponent<TennisAcademy>();
TennisAgent agentB = GameObject.Find("AgentB").GetComponent<TennisAgent>();
TennisAcademy academy = GameObject.Find("Academy").GetComponent<TennisAcademy>();
if (collision.gameObject.tag == "iWall")
if (collision.gameObject.tag == "iWall")
academy.done = true;
academy.done = true;
if (collision.gameObject.name == "wallA")
{
if (lastAgentHit == 0)

agentB.score += 1;
}
else {
agentA.reward = 0;
agentB.reward = -0.1f;
agentA.score += 1;
}
else
{
agentA.reward = 0;
agentB.reward = -0.1f;
agentA.score += 1;
}
}
else if (collision.gameObject.name == "wallB")
{

agentB.reward = 0;
agentB.score += 1;
}
else {
agentA.reward = 0;
agentB.reward = -0.1f;
agentA.score += 1;
}
}
else if (collision.gameObject.name == "floorA")
{
if (lastAgentHit == 0)
{
agentA.reward = -0.1f;
agentB.reward = 0;
agentB.score += 1;
}
else
{
agentA.reward = -0.1f;
agentB.score += 1;
}
else
{
agentA.reward = 0;
agentB.reward = -0.1f;
agentA.score += 1;
}
}
else if (collision.gameObject.name == "floorA")
{
if (lastAgentHit != 1)
{
agentA.reward = -0.1f;
agentB.reward = 0;
agentB.score += 1;
}
else
{
agentA.reward = -0.1f;
agentB.score += 1;
agentB.score += 1;
}
}
else if (collision.gameObject.name == "floorB")
{
if (lastAgentHit == 0)
{
}
}
else if (collision.gameObject.name == "floorB")
{
if (lastAgentHit == 0)
{
agentB.reward = -0.1f;
agentA.score += 1;
}
else
{
agentA.reward = 0;
agentB.reward = -0.1f;
agentA.score += 1;
}
}
else if (collision.gameObject.name == "net")
{
if (lastAgentHit == 0)
{
agentA.reward = -0.1f;
agentB.reward = 0.0f;
agentB.score += 1;
}
else
{
agentB.reward = -0.1f;
agentA.score += 1;
}
else
{
agentA.reward = 0;
agentB.reward = -0.1f;
agentA.score += 1;
}
}
else if (collision.gameObject.name == "net")
{
if (lastAgentHit == 0)
{
agentA.reward = -0.1f;
agentB.reward = 0.0f;
agentB.score += 1;
}
else
{
agentB.reward = -0.1f;
agentA.score += 1;
}
}
}
agentB.reward = -0.1f;
agentA.score += 1;
}
}
}
if (lastAgentHit == 1)
if (lastAgentHit != 0)
{
agentA.reward = 0.1f;
agentB.reward = 0.05f;

else
else
if (lastAgentHit == 0)
if (lastAgentHit != 1)
}
}
}
}
}

2
unity-environment/Assets/ML-Agents/Template/Scripts/TemplateDecision.cs


using System.Collections.Generic;
using UnityEngine;
public class Template : MonoBehaviour, Decision {
public class TemplateDecision : MonoBehaviour, Decision {
public float[] Decide (List<float> state, List<Camera> observation, float reward, bool done, float[] memory)
{

11
unity-environment/README.md


- Save environment binary to a sub-directory containing the model to use for training *(you may need to click on the down arrow on the file chooser to be able to select that folder)*
## Example Projects
The `Examples` subfolder contains a set of example environments to use either as starting points or templates for designing your own environments.
The `Examples` subfolder contains a set of example environments to use either as starting points or templates for designing your own environments.
* **3DBalanceBall** - Physics-based game where the agent must rotate a 3D-platform to keep a ball in the air. Supports both discrete and continuous control.
* **GridWorld** - A simple gridworld containing regions which provide positive and negative reward. The agent must learn to move to the rewarding regions (green) and avoid the negatively rewarding ones (red). Supports discrete control.
* **Tennis** - An adversarial game where two agents control rackets, which must be used to bounce a ball back and forth between them. Supports continuous control.

* **Template** - An empty Unity scene with a single _Academy_, _Brain_, and _Agent_. Designed to be used as a template for new environments.
## Agents SDK Package
A link to Unity package containing the Agents SDK for Unity 2017.1 can be downloaded from `Control/UnitEnvironment_0.1.unitpackage`.
For information on the use of each script, see the comments and documentation within the files themselves, or read the documentation linked to below.
A link to Unity package containing the Agents SDK for Unity 2017.1 can be downloaded here :
* [ML-Agents package without TensorflowSharp](https://s3.amazonaws.com/unity-agents/ML-AgentsNoPlugin.unitypackage)
* [ML-Agents package with TensorflowSharp](https://s3.amazonaws.com/unity-agents/ML-AgentsWithPlugin.unitypackage)
For information on the use of each script, see the comments and documentation within the files themselves, or read the [documentation](../../../wiki).
## Creating your own Unity Environment
For information on how to create a new Unity Environment, see the walkthrough [here](../../../wiki/Making-a-new-Unity-Environment). If you have questions or run into issues, please feel free to create issues through the repo, and we will do our best to address them.

1. Make sure you are using Unity 2017.1 or newer.
2. Make sure the TensorflowSharp plugin is in your Asset folder. A Plugins folder which includes TF# can be downloaded [here](https://drive.google.com/file/d/0BxZSPcA0DrkfSEFWcFFCNVZ6U2s/view?usp=sharing).
2. Make sure the TensorflowSharp plugin is in your Asset folder. A Plugins folder which includes TF# can be downloaded [here](https://s3.amazonaws.com/unity-agents/TFSharpPlugin.unitypackage).
3. Go to `Edit` -> `Project Settings` -> `Player`
4. For each of the platforms you target (**`PC, Mac and Linux Standalone`**, **`iOS`** or **`Android`**):
1. Go into `Other Settings`.

196
images/unity-wide.png

之前 之后
宽度: 3000  |  高度: 500  |  大小: 27 KiB
正在加载...
取消
保存