GitHub
3 年前
当前提交
c3cc3add
共有 7 个文件被更改,包括 1385 次插入 和 5 次删除
-
1.pre-commit-config.yaml
-
1com.unity.ml-agents/CHANGELOG.md
-
13utils/validate_release_links.py
-
494colab/Colab_UnityEnvironment_1_Run.ipynb
-
588colab/Colab_UnityEnvironment_2_Train.ipynb
-
293colab/Colab_UnityEnvironment_3_SideChannel.ipynb
|
|||
{ |
|||
"nbformat": 4, |
|||
"nbformat_minor": 0, |
|||
"metadata": { |
|||
"colab": { |
|||
"name": "Colab-UnityEnvironment-1-Run.ipynb", |
|||
"private_outputs": true, |
|||
"provenance": [], |
|||
"collapsed_sections": [], |
|||
"toc_visible": true |
|||
}, |
|||
"kernelspec": { |
|||
"name": "python3", |
|||
"display_name": "Python 3" |
|||
} |
|||
}, |
|||
"cells": [ |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "pbVXrmEsLXDt" |
|||
}, |
|||
"source": [ |
|||
"# ML-Agents Open a UnityEnvironment\n", |
|||
"<img src=\"https://github.com/Unity-Technologies/ml-agents/blob/release_1/docs/images/image-banner.png?raw=true\" align=\"middle\" width=\"435\"/>" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "WNKTwHU3d2-l" |
|||
}, |
|||
"source": [ |
|||
"## Setup" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "htb-p1hSNX7D" |
|||
}, |
|||
"source": [ |
|||
"#@title Install Rendering Dependencies { display-mode: \"form\" }\n", |
|||
"#@markdown (You only need to run this code when using Colab's hosted runtime)\n", |
|||
"\n", |
|||
"from IPython.display import HTML, display\n", |
|||
"\n", |
|||
"def progress(value, max=100):\n", |
|||
" return HTML(\"\"\"\n", |
|||
" <progress\n", |
|||
" value='{value}'\n", |
|||
" max='{max}',\n", |
|||
" style='width: 100%'\n", |
|||
" >\n", |
|||
" {value}\n", |
|||
" </progress>\n", |
|||
" \"\"\".format(value=value, max=max))\n", |
|||
"\n", |
|||
"pro_bar = display(progress(0, 100), display_id=True)\n", |
|||
"\n", |
|||
"try:\n", |
|||
" import google.colab\n", |
|||
" IN_COLAB = True\n", |
|||
"except ImportError:\n", |
|||
" IN_COLAB = False\n", |
|||
"\n", |
|||
"if IN_COLAB:\n", |
|||
" with open('frame-buffer', 'w') as writefile:\n", |
|||
" writefile.write(\"\"\"#taken from https://gist.github.com/jterrace/2911875\n", |
|||
"XVFB=/usr/bin/Xvfb\n", |
|||
"XVFBARGS=\":1 -screen 0 1024x768x24 -ac +extension GLX +render -noreset\"\n", |
|||
"PIDFILE=./frame-buffer.pid\n", |
|||
"case \"$1\" in\n", |
|||
" start)\n", |
|||
" echo -n \"Starting virtual X frame buffer: Xvfb\"\n", |
|||
" /sbin/start-stop-daemon --start --quiet --pidfile $PIDFILE --make-pidfile --background --exec $XVFB -- $XVFBARGS\n", |
|||
" echo \".\"\n", |
|||
" ;;\n", |
|||
" stop)\n", |
|||
" echo -n \"Stopping virtual X frame buffer: Xvfb\"\n", |
|||
" /sbin/start-stop-daemon --stop --quiet --pidfile $PIDFILE\n", |
|||
" rm $PIDFILE\n", |
|||
" echo \".\"\n", |
|||
" ;;\n", |
|||
" restart)\n", |
|||
" $0 stop\n", |
|||
" $0 start\n", |
|||
" ;;\n", |
|||
" *)\n", |
|||
" echo \"Usage: /etc/init.d/xvfb {start|stop|restart}\"\n", |
|||
" exit 1\n", |
|||
"esac\n", |
|||
"exit 0\n", |
|||
" \"\"\")\n", |
|||
" pro_bar.update(progress(5, 100))\n", |
|||
" !apt-get install daemon >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(10, 100))\n", |
|||
" !apt-get install wget >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(20, 100))\n", |
|||
" !wget http://security.ubuntu.com/ubuntu/pool/main/libx/libxfont/libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(30, 100))\n", |
|||
" !wget --output-document xvfb.deb http://security.ubuntu.com/ubuntu/pool/universe/x/xorg-server/xvfb_1.18.4-0ubuntu0.12_amd64.deb >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(40, 100))\n", |
|||
" !dpkg -i libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(50, 100))\n", |
|||
" !dpkg -i xvfb.deb >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(70, 100))\n", |
|||
" !rm libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb\n", |
|||
" pro_bar.update(progress(80, 100))\n", |
|||
" !rm xvfb.deb\n", |
|||
" pro_bar.update(progress(90, 100))\n", |
|||
" !bash frame-buffer start\n", |
|||
" import os\n", |
|||
" os.environ[\"DISPLAY\"] = \":1\"\n", |
|||
"pro_bar.update(progress(100, 100))" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "Pzj7wgapAcDs" |
|||
}, |
|||
"source": [ |
|||
"### Installing ml-agents" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "N8yfQqkbebQ5" |
|||
}, |
|||
"source": [ |
|||
"try:\n", |
|||
" import mlagents\n", |
|||
" print(\"ml-agents already installed\")\n", |
|||
"except ImportError:\n", |
|||
" !pip install -q mlagents==0.25.1\n", |
|||
" print(\"Installed ml-agents\")" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "_u74YhSmW6gD" |
|||
}, |
|||
"source": [ |
|||
"## Run the Environment" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "DpZPbRvRuLZv" |
|||
}, |
|||
"source": [ |
|||
"#@title Select Environment { display-mode: \"form\" }\n", |
|||
"env_id = \"GridWorld\" #@param ['Basic', '3DBall', '3DBallHard', 'GridWorld', 'Hallway', 'VisualHallway', 'CrawlerDynamicTarget', 'CrawlerStaticTarget', 'Bouncer', 'SoccerTwos', 'PushBlock', 'VisualPushBlock', 'WallJump', 'Tennis', 'Reacher', 'Pyramids', 'VisualPyramids', 'Walker', 'FoodCollector', 'VisualFoodCollector', 'StrikersVsGoalie', 'WormStaticTarget', 'WormDynamicTarget']\n" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "P-r_cB2rqp5x" |
|||
}, |
|||
"source": [ |
|||
"### Start Environment from the registry" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "YSf-WhxbqtLw" |
|||
}, |
|||
"source": [ |
|||
"# -----------------\n", |
|||
"# This code is used to close an env that might not have been closed before\n", |
|||
"try:\n", |
|||
" env.close()\n", |
|||
"except:\n", |
|||
" pass\n", |
|||
"# -----------------\n", |
|||
"\n", |
|||
"from mlagents_envs.registry import default_registry\n", |
|||
"\n", |
|||
"env = default_registry[env_id].make()" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "h1lIx3_l24OP" |
|||
}, |
|||
"source": [ |
|||
"### Reset the environment\n", |
|||
"To reset the environment, simply call `env.reset()`. This method takes no argument and returns nothing but will send a signal to the simulation to reset." |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "dhtl0mpeqxYi" |
|||
}, |
|||
"source": [ |
|||
"env.reset()" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "k1rwnVq2qyoO" |
|||
}, |
|||
"source": [ |
|||
"### Behavior Specs\n" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "TrD0rSv92T8A" |
|||
}, |
|||
"source": [ |
|||
"#### Get the Behavior Specs from the Environment" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "a7KatdThq7OV" |
|||
}, |
|||
"source": [ |
|||
"# We will only consider the first Behavior\n", |
|||
"behavior_name = list(env.behavior_specs)[0] \n", |
|||
"print(f\"Name of the behavior : {behavior_name}\")\n", |
|||
"spec = env.behavior_specs[behavior_name]" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "m1L8DHADrAbe" |
|||
}, |
|||
"source": [ |
|||
"#### Get the Observation Space from the Behavior Specs" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "PqDTV5mSrJF5" |
|||
}, |
|||
"source": [ |
|||
"# Examine the number of observations per Agent\n", |
|||
"print(\"Number of observations : \", len(spec.observation_specs))\n", |
|||
"\n", |
|||
"# Is there a visual observation ?\n", |
|||
"# Visual observation have 3 dimensions: Height, Width and number of channels\n", |
|||
"vis_obs = any(len(spec.shape) == 3 for spec in spec.observation_specs)\n", |
|||
"print(\"Is there a visual observation ?\", vis_obs)" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "yVLN_wbG1G5-" |
|||
}, |
|||
"source": [ |
|||
"#### Get the Action Space from the Behavior Specs" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "M9zk1-az1L-G" |
|||
}, |
|||
"source": [ |
|||
"# Is the Action continuous or multi-discrete ?\n", |
|||
"if spec.action_spec.continuous_size > 0:\n", |
|||
" print(f\"There are {spec.action_spec.continuous_size} continuous actions\")\n", |
|||
"if spec.action_spec.is_discrete():\n", |
|||
" print(f\"There are {spec.action_spec.discrete_size} discrete actions\")\n", |
|||
"\n", |
|||
"\n", |
|||
"# How many actions are possible ?\n", |
|||
"#print(f\"There are {spec.action_size} action(s)\")\n", |
|||
"\n", |
|||
"# For discrete actions only : How many different options does each action has ?\n", |
|||
"if spec.action_spec.discrete_size > 0:\n", |
|||
" for action, branch_size in enumerate(spec.action_spec.discrete_branches):\n", |
|||
" print(f\"Action number {action} has {branch_size} different options\")\n", |
|||
" \n" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "3cX07SGw22Lm" |
|||
}, |
|||
"source": [ |
|||
"### Stepping the environment" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "xO5p0s0prfsQ" |
|||
}, |
|||
"source": [ |
|||
"#### Get the steps from the Environment\n", |
|||
"You can do this with the `env.get_steps(behavior_name)` method. If there are multiple behaviors in the Environment, you can call this method with each of the behavior's names.\n", |
|||
"_Note_ This will not move the simulation forward." |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "ePZtcHXUrjyf" |
|||
}, |
|||
"source": [ |
|||
"decision_steps, terminal_steps = env.get_steps(behavior_name)" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "f-Oj3ix530mx" |
|||
}, |
|||
"source": [ |
|||
"#### Set actions for each behavior\n", |
|||
"You can set the actions for the Agents of a Behavior by calling `env.set_actions()` you will need to specify the behavior name and pass a tensor of dimension 2. The first dimension of the action must be equal to the number of Agents that requested a decision during the step." |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "KB-nxfbw337g" |
|||
}, |
|||
"source": [ |
|||
"env.set_actions(behavior_name, spec.action_spec.empty_action(len(decision_steps)))" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "MQCybRs84cmq" |
|||
}, |
|||
"source": [ |
|||
"#### Move the simulation forward\n", |
|||
"Call `env.step()` to move the simulation forward. The simulation will progress until an Agent requestes a decision or terminates." |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "nl3K40ZR4bh2" |
|||
}, |
|||
"source": [ |
|||
"env.step()" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "i9gdextn2vJy" |
|||
}, |
|||
"source": [ |
|||
"### Observations" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "iAMqnnddr8Xo" |
|||
}, |
|||
"source": [ |
|||
"#### Show the observations for one of the Agents\n", |
|||
"`DecisionSteps.obs` is a tuple containing all of the observations for all of the Agents with the provided Behavior name.\n", |
|||
"Each value in the tuple is an observation tensor containing the observation data for all of the agents." |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "OJpta61TsBiO" |
|||
}, |
|||
"source": [ |
|||
"import matplotlib.pyplot as plt\n", |
|||
"%matplotlib inline\n", |
|||
"\n", |
|||
"for index, obs_spec in enumerate(spec.observation_specs):\n", |
|||
" if len(obs_spec.shape) == 3:\n", |
|||
" print(\"Here is the first visual observation\")\n", |
|||
" plt.imshow(decision_steps.obs[index][0,:,:,:])\n", |
|||
" plt.show()\n", |
|||
"\n", |
|||
"for index, obs_spec in enumerate(spec.observation_specs):\n", |
|||
" if len(obs_spec.shape) == 1:\n", |
|||
" print(\"First vector observations : \", decision_steps.obs[index][0,:])" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "y60u21sys8kA" |
|||
}, |
|||
"source": [ |
|||
"### Run the Environment for a few episodes" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "a2uQUsoMtIUK" |
|||
}, |
|||
"source": [ |
|||
"for episode in range(3):\n", |
|||
" env.reset()\n", |
|||
" decision_steps, terminal_steps = env.get_steps(behavior_name)\n", |
|||
" tracked_agent = -1 # -1 indicates not yet tracking\n", |
|||
" done = False # For the tracked_agent\n", |
|||
" episode_rewards = 0 # For the tracked_agent\n", |
|||
" while not done:\n", |
|||
" # Track the first agent we see if not tracking \n", |
|||
" # Note : len(decision_steps) = [number of agents that requested a decision]\n", |
|||
" if tracked_agent == -1 and len(decision_steps) >= 1:\n", |
|||
" tracked_agent = decision_steps.agent_id[0] \n", |
|||
"\n", |
|||
" # Generate an action for all agents\n", |
|||
" action = spec.action_spec.random_action(len(decision_steps))\n", |
|||
"\n", |
|||
" # Set the actions\n", |
|||
" env.set_actions(behavior_name, action)\n", |
|||
"\n", |
|||
" # Move the simulation forward\n", |
|||
" env.step()\n", |
|||
"\n", |
|||
" # Get the new simulation results\n", |
|||
" decision_steps, terminal_steps = env.get_steps(behavior_name)\n", |
|||
" if tracked_agent in decision_steps: # The agent requested a decision\n", |
|||
" episode_rewards += decision_steps[tracked_agent].reward\n", |
|||
" if tracked_agent in terminal_steps: # The agent terminated its episode\n", |
|||
" episode_rewards += terminal_steps[tracked_agent].reward\n", |
|||
" done = True\n", |
|||
" print(f\"Total rewards for episode {episode} is {episode_rewards}\")\n" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "R-3grXNEtJPa" |
|||
}, |
|||
"source": [ |
|||
"### Close the Environment to free the port it is using" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "vdWG6_SqtNtv" |
|||
}, |
|||
"source": [ |
|||
"env.close()\n", |
|||
"print(\"Closed environment\")" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
} |
|||
] |
|||
} |
|
|||
{ |
|||
"nbformat": 4, |
|||
"nbformat_minor": 0, |
|||
"metadata": { |
|||
"colab": { |
|||
"name": "Colab-UnityEnvironment-2-Train.ipynb", |
|||
"private_outputs": true, |
|||
"provenance": [], |
|||
"collapsed_sections": [], |
|||
"toc_visible": true |
|||
}, |
|||
"kernelspec": { |
|||
"name": "python3", |
|||
"display_name": "Python 3" |
|||
} |
|||
}, |
|||
"cells": [ |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "pbVXrmEsLXDt" |
|||
}, |
|||
"source": [ |
|||
"# ML-Agents Q-Learning with GridWorld\n", |
|||
"<img src=\"https://github.com/Unity-Technologies/ml-agents/blob/release_2/docs/images/gridworld.png?raw=true\" align=\"middle\" width=\"435\"/>" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "WNKTwHU3d2-l" |
|||
}, |
|||
"source": [ |
|||
"## Setup" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "htb-p1hSNX7D" |
|||
}, |
|||
"source": [ |
|||
"#@title Install Rendering Dependencies { display-mode: \"form\" }\n", |
|||
"#@markdown (You only need to run this code when using Colab's hosted runtime)\n", |
|||
"\n", |
|||
"from IPython.display import HTML, display\n", |
|||
"\n", |
|||
"def progress(value, max=100):\n", |
|||
" return HTML(\"\"\"\n", |
|||
" <progress\n", |
|||
" value='{value}'\n", |
|||
" max='{max}',\n", |
|||
" style='width: 100%'\n", |
|||
" >\n", |
|||
" {value}\n", |
|||
" </progress>\n", |
|||
" \"\"\".format(value=value, max=max))\n", |
|||
"\n", |
|||
"pro_bar = display(progress(0, 100), display_id=True)\n", |
|||
"\n", |
|||
"try:\n", |
|||
" import google.colab\n", |
|||
" IN_COLAB = True\n", |
|||
"except ImportError:\n", |
|||
" IN_COLAB = False\n", |
|||
"\n", |
|||
"if IN_COLAB:\n", |
|||
" with open('frame-buffer', 'w') as writefile:\n", |
|||
" writefile.write(\"\"\"#taken from https://gist.github.com/jterrace/2911875\n", |
|||
"XVFB=/usr/bin/Xvfb\n", |
|||
"XVFBARGS=\":1 -screen 0 1024x768x24 -ac +extension GLX +render -noreset\"\n", |
|||
"PIDFILE=./frame-buffer.pid\n", |
|||
"case \"$1\" in\n", |
|||
" start)\n", |
|||
" echo -n \"Starting virtual X frame buffer: Xvfb\"\n", |
|||
" /sbin/start-stop-daemon --start --quiet --pidfile $PIDFILE --make-pidfile --background --exec $XVFB -- $XVFBARGS\n", |
|||
" echo \".\"\n", |
|||
" ;;\n", |
|||
" stop)\n", |
|||
" echo -n \"Stopping virtual X frame buffer: Xvfb\"\n", |
|||
" /sbin/start-stop-daemon --stop --quiet --pidfile $PIDFILE\n", |
|||
" rm $PIDFILE\n", |
|||
" echo \".\"\n", |
|||
" ;;\n", |
|||
" restart)\n", |
|||
" $0 stop\n", |
|||
" $0 start\n", |
|||
" ;;\n", |
|||
" *)\n", |
|||
" echo \"Usage: /etc/init.d/xvfb {start|stop|restart}\"\n", |
|||
" exit 1\n", |
|||
"esac\n", |
|||
"exit 0\n", |
|||
" \"\"\")\n", |
|||
" pro_bar.update(progress(5, 100))\n", |
|||
" !apt-get install daemon >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(10, 100))\n", |
|||
" !apt-get install wget >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(20, 100))\n", |
|||
" !wget http://security.ubuntu.com/ubuntu/pool/main/libx/libxfont/libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(30, 100))\n", |
|||
" !wget --output-document xvfb.deb http://security.ubuntu.com/ubuntu/pool/universe/x/xorg-server/xvfb_1.18.4-0ubuntu0.12_amd64.deb >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(40, 100))\n", |
|||
" !dpkg -i libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(50, 100))\n", |
|||
" !dpkg -i xvfb.deb >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(70, 100))\n", |
|||
" !rm libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb\n", |
|||
" pro_bar.update(progress(80, 100))\n", |
|||
" !rm xvfb.deb\n", |
|||
" pro_bar.update(progress(90, 100))\n", |
|||
" !bash frame-buffer start\n", |
|||
" import os\n", |
|||
" os.environ[\"DISPLAY\"] = \":1\"\n", |
|||
"pro_bar.update(progress(100, 100))" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "Pzj7wgapAcDs" |
|||
}, |
|||
"source": [ |
|||
"### Installing ml-agents" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "N8yfQqkbebQ5" |
|||
}, |
|||
"source": [ |
|||
"try:\n", |
|||
" import mlagents\n", |
|||
" print(\"ml-agents already installed\")\n", |
|||
"except ImportError:\n", |
|||
" !pip install -q mlagents==0.25.1\n", |
|||
" print(\"Installed ml-agents\")" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "jz81TWAkbuFY" |
|||
}, |
|||
"source": [ |
|||
"## Train the GridWorld Environment with Q-Learning" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "29n3dt1Zx5ty" |
|||
}, |
|||
"source": [ |
|||
"### What is the GridWorld Environment" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "pZhVRfdoyPmv" |
|||
}, |
|||
"source": [ |
|||
"The [GridWorld](https://github.com/Unity-Technologies/ml-agents/blob/release_2/docs/Learning-Environment-Examples.md#gridworld) Environment is a simple Unity visual environment. The Agent is a blue square in a 3x3 grid that is trying to reach a green __`+`__ while avoiding a red __`x`__.\n", |
|||
"\n", |
|||
"The observation is an image obtained by a camera on top of the grid.\n", |
|||
"\n", |
|||
"The Action can be one of 5 : \n", |
|||
" - Do not move\n", |
|||
" - Move up\n", |
|||
" - Move down\n", |
|||
" - Move right\n", |
|||
" - Move left\n", |
|||
"\n", |
|||
"The Agent receives a reward of _1.0_ if it reaches the green __`+`__, a penalty of _-1.0_ if it touches the red __`x`__ and a penalty of `-0.01` at every step (to force the Agent to solve the task as fast as possible)\n", |
|||
"\n", |
|||
"__Note__ There are 9 Agents, each in their own grid, at once in the environment. This alows for faster data collection.\n", |
|||
"\n", |
|||
"\n" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "4Gt-ZydJyJWD" |
|||
}, |
|||
"source": [ |
|||
"### The Q-Learning Algorithm\n" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "KA1qOgfq0Xdv" |
|||
}, |
|||
"source": [ |
|||
"In this Notebook, we will implement a very simple Q-Learning algorithm. We will use [pytorch](https://pytorch.org/) to do so.\n", |
|||
"\n", |
|||
"Below is the code to create the neural network we will use in the Notebook." |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "q79rUp_Sx6A_" |
|||
}, |
|||
"source": [ |
|||
"import torch\n", |
|||
"from typing import Tuple\n", |
|||
"from math import floor\n", |
|||
"\n", |
|||
"\n", |
|||
"class VisualQNetwork(torch.nn.Module):\n", |
|||
" def __init__(\n", |
|||
" self,\n", |
|||
" input_shape: Tuple[int, int, int], \n", |
|||
" encoding_size: int, \n", |
|||
" output_size: int\n", |
|||
" ):\n", |
|||
" \"\"\"\n", |
|||
" Creates a neural network that takes as input a batch of images (3\n", |
|||
" dimensional tensors) and outputs a batch of outputs (1 dimensional\n", |
|||
" tensors)\n", |
|||
" \"\"\"\n", |
|||
" super(VisualQNetwork, self).__init__()\n", |
|||
" height = input_shape[0]\n", |
|||
" width = input_shape[1]\n", |
|||
" initial_channels = input_shape[2]\n", |
|||
" conv_1_hw = self.conv_output_shape((height, width), 8, 4)\n", |
|||
" conv_2_hw = self.conv_output_shape(conv_1_hw, 4, 2)\n", |
|||
" self.final_flat = conv_2_hw[0] * conv_2_hw[1] * 32\n", |
|||
" self.conv1 = torch.nn.Conv2d(initial_channels, 16, [8, 8], [4, 4])\n", |
|||
" self.conv2 = torch.nn.Conv2d(16, 32, [4, 4], [2, 2])\n", |
|||
" self.dense1 = torch.nn.Linear(self.final_flat, encoding_size)\n", |
|||
" self.dense2 = torch.nn.Linear(encoding_size, output_size)\n", |
|||
"\n", |
|||
" def forward(self, visual_obs: torch.tensor):\n", |
|||
" visual_obs = visual_obs.permute(0, 3, 1, 2)\n", |
|||
" conv_1 = torch.relu(self.conv1(visual_obs))\n", |
|||
" conv_2 = torch.relu(self.conv2(conv_1))\n", |
|||
" hidden = self.dense1(conv_2.reshape([-1, self.final_flat]))\n", |
|||
" hidden = torch.relu(hidden)\n", |
|||
" hidden = self.dense2(hidden)\n", |
|||
" return hidden\n", |
|||
"\n", |
|||
" @staticmethod\n", |
|||
" def conv_output_shape(\n", |
|||
" h_w: Tuple[int, int],\n", |
|||
" kernel_size: int = 1,\n", |
|||
" stride: int = 1,\n", |
|||
" pad: int = 0,\n", |
|||
" dilation: int = 1,\n", |
|||
" ):\n", |
|||
" \"\"\"\n", |
|||
" Computes the height and width of the output of a convolution layer.\n", |
|||
" \"\"\"\n", |
|||
" h = floor(\n", |
|||
" ((h_w[0] + (2 * pad) - (dilation * (kernel_size - 1)) - 1) / stride) + 1\n", |
|||
" )\n", |
|||
" w = floor(\n", |
|||
" ((h_w[1] + (2 * pad) - (dilation * (kernel_size - 1)) - 1) / stride) + 1\n", |
|||
" )\n", |
|||
" return h, w\n" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "EZoaEBAo2L0F" |
|||
}, |
|||
"source": [ |
|||
"We will now create a few classes to help us store the data we will use to train the Q-Learning algorithm." |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "L772fe2q39DO" |
|||
}, |
|||
"source": [ |
|||
"import numpy as np\n", |
|||
"from typing import NamedTuple, List\n", |
|||
"\n", |
|||
"\n", |
|||
"class Experience(NamedTuple):\n", |
|||
" \"\"\"\n", |
|||
" An experience contains the data of one Agent transition.\n", |
|||
" - Observation\n", |
|||
" - Action\n", |
|||
" - Reward\n", |
|||
" - Done flag\n", |
|||
" - Next Observation\n", |
|||
" \"\"\"\n", |
|||
"\n", |
|||
" obs: np.ndarray\n", |
|||
" action: np.ndarray\n", |
|||
" reward: float\n", |
|||
" done: bool\n", |
|||
" next_obs: np.ndarray\n", |
|||
"\n", |
|||
"# A Trajectory is an ordered sequence of Experiences\n", |
|||
"Trajectory = List[Experience]\n", |
|||
"\n", |
|||
"# A Buffer is an unordered list of Experiences from multiple Trajectories\n", |
|||
"Buffer = List[Experience]" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "6HsM1d5I3_Tj" |
|||
}, |
|||
"source": [ |
|||
"Now, we can create our trainer class. The role of this trainer is to collect data from the Environment according to a Policy, and then train the Q-Network with that data." |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "KkzBoRJCb18t" |
|||
}, |
|||
"source": [ |
|||
"from mlagents_envs.environment import ActionTuple, BaseEnv\n", |
|||
"from typing import Dict\n", |
|||
"import random\n", |
|||
"\n", |
|||
"\n", |
|||
"class Trainer:\n", |
|||
" @staticmethod\n", |
|||
" def generate_trajectories(\n", |
|||
" env: BaseEnv, q_net: VisualQNetwork, buffer_size: int, epsilon: float\n", |
|||
" ):\n", |
|||
" \"\"\"\n", |
|||
" Given a Unity Environment and a Q-Network, this method will generate a\n", |
|||
" buffer of Experiences obtained by running the Environment with the Policy\n", |
|||
" derived from the Q-Network.\n", |
|||
" :param BaseEnv: The UnityEnvironment used.\n", |
|||
" :param q_net: The Q-Network used to collect the data.\n", |
|||
" :param buffer_size: The minimum size of the buffer this method will return.\n", |
|||
" :param epsilon: Will add a random normal variable with standard deviation.\n", |
|||
" epsilon to the value heads of the Q-Network to encourage exploration.\n", |
|||
" :returns: a Tuple containing the created buffer and the average cumulative\n", |
|||
" the Agents obtained.\n", |
|||
" \"\"\"\n", |
|||
" # Create an empty Buffer\n", |
|||
" buffer: Buffer = []\n", |
|||
"\n", |
|||
" # Reset the environment\n", |
|||
" env.reset()\n", |
|||
" # Read and store the Behavior Name of the Environment\n", |
|||
" behavior_name = list(env.behavior_specs)[0]\n", |
|||
" # Read and store the Behavior Specs of the Environment\n", |
|||
" spec = env.behavior_specs[behavior_name]\n", |
|||
"\n", |
|||
" # Create a Mapping from AgentId to Trajectories. This will help us create\n", |
|||
" # trajectories for each Agents\n", |
|||
" dict_trajectories_from_agent: Dict[int, Trajectory] = {}\n", |
|||
" # Create a Mapping from AgentId to the last observation of the Agent\n", |
|||
" dict_last_obs_from_agent: Dict[int, np.ndarray] = {}\n", |
|||
" # Create a Mapping from AgentId to the last observation of the Agent\n", |
|||
" dict_last_action_from_agent: Dict[int, np.ndarray] = {}\n", |
|||
" # Create a Mapping from AgentId to cumulative reward (Only for reporting)\n", |
|||
" dict_cumulative_reward_from_agent: Dict[int, float] = {}\n", |
|||
" # Create a list to store the cumulative rewards obtained so far\n", |
|||
" cumulative_rewards: List[float] = []\n", |
|||
"\n", |
|||
" while len(buffer) < buffer_size: # While not enough data in the buffer\n", |
|||
" # Get the Decision Steps and Terminal Steps of the Agents\n", |
|||
" decision_steps, terminal_steps = env.get_steps(behavior_name)\n", |
|||
"\n", |
|||
" # For all Agents with a Terminal Step:\n", |
|||
" for agent_id_terminated in terminal_steps:\n", |
|||
" # Create its last experience (is last because the Agent terminated)\n", |
|||
" last_experience = Experience(\n", |
|||
" obs=dict_last_obs_from_agent[agent_id_terminated].copy(),\n", |
|||
" reward=terminal_steps[agent_id_terminated].reward,\n", |
|||
" done=not terminal_steps[agent_id_terminated].interrupted,\n", |
|||
" action=dict_last_action_from_agent[agent_id_terminated].copy(),\n", |
|||
" next_obs=terminal_steps[agent_id_terminated].obs[0],\n", |
|||
" )\n", |
|||
" # Clear its last observation and action (Since the trajectory is over)\n", |
|||
" dict_last_obs_from_agent.pop(agent_id_terminated)\n", |
|||
" dict_last_action_from_agent.pop(agent_id_terminated)\n", |
|||
" # Report the cumulative reward\n", |
|||
" cumulative_reward = (\n", |
|||
" dict_cumulative_reward_from_agent.pop(agent_id_terminated)\n", |
|||
" + terminal_steps[agent_id_terminated].reward\n", |
|||
" )\n", |
|||
" cumulative_rewards.append(cumulative_reward)\n", |
|||
" # Add the Trajectory and the last experience to the buffer\n", |
|||
" buffer.extend(dict_trajectories_from_agent.pop(agent_id_terminated))\n", |
|||
" buffer.append(last_experience)\n", |
|||
"\n", |
|||
" # For all Agents with a Decision Step:\n", |
|||
" for agent_id_decisions in decision_steps:\n", |
|||
" # If the Agent does not have a Trajectory, create an empty one\n", |
|||
" if agent_id_decisions not in dict_trajectories_from_agent:\n", |
|||
" dict_trajectories_from_agent[agent_id_decisions] = []\n", |
|||
" dict_cumulative_reward_from_agent[agent_id_decisions] = 0\n", |
|||
"\n", |
|||
" # If the Agent requesting a decision has a \"last observation\"\n", |
|||
" if agent_id_decisions in dict_last_obs_from_agent:\n", |
|||
" # Create an Experience from the last observation and the Decision Step\n", |
|||
" exp = Experience(\n", |
|||
" obs=dict_last_obs_from_agent[agent_id_decisions].copy(),\n", |
|||
" reward=decision_steps[agent_id_decisions].reward,\n", |
|||
" done=False,\n", |
|||
" action=dict_last_action_from_agent[agent_id_decisions].copy(),\n", |
|||
" next_obs=decision_steps[agent_id_decisions].obs[0],\n", |
|||
" )\n", |
|||
" # Update the Trajectory of the Agent and its cumulative reward\n", |
|||
" dict_trajectories_from_agent[agent_id_decisions].append(exp)\n", |
|||
" dict_cumulative_reward_from_agent[agent_id_decisions] += (\n", |
|||
" decision_steps[agent_id_decisions].reward\n", |
|||
" )\n", |
|||
" # Store the observation as the new \"last observation\"\n", |
|||
" dict_last_obs_from_agent[agent_id_decisions] = (\n", |
|||
" decision_steps[agent_id_decisions].obs[0]\n", |
|||
" )\n", |
|||
"\n", |
|||
" # Generate an action for all the Agents that requested a decision\n", |
|||
" # Compute the values for each action given the observation\n", |
|||
" actions_values = (\n", |
|||
" q_net(torch.from_numpy(decision_steps.obs[0])).detach().numpy()\n", |
|||
" )\n", |
|||
" # Add some noise with epsilon to the values\n", |
|||
" actions_values += epsilon * (\n", |
|||
" np.random.randn(actions_values.shape[0], actions_values.shape[1])\n", |
|||
" ).astype(np.float32)\n", |
|||
" # Pick the best action using argmax\n", |
|||
" actions = np.argmax(actions_values, axis=1)\n", |
|||
" actions.resize((len(decision_steps), 1))\n", |
|||
" # Store the action that was picked, it will be put in the trajectory later\n", |
|||
" for agent_index, agent_id in enumerate(decision_steps.agent_id):\n", |
|||
" dict_last_action_from_agent[agent_id] = actions[agent_index]\n", |
|||
"\n", |
|||
" # Set the actions in the environment\n", |
|||
" # Unity Environments expect ActionTuple instances.\n", |
|||
" action_tuple = ActionTuple()\n", |
|||
" action_tuple.add_discrete(actions)\n", |
|||
" env.set_actions(behavior_name, action_tuple)\n", |
|||
" # Perform a step in the simulation\n", |
|||
" env.step()\n", |
|||
" return buffer, np.mean(cumulative_rewards)\n", |
|||
"\n", |
|||
" @staticmethod\n", |
|||
" def update_q_net(\n", |
|||
" q_net: VisualQNetwork, \n", |
|||
" optimizer: torch.optim, \n", |
|||
" buffer: Buffer, \n", |
|||
" action_size: int\n", |
|||
" ):\n", |
|||
" \"\"\"\n", |
|||
" Performs an update of the Q-Network using the provided optimizer and buffer\n", |
|||
" \"\"\"\n", |
|||
" BATCH_SIZE = 1000\n", |
|||
" NUM_EPOCH = 3\n", |
|||
" GAMMA = 0.9\n", |
|||
" batch_size = min(len(buffer), BATCH_SIZE)\n", |
|||
" random.shuffle(buffer)\n", |
|||
" # Split the buffer into batches\n", |
|||
" batches = [\n", |
|||
" buffer[batch_size * start : batch_size * (start + 1)]\n", |
|||
" for start in range(int(len(buffer) / batch_size))\n", |
|||
" ]\n", |
|||
" for _ in range(NUM_EPOCH):\n", |
|||
" for batch in batches:\n", |
|||
" # Create the Tensors that will be fed in the network\n", |
|||
" obs = torch.from_numpy(np.stack([ex.obs for ex in batch]))\n", |
|||
" reward = torch.from_numpy(\n", |
|||
" np.array([ex.reward for ex in batch], dtype=np.float32).reshape(-1, 1)\n", |
|||
" )\n", |
|||
" done = torch.from_numpy(\n", |
|||
" np.array([ex.done for ex in batch], dtype=np.float32).reshape(-1, 1)\n", |
|||
" )\n", |
|||
" action = torch.from_numpy(np.stack([ex.action for ex in batch]))\n", |
|||
" next_obs = torch.from_numpy(np.stack([ex.next_obs for ex in batch]))\n", |
|||
"\n", |
|||
" # Use the Bellman equation to update the Q-Network\n", |
|||
" target = (\n", |
|||
" reward\n", |
|||
" + (1.0 - done)\n", |
|||
" * GAMMA\n", |
|||
" * torch.max(q_net(next_obs).detach(), dim=1, keepdim=True).values\n", |
|||
" )\n", |
|||
" mask = torch.zeros((len(batch), action_size))\n", |
|||
" mask.scatter_(1, action, 1)\n", |
|||
" prediction = torch.sum(qnet(obs) * mask, dim=1, keepdim=True)\n", |
|||
" criterion = torch.nn.MSELoss()\n", |
|||
" loss = criterion(prediction, target)\n", |
|||
"\n", |
|||
" # Perform the backpropagation\n", |
|||
" optimizer.zero_grad()\n", |
|||
" loss.backward()\n", |
|||
" optimizer.step()\n" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "vcU4ZMAEWCvX" |
|||
}, |
|||
"source": [ |
|||
"### Run Training" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "_lIHijQfbYjh" |
|||
}, |
|||
"source": [ |
|||
"# -----------------\n", |
|||
"# This code is used to close an env that might not have been closed before\n", |
|||
"try:\n", |
|||
" env.close()\n", |
|||
"except:\n", |
|||
" pass\n", |
|||
"# -----------------\n", |
|||
"\n", |
|||
"from mlagents_envs.registry import default_registry\n", |
|||
"from mlagents_envs.environment import UnityEnvironment\n", |
|||
"import matplotlib.pyplot as plt\n", |
|||
"%matplotlib inline\n", |
|||
"\n", |
|||
"# Create the GridWorld Environment from the registry\n", |
|||
"env = default_registry[\"GridWorld\"].make()\n", |
|||
"print(\"GridWorld environment created.\")\n", |
|||
"\n", |
|||
"# Create a new Q-Network. \n", |
|||
"qnet = VisualQNetwork((64, 84, 3), 126, 5)\n", |
|||
"\n", |
|||
"experiences: Buffer = []\n", |
|||
"optim = torch.optim.Adam(qnet.parameters(), lr= 0.001)\n", |
|||
"\n", |
|||
"cumulative_rewards: List[float] = []\n", |
|||
"\n", |
|||
"# The number of training steps that will be performed\n", |
|||
"NUM_TRAINING_STEPS = 70\n", |
|||
"# The number of experiences to collect per training step\n", |
|||
"NUM_NEW_EXP = 1000\n", |
|||
"# The maximum size of the Buffer\n", |
|||
"BUFFER_SIZE = 10000\n", |
|||
"\n", |
|||
"for n in range(NUM_TRAINING_STEPS):\n", |
|||
" new_exp,_ = Trainer.generate_trajectories(env, qnet, NUM_NEW_EXP, epsilon=0.1)\n", |
|||
" random.shuffle(experiences)\n", |
|||
" if len(experiences) > BUFFER_SIZE:\n", |
|||
" experiences = experiences[:BUFFER_SIZE]\n", |
|||
" experiences.extend(new_exp)\n", |
|||
" Trainer.update_q_net(qnet, optim, experiences, 5)\n", |
|||
" _, rewards = Trainer.generate_trajectories(env, qnet, 100, epsilon=0)\n", |
|||
" cumulative_rewards.append(rewards)\n", |
|||
" print(\"Training step \", n+1, \"\\treward \", rewards)\n", |
|||
"\n", |
|||
"\n", |
|||
"env.close()\n", |
|||
"\n", |
|||
"# Show the training graph\n", |
|||
"plt.plot(range(NUM_TRAINING_STEPS), cumulative_rewards)\n" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "2ihb_gmYLUsH" |
|||
}, |
|||
"source": [ |
|||
"" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
} |
|||
] |
|||
} |
|
|||
{ |
|||
"nbformat": 4, |
|||
"nbformat_minor": 0, |
|||
"metadata": { |
|||
"colab": { |
|||
"name": "Colab-UnityEnvironment-3-SideChannel.ipynb", |
|||
"private_outputs": true, |
|||
"provenance": [], |
|||
"collapsed_sections": [], |
|||
"toc_visible": true |
|||
}, |
|||
"kernelspec": { |
|||
"name": "python3", |
|||
"display_name": "Python 3" |
|||
} |
|||
}, |
|||
"cells": [ |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "pbVXrmEsLXDt" |
|||
}, |
|||
"source": [ |
|||
"# ML-Agents Use SideChannels\n", |
|||
"<img src=\"https://raw.githubusercontent.com/Unity-Technologies/ml-agents/release_1/docs/images/3dball_big.png\" align=\"middle\" width=\"435\"/>" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "WNKTwHU3d2-l" |
|||
}, |
|||
"source": [ |
|||
"## Setup" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "htb-p1hSNX7D" |
|||
}, |
|||
"source": [ |
|||
"#@title Install Rendering Dependencies { display-mode: \"form\" }\n", |
|||
"#@markdown (You only need to run this code when using Colab's hosted runtime)\n", |
|||
"\n", |
|||
"from IPython.display import HTML, display\n", |
|||
"\n", |
|||
"def progress(value, max=100):\n", |
|||
" return HTML(\"\"\"\n", |
|||
" <progress\n", |
|||
" value='{value}'\n", |
|||
" max='{max}',\n", |
|||
" style='width: 100%'\n", |
|||
" >\n", |
|||
" {value}\n", |
|||
" </progress>\n", |
|||
" \"\"\".format(value=value, max=max))\n", |
|||
"\n", |
|||
"pro_bar = display(progress(0, 100), display_id=True)\n", |
|||
"\n", |
|||
"try:\n", |
|||
" import google.colab\n", |
|||
" IN_COLAB = True\n", |
|||
"except ImportError:\n", |
|||
" IN_COLAB = False\n", |
|||
"\n", |
|||
"if IN_COLAB:\n", |
|||
" with open('frame-buffer', 'w') as writefile:\n", |
|||
" writefile.write(\"\"\"#taken from https://gist.github.com/jterrace/2911875\n", |
|||
"XVFB=/usr/bin/Xvfb\n", |
|||
"XVFBARGS=\":1 -screen 0 1024x768x24 -ac +extension GLX +render -noreset\"\n", |
|||
"PIDFILE=./frame-buffer.pid\n", |
|||
"case \"$1\" in\n", |
|||
" start)\n", |
|||
" echo -n \"Starting virtual X frame buffer: Xvfb\"\n", |
|||
" /sbin/start-stop-daemon --start --quiet --pidfile $PIDFILE --make-pidfile --background --exec $XVFB -- $XVFBARGS\n", |
|||
" echo \".\"\n", |
|||
" ;;\n", |
|||
" stop)\n", |
|||
" echo -n \"Stopping virtual X frame buffer: Xvfb\"\n", |
|||
" /sbin/start-stop-daemon --stop --quiet --pidfile $PIDFILE\n", |
|||
" rm $PIDFILE\n", |
|||
" echo \".\"\n", |
|||
" ;;\n", |
|||
" restart)\n", |
|||
" $0 stop\n", |
|||
" $0 start\n", |
|||
" ;;\n", |
|||
" *)\n", |
|||
" echo \"Usage: /etc/init.d/xvfb {start|stop|restart}\"\n", |
|||
" exit 1\n", |
|||
"esac\n", |
|||
"exit 0\n", |
|||
" \"\"\")\n", |
|||
" pro_bar.update(progress(5, 100))\n", |
|||
" !apt-get install daemon >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(10, 100))\n", |
|||
" !apt-get install wget >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(20, 100))\n", |
|||
" !wget http://security.ubuntu.com/ubuntu/pool/main/libx/libxfont/libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(30, 100))\n", |
|||
" !wget --output-document xvfb.deb http://security.ubuntu.com/ubuntu/pool/universe/x/xorg-server/xvfb_1.18.4-0ubuntu0.12_amd64.deb >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(40, 100))\n", |
|||
" !dpkg -i libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(50, 100))\n", |
|||
" !dpkg -i xvfb.deb >/dev/null 2>&1\n", |
|||
" pro_bar.update(progress(70, 100))\n", |
|||
" !rm libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb\n", |
|||
" pro_bar.update(progress(80, 100))\n", |
|||
" !rm xvfb.deb\n", |
|||
" pro_bar.update(progress(90, 100))\n", |
|||
" !bash frame-buffer start\n", |
|||
" import os\n", |
|||
" os.environ[\"DISPLAY\"] = \":1\"\n", |
|||
"pro_bar.update(progress(100, 100))" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "Pzj7wgapAcDs" |
|||
}, |
|||
"source": [ |
|||
"### Installing ml-agents" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "N8yfQqkbebQ5" |
|||
}, |
|||
"source": [ |
|||
"try:\n", |
|||
" import mlagents\n", |
|||
" print(\"ml-agents already installed\")\n", |
|||
"except ImportError:\n", |
|||
" !pip install -q mlagents==0.25.1\n", |
|||
" print(\"Installed ml-agents\")" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "_u74YhSmW6gD" |
|||
}, |
|||
"source": [ |
|||
"## Side Channel\n", |
|||
"\n", |
|||
"SideChannels are objects that can be passed to the constructor of a UnityEnvironment or the `make()` method of a registry entry to send non Reinforcement Learning related data. \n", |
|||
"More information available [here](https://github.com/Unity-Technologies/ml-agents/blob/release_1/docs/Python-API.md#communicating-additional-information-with-the-environment)\n", |
|||
"\n", |
|||
"\n", |
|||
"\n" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "U4RXnhLRk7Uc" |
|||
}, |
|||
"source": [ |
|||
"### Engine Configuration SideChannel\n", |
|||
"The [Engine Configuration Side Channel](https://github.com/Unity-Technologies/ml-agents/blob/release_1/docs/Python-API.md#engineconfigurationchannel) is used to configure how the Unity Engine should run. \n", |
|||
"We will use the GridWorld environment to demonstrate how to use the EngineConfigurationChannel. " |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "YSf-WhxbqtLw" |
|||
}, |
|||
"source": [ |
|||
"# -----------------\n", |
|||
"# This code is used to close an env that might not have been closed before\n", |
|||
"try:\n", |
|||
" env.close()\n", |
|||
"except:\n", |
|||
" pass\n", |
|||
"# -----------------\n", |
|||
"\n", |
|||
"from mlagents_envs.registry import default_registry\n", |
|||
"env_id = \"GridWorld\"\n", |
|||
"\n", |
|||
"# Import the EngineConfigurationChannel class\n", |
|||
"from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel\n", |
|||
"\n", |
|||
"# Create the side channel\n", |
|||
"engine_config_channel = EngineConfigurationChannel()\n", |
|||
"\n", |
|||
"# Pass the side channel to the make method\n", |
|||
"# Note, the make method takes a LIST of SideChannel as input\n", |
|||
"env = default_registry[env_id].make(side_channels = [engine_config_channel])\n", |
|||
"\n", |
|||
"# Configure the Unity Engine\n", |
|||
"engine_config_channel.set_configuration_parameters(target_frame_rate = 30)\n", |
|||
"\n", |
|||
"env.reset()\n", |
|||
"\n", |
|||
"# ... \n", |
|||
"# Perform experiment on environment\n", |
|||
"# ...\n", |
|||
"\n", |
|||
"env.close()" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "h1lIx3_l24OP" |
|||
}, |
|||
"source": [ |
|||
"### Environment Parameters Channel\n", |
|||
"The [Environment Parameters Side Channel](https://github.com/Unity-Technologies/ml-agents/blob/release_1/docs/Python-API.md#environmentparameters) is used to modify environment parameters during the simulation. \n", |
|||
"We will use the GridWorld environment to demonstrate how to use the EngineConfigurationChannel. " |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"metadata": { |
|||
"id": "dhtl0mpeqxYi" |
|||
}, |
|||
"source": [ |
|||
"import matplotlib.pyplot as plt\n", |
|||
"%matplotlib inline\n", |
|||
"\n", |
|||
"# -----------------\n", |
|||
"# This code is used to close an env that might not have been closed before\n", |
|||
"try:\n", |
|||
" env.close()\n", |
|||
"except:\n", |
|||
" pass\n", |
|||
"# -----------------\n", |
|||
"\n", |
|||
"from mlagents_envs.registry import default_registry\n", |
|||
"env_id = \"GridWorld\"\n", |
|||
"\n", |
|||
"# Import the EngineConfigurationChannel class\n", |
|||
"from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParametersChannel\n", |
|||
"\n", |
|||
"# Create the side channel\n", |
|||
"env_parameters = EnvironmentParametersChannel()\n", |
|||
"\n", |
|||
"# Pass the side channel to the make method\n", |
|||
"# Note, the make method takes a LIST of SideChannel as input\n", |
|||
"env = default_registry[env_id].make(side_channels = [env_parameters])\n", |
|||
"\n", |
|||
"env.reset()\n", |
|||
"behavior_name = list(env.behavior_specs)[0] \n", |
|||
"\n", |
|||
"print(\"Observation without changing the environment parameters\")\n", |
|||
"decision_steps, terminal_steps = env.get_steps(behavior_name)\n", |
|||
"plt.imshow(decision_steps.obs[0][0,:,:,:])\n", |
|||
"plt.show()\n", |
|||
"\n", |
|||
"print(\"Increasing the dimensions of the grid from 5 to 7\")\n", |
|||
"env_parameters.set_float_parameter(\"gridSize\", 7)\n", |
|||
"print(\"Increasing the number of X from 1 to 5\")\n", |
|||
"env_parameters.set_float_parameter(\"numObstacles\", 5)\n", |
|||
"\n", |
|||
"# Any change to a SideChannel will only be effective after a step or reset\n", |
|||
"# In the GridWorld Environment, the grid's dimensions can only change at reset\n", |
|||
"env.reset()\n", |
|||
"\n", |
|||
"\n", |
|||
"decision_steps, terminal_steps = env.get_steps(behavior_name)\n", |
|||
"plt.imshow(decision_steps.obs[0][0,:,:,:])\n", |
|||
"plt.show()\n", |
|||
"\n", |
|||
"\n", |
|||
"\n", |
|||
"env.close()" |
|||
], |
|||
"execution_count": null, |
|||
"outputs": [] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": { |
|||
"id": "k1rwnVq2qyoO" |
|||
}, |
|||
"source": [ |
|||
"### Creating your own Side Channels\n", |
|||
"You can send various kinds of data between a Unity Environment and Python but you will need to [create your own implementation of a Side Channel](https://github.com/Unity-Technologies/ml-agents/blob/release_1/docs/Custom-SideChannels.md#custom-side-channels) for advanced use cases.\n" |
|||
] |
|||
} |
|||
] |
|||
} |
撰写
预览
正在加载...
取消
保存
Reference in new issue