您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
204 行
6.3 KiB
204 行
6.3 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Unity ML Agents\n",
|
|
"## Proximal Policy Optimization (PPO)\n",
|
|
"Contains an implementation of PPO as described [here](https://arxiv.org/abs/1707.06347)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import os\n",
|
|
"import tensorflow as tf\n",
|
|
"\n",
|
|
"from ppo.history import *\n",
|
|
"from ppo.models import *\n",
|
|
"from ppo.trainer import Trainer\n",
|
|
"from unityagents import *"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Hyperparameters"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"### General parameters\n",
|
|
"max_steps = 5e5 # Set maximum number of steps to run environment.\n",
|
|
"run_path = \"ppo\" # The sub-directory name for model and summary statistics\n",
|
|
"load_model = False # Whether to load a saved model.\n",
|
|
"train_model = True # Whether to train the model.\n",
|
|
"summary_freq = 10000 # Frequency at which to save training statistics.\n",
|
|
"save_freq = 50000 # Frequency at which to save model.\n",
|
|
"env_name = \"simple\" # Name of the training environment file.\n",
|
|
"\n",
|
|
"### Algorithm-specific parameters for tuning\n",
|
|
"gamma = 0.99 # Reward discount rate.\n",
|
|
"lambd = 0.95 # Lambda parameter for GAE.\n",
|
|
"time_horizon = 2048 # How many steps to collect per agent before adding to buffer.\n",
|
|
"beta = 1e-3 # Strength of entropy regularization\n",
|
|
"num_epoch = 5 # Number of gradient descent steps per batch of experiences.\n",
|
|
"epsilon = 0.2 # Acceptable threshold around ratio of old and new policy probabilities.\n",
|
|
"buffer_size = 2048 # How large the experience buffer should be before gradient descent.\n",
|
|
"learning_rate = 3e-4 # Model learning rate.\n",
|
|
"hidden_units = 64 # Number of units in hidden layer.\n",
|
|
"batch_size = 64 # How many experiences per gradient descent update step."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load the environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"env = UnityEnvironment(file_name=env_name)\n",
|
|
"print(str(env))\n",
|
|
"brain_name = env.brain_names[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Train the Agent(s)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true,
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"tf.reset_default_graph()\n",
|
|
"\n",
|
|
"# Create the Tensorflow model graph\n",
|
|
"ppo_model = create_agent_model(env, lr=learning_rate,\n",
|
|
" h_size=hidden_units, epsilon=epsilon,\n",
|
|
" beta=beta)\n",
|
|
"\n",
|
|
"is_continuous = (env.brains[brain_name].action_space_type == \"continuous\")\n",
|
|
"use_observations = (env.brains[brain_name].number_observations > 0)\n",
|
|
"\n",
|
|
"model_path = './models/{}'.format(run_path)\n",
|
|
"summary_path = './summaries/{}'.format(run_path)\n",
|
|
"\n",
|
|
"if not os.path.exists(model_path):\n",
|
|
" os.makedirs(model_path)\n",
|
|
"\n",
|
|
"if not os.path.exists(summary_path):\n",
|
|
" os.makedirs(summary_path)\n",
|
|
"\n",
|
|
"init = tf.global_variables_initializer()\n",
|
|
"saver = tf.train.Saver()\n",
|
|
"\n",
|
|
"with tf.Session() as sess:\n",
|
|
" # Instantiate model parameters\n",
|
|
" if load_model:\n",
|
|
" print('Loading Model...')\n",
|
|
" ckpt = tf.train.get_checkpoint_state(model_path)\n",
|
|
" saver.restore(sess, ckpt.model_checkpoint_path)\n",
|
|
" else:\n",
|
|
" sess.run(init)\n",
|
|
" steps = sess.run(ppo_model.global_step)\n",
|
|
" summary_writer = tf.summary.FileWriter(summary_path)\n",
|
|
" info = env.reset(train_mode=train_model)[brain_name]\n",
|
|
" trainer = Trainer(ppo_model, sess, info, is_continuous, use_observations)\n",
|
|
" while steps <= max_steps:\n",
|
|
" if env.global_done:\n",
|
|
" info = env.reset(train_mode=train_model)[brain_name]\n",
|
|
" # Decide and take an action\n",
|
|
" new_info = trainer.take_action(info, env, brain_name)\n",
|
|
" info = new_info\n",
|
|
" trainer.process_experiences(info, time_horizon, gamma, lambd)\n",
|
|
" if len(trainer.training_buffer['actions']) > buffer_size and train_model:\n",
|
|
" # Perform gradient descent with experience buffer\n",
|
|
" trainer.update_model(batch_size, num_epoch)\n",
|
|
" if steps % summary_freq == 0 and steps != 0 and train_model:\n",
|
|
" # Write training statistics to tensorboard.\n",
|
|
" trainer.write_summary(summary_writer, steps)\n",
|
|
" if steps % save_freq == 0 and steps != 0 and train_model:\n",
|
|
" # Save Tensorflow model\n",
|
|
" save_model(sess, model_path=model_path, steps=steps, saver=saver)\n",
|
|
" steps += 1\n",
|
|
" sess.run(ppo_model.increment_step)\n",
|
|
" # Final save Tensorflow model\n",
|
|
" if steps != 0 and train_model:\n",
|
|
" save_model(sess, model_path=model_path, steps=steps, saver=saver)\n",
|
|
"env.close()\n",
|
|
"export_graph(model_path, env_name)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Export the trained Tensorflow graph\n",
|
|
"Once the model has been trained and saved, we can export it as a .bytes file which Unity can embed."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"export_graph(model_path, env_name)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"anaconda-cloud": {},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.2"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|