您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
494 行
15 KiB
494 行
15 KiB
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"name": "Colab-UnityEnvironment-1-Run.ipynb",
|
|
"private_outputs": true,
|
|
"provenance": [],
|
|
"collapsed_sections": [],
|
|
"toc_visible": true
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
}
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "pbVXrmEsLXDt"
|
|
},
|
|
"source": [
|
|
"# ML-Agents Open a UnityEnvironment\n",
|
|
"<img src=\"https://github.com/Unity-Technologies/ml-agents/blob/release_1/docs/images/image-banner.png?raw=true\" align=\"middle\" width=\"435\"/>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "WNKTwHU3d2-l"
|
|
},
|
|
"source": [
|
|
"## Setup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "htb-p1hSNX7D"
|
|
},
|
|
"source": [
|
|
"#@title Install Rendering Dependencies { display-mode: \"form\" }\n",
|
|
"#@markdown (You only need to run this code when using Colab's hosted runtime)\n",
|
|
"\n",
|
|
"from IPython.display import HTML, display\n",
|
|
"\n",
|
|
"def progress(value, max=100):\n",
|
|
" return HTML(\"\"\"\n",
|
|
" <progress\n",
|
|
" value='{value}'\n",
|
|
" max='{max}',\n",
|
|
" style='width: 100%'\n",
|
|
" >\n",
|
|
" {value}\n",
|
|
" </progress>\n",
|
|
" \"\"\".format(value=value, max=max))\n",
|
|
"\n",
|
|
"pro_bar = display(progress(0, 100), display_id=True)\n",
|
|
"\n",
|
|
"try:\n",
|
|
" import google.colab\n",
|
|
" IN_COLAB = True\n",
|
|
"except ImportError:\n",
|
|
" IN_COLAB = False\n",
|
|
"\n",
|
|
"if IN_COLAB:\n",
|
|
" with open('frame-buffer', 'w') as writefile:\n",
|
|
" writefile.write(\"\"\"#taken from https://gist.github.com/jterrace/2911875\n",
|
|
"XVFB=/usr/bin/Xvfb\n",
|
|
"XVFBARGS=\":1 -screen 0 1024x768x24 -ac +extension GLX +render -noreset\"\n",
|
|
"PIDFILE=./frame-buffer.pid\n",
|
|
"case \"$1\" in\n",
|
|
" start)\n",
|
|
" echo -n \"Starting virtual X frame buffer: Xvfb\"\n",
|
|
" /sbin/start-stop-daemon --start --quiet --pidfile $PIDFILE --make-pidfile --background --exec $XVFB -- $XVFBARGS\n",
|
|
" echo \".\"\n",
|
|
" ;;\n",
|
|
" stop)\n",
|
|
" echo -n \"Stopping virtual X frame buffer: Xvfb\"\n",
|
|
" /sbin/start-stop-daemon --stop --quiet --pidfile $PIDFILE\n",
|
|
" rm $PIDFILE\n",
|
|
" echo \".\"\n",
|
|
" ;;\n",
|
|
" restart)\n",
|
|
" $0 stop\n",
|
|
" $0 start\n",
|
|
" ;;\n",
|
|
" *)\n",
|
|
" echo \"Usage: /etc/init.d/xvfb {start|stop|restart}\"\n",
|
|
" exit 1\n",
|
|
"esac\n",
|
|
"exit 0\n",
|
|
" \"\"\")\n",
|
|
" pro_bar.update(progress(5, 100))\n",
|
|
" !apt-get install daemon >/dev/null 2>&1\n",
|
|
" pro_bar.update(progress(10, 100))\n",
|
|
" !apt-get install wget >/dev/null 2>&1\n",
|
|
" pro_bar.update(progress(20, 100))\n",
|
|
" !wget http://security.ubuntu.com/ubuntu/pool/main/libx/libxfont/libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb >/dev/null 2>&1\n",
|
|
" pro_bar.update(progress(30, 100))\n",
|
|
" !wget --output-document xvfb.deb http://security.ubuntu.com/ubuntu/pool/universe/x/xorg-server/xvfb_1.18.4-0ubuntu0.12_amd64.deb >/dev/null 2>&1\n",
|
|
" pro_bar.update(progress(40, 100))\n",
|
|
" !dpkg -i libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb >/dev/null 2>&1\n",
|
|
" pro_bar.update(progress(50, 100))\n",
|
|
" !dpkg -i xvfb.deb >/dev/null 2>&1\n",
|
|
" pro_bar.update(progress(70, 100))\n",
|
|
" !rm libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb\n",
|
|
" pro_bar.update(progress(80, 100))\n",
|
|
" !rm xvfb.deb\n",
|
|
" pro_bar.update(progress(90, 100))\n",
|
|
" !bash frame-buffer start\n",
|
|
" import os\n",
|
|
" os.environ[\"DISPLAY\"] = \":1\"\n",
|
|
"pro_bar.update(progress(100, 100))"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Pzj7wgapAcDs"
|
|
},
|
|
"source": [
|
|
"### Installing ml-agents"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "N8yfQqkbebQ5"
|
|
},
|
|
"source": [
|
|
"try:\n",
|
|
" import mlagents\n",
|
|
" print(\"ml-agents already installed\")\n",
|
|
"except ImportError:\n",
|
|
" !pip install -q mlagents==0.25.1\n",
|
|
" print(\"Installed ml-agents\")"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "_u74YhSmW6gD"
|
|
},
|
|
"source": [
|
|
"## Run the Environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "DpZPbRvRuLZv"
|
|
},
|
|
"source": [
|
|
"#@title Select Environment { display-mode: \"form\" }\n",
|
|
"env_id = \"GridWorld\" #@param ['Basic', '3DBall', '3DBallHard', 'GridWorld', 'Hallway', 'VisualHallway', 'CrawlerDynamicTarget', 'CrawlerStaticTarget', 'Bouncer', 'SoccerTwos', 'PushBlock', 'VisualPushBlock', 'WallJump', 'Tennis', 'Reacher', 'Pyramids', 'VisualPyramids', 'Walker', 'FoodCollector', 'VisualFoodCollector', 'StrikersVsGoalie', 'WormStaticTarget', 'WormDynamicTarget']\n"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "P-r_cB2rqp5x"
|
|
},
|
|
"source": [
|
|
"### Start Environment from the registry"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "YSf-WhxbqtLw"
|
|
},
|
|
"source": [
|
|
"# -----------------\n",
|
|
"# This code is used to close an env that might not have been closed before\n",
|
|
"try:\n",
|
|
" env.close()\n",
|
|
"except:\n",
|
|
" pass\n",
|
|
"# -----------------\n",
|
|
"\n",
|
|
"from mlagents_envs.registry import default_registry\n",
|
|
"\n",
|
|
"env = default_registry[env_id].make()"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "h1lIx3_l24OP"
|
|
},
|
|
"source": [
|
|
"### Reset the environment\n",
|
|
"To reset the environment, simply call `env.reset()`. This method takes no argument and returns nothing but will send a signal to the simulation to reset."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "dhtl0mpeqxYi"
|
|
},
|
|
"source": [
|
|
"env.reset()"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "k1rwnVq2qyoO"
|
|
},
|
|
"source": [
|
|
"### Behavior Specs\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "TrD0rSv92T8A"
|
|
},
|
|
"source": [
|
|
"#### Get the Behavior Specs from the Environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "a7KatdThq7OV"
|
|
},
|
|
"source": [
|
|
"# We will only consider the first Behavior\n",
|
|
"behavior_name = list(env.behavior_specs)[0] \n",
|
|
"print(f\"Name of the behavior : {behavior_name}\")\n",
|
|
"spec = env.behavior_specs[behavior_name]"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "m1L8DHADrAbe"
|
|
},
|
|
"source": [
|
|
"#### Get the Observation Space from the Behavior Specs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "PqDTV5mSrJF5"
|
|
},
|
|
"source": [
|
|
"# Examine the number of observations per Agent\n",
|
|
"print(\"Number of observations : \", len(spec.observation_specs))\n",
|
|
"\n",
|
|
"# Is there a visual observation ?\n",
|
|
"# Visual observation have 3 dimensions: Height, Width and number of channels\n",
|
|
"vis_obs = any(len(spec.shape) == 3 for spec in spec.observation_specs)\n",
|
|
"print(\"Is there a visual observation ?\", vis_obs)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "yVLN_wbG1G5-"
|
|
},
|
|
"source": [
|
|
"#### Get the Action Space from the Behavior Specs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "M9zk1-az1L-G"
|
|
},
|
|
"source": [
|
|
"# Is the Action continuous or multi-discrete ?\n",
|
|
"if spec.action_spec.continuous_size > 0:\n",
|
|
" print(f\"There are {spec.action_spec.continuous_size} continuous actions\")\n",
|
|
"if spec.action_spec.is_discrete():\n",
|
|
" print(f\"There are {spec.action_spec.discrete_size} discrete actions\")\n",
|
|
"\n",
|
|
"\n",
|
|
"# How many actions are possible ?\n",
|
|
"#print(f\"There are {spec.action_size} action(s)\")\n",
|
|
"\n",
|
|
"# For discrete actions only : How many different options does each action has ?\n",
|
|
"if spec.action_spec.discrete_size > 0:\n",
|
|
" for action, branch_size in enumerate(spec.action_spec.discrete_branches):\n",
|
|
" print(f\"Action number {action} has {branch_size} different options\")\n",
|
|
" \n"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "3cX07SGw22Lm"
|
|
},
|
|
"source": [
|
|
"### Stepping the environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "xO5p0s0prfsQ"
|
|
},
|
|
"source": [
|
|
"#### Get the steps from the Environment\n",
|
|
"You can do this with the `env.get_steps(behavior_name)` method. If there are multiple behaviors in the Environment, you can call this method with each of the behavior's names.\n",
|
|
"_Note_ This will not move the simulation forward."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "ePZtcHXUrjyf"
|
|
},
|
|
"source": [
|
|
"decision_steps, terminal_steps = env.get_steps(behavior_name)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "f-Oj3ix530mx"
|
|
},
|
|
"source": [
|
|
"#### Set actions for each behavior\n",
|
|
"You can set the actions for the Agents of a Behavior by calling `env.set_actions()` you will need to specify the behavior name and pass a tensor of dimension 2. The first dimension of the action must be equal to the number of Agents that requested a decision during the step."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "KB-nxfbw337g"
|
|
},
|
|
"source": [
|
|
"env.set_actions(behavior_name, spec.action_spec.empty_action(len(decision_steps)))"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "MQCybRs84cmq"
|
|
},
|
|
"source": [
|
|
"#### Move the simulation forward\n",
|
|
"Call `env.step()` to move the simulation forward. The simulation will progress until an Agent requestes a decision or terminates."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "nl3K40ZR4bh2"
|
|
},
|
|
"source": [
|
|
"env.step()"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "i9gdextn2vJy"
|
|
},
|
|
"source": [
|
|
"### Observations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "iAMqnnddr8Xo"
|
|
},
|
|
"source": [
|
|
"#### Show the observations for one of the Agents\n",
|
|
"`DecisionSteps.obs` is a tuple containing all of the observations for all of the Agents with the provided Behavior name.\n",
|
|
"Each value in the tuple is an observation tensor containing the observation data for all of the agents."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "OJpta61TsBiO"
|
|
},
|
|
"source": [
|
|
"import matplotlib.pyplot as plt\n",
|
|
"%matplotlib inline\n",
|
|
"\n",
|
|
"for index, obs_spec in enumerate(spec.observation_specs):\n",
|
|
" if len(obs_spec.shape) == 3:\n",
|
|
" print(\"Here is the first visual observation\")\n",
|
|
" plt.imshow(decision_steps.obs[index][0,:,:,:])\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
"for index, obs_spec in enumerate(spec.observation_specs):\n",
|
|
" if len(obs_spec.shape) == 1:\n",
|
|
" print(\"First vector observations : \", decision_steps.obs[index][0,:])"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "y60u21sys8kA"
|
|
},
|
|
"source": [
|
|
"### Run the Environment for a few episodes"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "a2uQUsoMtIUK"
|
|
},
|
|
"source": [
|
|
"for episode in range(3):\n",
|
|
" env.reset()\n",
|
|
" decision_steps, terminal_steps = env.get_steps(behavior_name)\n",
|
|
" tracked_agent = -1 # -1 indicates not yet tracking\n",
|
|
" done = False # For the tracked_agent\n",
|
|
" episode_rewards = 0 # For the tracked_agent\n",
|
|
" while not done:\n",
|
|
" # Track the first agent we see if not tracking \n",
|
|
" # Note : len(decision_steps) = [number of agents that requested a decision]\n",
|
|
" if tracked_agent == -1 and len(decision_steps) >= 1:\n",
|
|
" tracked_agent = decision_steps.agent_id[0] \n",
|
|
"\n",
|
|
" # Generate an action for all agents\n",
|
|
" action = spec.action_spec.random_action(len(decision_steps))\n",
|
|
"\n",
|
|
" # Set the actions\n",
|
|
" env.set_actions(behavior_name, action)\n",
|
|
"\n",
|
|
" # Move the simulation forward\n",
|
|
" env.step()\n",
|
|
"\n",
|
|
" # Get the new simulation results\n",
|
|
" decision_steps, terminal_steps = env.get_steps(behavior_name)\n",
|
|
" if tracked_agent in decision_steps: # The agent requested a decision\n",
|
|
" episode_rewards += decision_steps[tracked_agent].reward\n",
|
|
" if tracked_agent in terminal_steps: # The agent terminated its episode\n",
|
|
" episode_rewards += terminal_steps[tracked_agent].reward\n",
|
|
" done = True\n",
|
|
" print(f\"Total rewards for episode {episode} is {episode_rewards}\")\n"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "R-3grXNEtJPa"
|
|
},
|
|
"source": [
|
|
"### Close the Environment to free the port it is using"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "vdWG6_SqtNtv"
|
|
},
|
|
"source": [
|
|
"env.close()\n",
|
|
"print(\"Closed environment\")"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
}
|
|
]
|
|
}
|