您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
588 行
23 KiB
588 行
23 KiB
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"name": "Colab-UnityEnvironment-2-Train.ipynb",
|
|
"private_outputs": true,
|
|
"provenance": [],
|
|
"collapsed_sections": [],
|
|
"toc_visible": true
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
}
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "pbVXrmEsLXDt"
|
|
},
|
|
"source": [
|
|
"# ML-Agents Q-Learning with GridWorld\n",
|
|
"<img src=\"https://github.com/Unity-Technologies/ml-agents/blob/release_2/docs/images/gridworld.png?raw=true\" align=\"middle\" width=\"435\"/>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "WNKTwHU3d2-l"
|
|
},
|
|
"source": [
|
|
"## Setup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "htb-p1hSNX7D"
|
|
},
|
|
"source": [
|
|
"#@title Install Rendering Dependencies { display-mode: \"form\" }\n",
|
|
"#@markdown (You only need to run this code when using Colab's hosted runtime)\n",
|
|
"\n",
|
|
"from IPython.display import HTML, display\n",
|
|
"\n",
|
|
"def progress(value, max=100):\n",
|
|
" return HTML(\"\"\"\n",
|
|
" <progress\n",
|
|
" value='{value}'\n",
|
|
" max='{max}',\n",
|
|
" style='width: 100%'\n",
|
|
" >\n",
|
|
" {value}\n",
|
|
" </progress>\n",
|
|
" \"\"\".format(value=value, max=max))\n",
|
|
"\n",
|
|
"pro_bar = display(progress(0, 100), display_id=True)\n",
|
|
"\n",
|
|
"try:\n",
|
|
" import google.colab\n",
|
|
" IN_COLAB = True\n",
|
|
"except ImportError:\n",
|
|
" IN_COLAB = False\n",
|
|
"\n",
|
|
"if IN_COLAB:\n",
|
|
" with open('frame-buffer', 'w') as writefile:\n",
|
|
" writefile.write(\"\"\"#taken from https://gist.github.com/jterrace/2911875\n",
|
|
"XVFB=/usr/bin/Xvfb\n",
|
|
"XVFBARGS=\":1 -screen 0 1024x768x24 -ac +extension GLX +render -noreset\"\n",
|
|
"PIDFILE=./frame-buffer.pid\n",
|
|
"case \"$1\" in\n",
|
|
" start)\n",
|
|
" echo -n \"Starting virtual X frame buffer: Xvfb\"\n",
|
|
" /sbin/start-stop-daemon --start --quiet --pidfile $PIDFILE --make-pidfile --background --exec $XVFB -- $XVFBARGS\n",
|
|
" echo \".\"\n",
|
|
" ;;\n",
|
|
" stop)\n",
|
|
" echo -n \"Stopping virtual X frame buffer: Xvfb\"\n",
|
|
" /sbin/start-stop-daemon --stop --quiet --pidfile $PIDFILE\n",
|
|
" rm $PIDFILE\n",
|
|
" echo \".\"\n",
|
|
" ;;\n",
|
|
" restart)\n",
|
|
" $0 stop\n",
|
|
" $0 start\n",
|
|
" ;;\n",
|
|
" *)\n",
|
|
" echo \"Usage: /etc/init.d/xvfb {start|stop|restart}\"\n",
|
|
" exit 1\n",
|
|
"esac\n",
|
|
"exit 0\n",
|
|
" \"\"\")\n",
|
|
" pro_bar.update(progress(5, 100))\n",
|
|
" !apt-get install daemon >/dev/null 2>&1\n",
|
|
" pro_bar.update(progress(10, 100))\n",
|
|
" !apt-get install wget >/dev/null 2>&1\n",
|
|
" pro_bar.update(progress(20, 100))\n",
|
|
" !wget http://security.ubuntu.com/ubuntu/pool/main/libx/libxfont/libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb >/dev/null 2>&1\n",
|
|
" pro_bar.update(progress(30, 100))\n",
|
|
" !wget --output-document xvfb.deb http://security.ubuntu.com/ubuntu/pool/universe/x/xorg-server/xvfb_1.18.4-0ubuntu0.12_amd64.deb >/dev/null 2>&1\n",
|
|
" pro_bar.update(progress(40, 100))\n",
|
|
" !dpkg -i libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb >/dev/null 2>&1\n",
|
|
" pro_bar.update(progress(50, 100))\n",
|
|
" !dpkg -i xvfb.deb >/dev/null 2>&1\n",
|
|
" pro_bar.update(progress(70, 100))\n",
|
|
" !rm libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb\n",
|
|
" pro_bar.update(progress(80, 100))\n",
|
|
" !rm xvfb.deb\n",
|
|
" pro_bar.update(progress(90, 100))\n",
|
|
" !bash frame-buffer start\n",
|
|
" import os\n",
|
|
" os.environ[\"DISPLAY\"] = \":1\"\n",
|
|
"pro_bar.update(progress(100, 100))"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Pzj7wgapAcDs"
|
|
},
|
|
"source": [
|
|
"### Installing ml-agents"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "N8yfQqkbebQ5"
|
|
},
|
|
"source": [
|
|
"try:\n",
|
|
" import mlagents\n",
|
|
" print(\"ml-agents already installed\")\n",
|
|
"except ImportError:\n",
|
|
" !pip install -q mlagents==0.25.1\n",
|
|
" print(\"Installed ml-agents\")"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "jz81TWAkbuFY"
|
|
},
|
|
"source": [
|
|
"## Train the GridWorld Environment with Q-Learning"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "29n3dt1Zx5ty"
|
|
},
|
|
"source": [
|
|
"### What is the GridWorld Environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "pZhVRfdoyPmv"
|
|
},
|
|
"source": [
|
|
"The [GridWorld](https://github.com/Unity-Technologies/ml-agents/blob/release_2/docs/Learning-Environment-Examples.md#gridworld) Environment is a simple Unity visual environment. The Agent is a blue square in a 3x3 grid that is trying to reach a green __`+`__ while avoiding a red __`x`__.\n",
|
|
"\n",
|
|
"The observation is an image obtained by a camera on top of the grid.\n",
|
|
"\n",
|
|
"The Action can be one of 5 : \n",
|
|
" - Do not move\n",
|
|
" - Move up\n",
|
|
" - Move down\n",
|
|
" - Move right\n",
|
|
" - Move left\n",
|
|
"\n",
|
|
"The Agent receives a reward of _1.0_ if it reaches the green __`+`__, a penalty of _-1.0_ if it touches the red __`x`__ and a penalty of `-0.01` at every step (to force the Agent to solve the task as fast as possible)\n",
|
|
"\n",
|
|
"__Note__ There are 9 Agents, each in their own grid, at once in the environment. This alows for faster data collection.\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "4Gt-ZydJyJWD"
|
|
},
|
|
"source": [
|
|
"### The Q-Learning Algorithm\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "KA1qOgfq0Xdv"
|
|
},
|
|
"source": [
|
|
"In this Notebook, we will implement a very simple Q-Learning algorithm. We will use [pytorch](https://pytorch.org/) to do so.\n",
|
|
"\n",
|
|
"Below is the code to create the neural network we will use in the Notebook."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "q79rUp_Sx6A_"
|
|
},
|
|
"source": [
|
|
"import torch\n",
|
|
"from typing import Tuple\n",
|
|
"from math import floor\n",
|
|
"\n",
|
|
"\n",
|
|
"class VisualQNetwork(torch.nn.Module):\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" input_shape: Tuple[int, int, int], \n",
|
|
" encoding_size: int, \n",
|
|
" output_size: int\n",
|
|
" ):\n",
|
|
" \"\"\"\n",
|
|
" Creates a neural network that takes as input a batch of images (3\n",
|
|
" dimensional tensors) and outputs a batch of outputs (1 dimensional\n",
|
|
" tensors)\n",
|
|
" \"\"\"\n",
|
|
" super(VisualQNetwork, self).__init__()\n",
|
|
" height = input_shape[0]\n",
|
|
" width = input_shape[1]\n",
|
|
" initial_channels = input_shape[2]\n",
|
|
" conv_1_hw = self.conv_output_shape((height, width), 8, 4)\n",
|
|
" conv_2_hw = self.conv_output_shape(conv_1_hw, 4, 2)\n",
|
|
" self.final_flat = conv_2_hw[0] * conv_2_hw[1] * 32\n",
|
|
" self.conv1 = torch.nn.Conv2d(initial_channels, 16, [8, 8], [4, 4])\n",
|
|
" self.conv2 = torch.nn.Conv2d(16, 32, [4, 4], [2, 2])\n",
|
|
" self.dense1 = torch.nn.Linear(self.final_flat, encoding_size)\n",
|
|
" self.dense2 = torch.nn.Linear(encoding_size, output_size)\n",
|
|
"\n",
|
|
" def forward(self, visual_obs: torch.tensor):\n",
|
|
" visual_obs = visual_obs.permute(0, 3, 1, 2)\n",
|
|
" conv_1 = torch.relu(self.conv1(visual_obs))\n",
|
|
" conv_2 = torch.relu(self.conv2(conv_1))\n",
|
|
" hidden = self.dense1(conv_2.reshape([-1, self.final_flat]))\n",
|
|
" hidden = torch.relu(hidden)\n",
|
|
" hidden = self.dense2(hidden)\n",
|
|
" return hidden\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def conv_output_shape(\n",
|
|
" h_w: Tuple[int, int],\n",
|
|
" kernel_size: int = 1,\n",
|
|
" stride: int = 1,\n",
|
|
" pad: int = 0,\n",
|
|
" dilation: int = 1,\n",
|
|
" ):\n",
|
|
" \"\"\"\n",
|
|
" Computes the height and width of the output of a convolution layer.\n",
|
|
" \"\"\"\n",
|
|
" h = floor(\n",
|
|
" ((h_w[0] + (2 * pad) - (dilation * (kernel_size - 1)) - 1) / stride) + 1\n",
|
|
" )\n",
|
|
" w = floor(\n",
|
|
" ((h_w[1] + (2 * pad) - (dilation * (kernel_size - 1)) - 1) / stride) + 1\n",
|
|
" )\n",
|
|
" return h, w\n"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "EZoaEBAo2L0F"
|
|
},
|
|
"source": [
|
|
"We will now create a few classes to help us store the data we will use to train the Q-Learning algorithm."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "L772fe2q39DO"
|
|
},
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"from typing import NamedTuple, List\n",
|
|
"\n",
|
|
"\n",
|
|
"class Experience(NamedTuple):\n",
|
|
" \"\"\"\n",
|
|
" An experience contains the data of one Agent transition.\n",
|
|
" - Observation\n",
|
|
" - Action\n",
|
|
" - Reward\n",
|
|
" - Done flag\n",
|
|
" - Next Observation\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" obs: np.ndarray\n",
|
|
" action: np.ndarray\n",
|
|
" reward: float\n",
|
|
" done: bool\n",
|
|
" next_obs: np.ndarray\n",
|
|
"\n",
|
|
"# A Trajectory is an ordered sequence of Experiences\n",
|
|
"Trajectory = List[Experience]\n",
|
|
"\n",
|
|
"# A Buffer is an unordered list of Experiences from multiple Trajectories\n",
|
|
"Buffer = List[Experience]"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "6HsM1d5I3_Tj"
|
|
},
|
|
"source": [
|
|
"Now, we can create our trainer class. The role of this trainer is to collect data from the Environment according to a Policy, and then train the Q-Network with that data."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "KkzBoRJCb18t"
|
|
},
|
|
"source": [
|
|
"from mlagents_envs.environment import ActionTuple, BaseEnv\n",
|
|
"from typing import Dict\n",
|
|
"import random\n",
|
|
"\n",
|
|
"\n",
|
|
"class Trainer:\n",
|
|
" @staticmethod\n",
|
|
" def generate_trajectories(\n",
|
|
" env: BaseEnv, q_net: VisualQNetwork, buffer_size: int, epsilon: float\n",
|
|
" ):\n",
|
|
" \"\"\"\n",
|
|
" Given a Unity Environment and a Q-Network, this method will generate a\n",
|
|
" buffer of Experiences obtained by running the Environment with the Policy\n",
|
|
" derived from the Q-Network.\n",
|
|
" :param BaseEnv: The UnityEnvironment used.\n",
|
|
" :param q_net: The Q-Network used to collect the data.\n",
|
|
" :param buffer_size: The minimum size of the buffer this method will return.\n",
|
|
" :param epsilon: Will add a random normal variable with standard deviation.\n",
|
|
" epsilon to the value heads of the Q-Network to encourage exploration.\n",
|
|
" :returns: a Tuple containing the created buffer and the average cumulative\n",
|
|
" the Agents obtained.\n",
|
|
" \"\"\"\n",
|
|
" # Create an empty Buffer\n",
|
|
" buffer: Buffer = []\n",
|
|
"\n",
|
|
" # Reset the environment\n",
|
|
" env.reset()\n",
|
|
" # Read and store the Behavior Name of the Environment\n",
|
|
" behavior_name = list(env.behavior_specs)[0]\n",
|
|
" # Read and store the Behavior Specs of the Environment\n",
|
|
" spec = env.behavior_specs[behavior_name]\n",
|
|
"\n",
|
|
" # Create a Mapping from AgentId to Trajectories. This will help us create\n",
|
|
" # trajectories for each Agents\n",
|
|
" dict_trajectories_from_agent: Dict[int, Trajectory] = {}\n",
|
|
" # Create a Mapping from AgentId to the last observation of the Agent\n",
|
|
" dict_last_obs_from_agent: Dict[int, np.ndarray] = {}\n",
|
|
" # Create a Mapping from AgentId to the last observation of the Agent\n",
|
|
" dict_last_action_from_agent: Dict[int, np.ndarray] = {}\n",
|
|
" # Create a Mapping from AgentId to cumulative reward (Only for reporting)\n",
|
|
" dict_cumulative_reward_from_agent: Dict[int, float] = {}\n",
|
|
" # Create a list to store the cumulative rewards obtained so far\n",
|
|
" cumulative_rewards: List[float] = []\n",
|
|
"\n",
|
|
" while len(buffer) < buffer_size: # While not enough data in the buffer\n",
|
|
" # Get the Decision Steps and Terminal Steps of the Agents\n",
|
|
" decision_steps, terminal_steps = env.get_steps(behavior_name)\n",
|
|
"\n",
|
|
" # For all Agents with a Terminal Step:\n",
|
|
" for agent_id_terminated in terminal_steps:\n",
|
|
" # Create its last experience (is last because the Agent terminated)\n",
|
|
" last_experience = Experience(\n",
|
|
" obs=dict_last_obs_from_agent[agent_id_terminated].copy(),\n",
|
|
" reward=terminal_steps[agent_id_terminated].reward,\n",
|
|
" done=not terminal_steps[agent_id_terminated].interrupted,\n",
|
|
" action=dict_last_action_from_agent[agent_id_terminated].copy(),\n",
|
|
" next_obs=terminal_steps[agent_id_terminated].obs[0],\n",
|
|
" )\n",
|
|
" # Clear its last observation and action (Since the trajectory is over)\n",
|
|
" dict_last_obs_from_agent.pop(agent_id_terminated)\n",
|
|
" dict_last_action_from_agent.pop(agent_id_terminated)\n",
|
|
" # Report the cumulative reward\n",
|
|
" cumulative_reward = (\n",
|
|
" dict_cumulative_reward_from_agent.pop(agent_id_terminated)\n",
|
|
" + terminal_steps[agent_id_terminated].reward\n",
|
|
" )\n",
|
|
" cumulative_rewards.append(cumulative_reward)\n",
|
|
" # Add the Trajectory and the last experience to the buffer\n",
|
|
" buffer.extend(dict_trajectories_from_agent.pop(agent_id_terminated))\n",
|
|
" buffer.append(last_experience)\n",
|
|
"\n",
|
|
" # For all Agents with a Decision Step:\n",
|
|
" for agent_id_decisions in decision_steps:\n",
|
|
" # If the Agent does not have a Trajectory, create an empty one\n",
|
|
" if agent_id_decisions not in dict_trajectories_from_agent:\n",
|
|
" dict_trajectories_from_agent[agent_id_decisions] = []\n",
|
|
" dict_cumulative_reward_from_agent[agent_id_decisions] = 0\n",
|
|
"\n",
|
|
" # If the Agent requesting a decision has a \"last observation\"\n",
|
|
" if agent_id_decisions in dict_last_obs_from_agent:\n",
|
|
" # Create an Experience from the last observation and the Decision Step\n",
|
|
" exp = Experience(\n",
|
|
" obs=dict_last_obs_from_agent[agent_id_decisions].copy(),\n",
|
|
" reward=decision_steps[agent_id_decisions].reward,\n",
|
|
" done=False,\n",
|
|
" action=dict_last_action_from_agent[agent_id_decisions].copy(),\n",
|
|
" next_obs=decision_steps[agent_id_decisions].obs[0],\n",
|
|
" )\n",
|
|
" # Update the Trajectory of the Agent and its cumulative reward\n",
|
|
" dict_trajectories_from_agent[agent_id_decisions].append(exp)\n",
|
|
" dict_cumulative_reward_from_agent[agent_id_decisions] += (\n",
|
|
" decision_steps[agent_id_decisions].reward\n",
|
|
" )\n",
|
|
" # Store the observation as the new \"last observation\"\n",
|
|
" dict_last_obs_from_agent[agent_id_decisions] = (\n",
|
|
" decision_steps[agent_id_decisions].obs[0]\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Generate an action for all the Agents that requested a decision\n",
|
|
" # Compute the values for each action given the observation\n",
|
|
" actions_values = (\n",
|
|
" q_net(torch.from_numpy(decision_steps.obs[0])).detach().numpy()\n",
|
|
" )\n",
|
|
" # Add some noise with epsilon to the values\n",
|
|
" actions_values += epsilon * (\n",
|
|
" np.random.randn(actions_values.shape[0], actions_values.shape[1])\n",
|
|
" ).astype(np.float32)\n",
|
|
" # Pick the best action using argmax\n",
|
|
" actions = np.argmax(actions_values, axis=1)\n",
|
|
" actions.resize((len(decision_steps), 1))\n",
|
|
" # Store the action that was picked, it will be put in the trajectory later\n",
|
|
" for agent_index, agent_id in enumerate(decision_steps.agent_id):\n",
|
|
" dict_last_action_from_agent[agent_id] = actions[agent_index]\n",
|
|
"\n",
|
|
" # Set the actions in the environment\n",
|
|
" # Unity Environments expect ActionTuple instances.\n",
|
|
" action_tuple = ActionTuple()\n",
|
|
" action_tuple.add_discrete(actions)\n",
|
|
" env.set_actions(behavior_name, action_tuple)\n",
|
|
" # Perform a step in the simulation\n",
|
|
" env.step()\n",
|
|
" return buffer, np.mean(cumulative_rewards)\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def update_q_net(\n",
|
|
" q_net: VisualQNetwork, \n",
|
|
" optimizer: torch.optim, \n",
|
|
" buffer: Buffer, \n",
|
|
" action_size: int\n",
|
|
" ):\n",
|
|
" \"\"\"\n",
|
|
" Performs an update of the Q-Network using the provided optimizer and buffer\n",
|
|
" \"\"\"\n",
|
|
" BATCH_SIZE = 1000\n",
|
|
" NUM_EPOCH = 3\n",
|
|
" GAMMA = 0.9\n",
|
|
" batch_size = min(len(buffer), BATCH_SIZE)\n",
|
|
" random.shuffle(buffer)\n",
|
|
" # Split the buffer into batches\n",
|
|
" batches = [\n",
|
|
" buffer[batch_size * start : batch_size * (start + 1)]\n",
|
|
" for start in range(int(len(buffer) / batch_size))\n",
|
|
" ]\n",
|
|
" for _ in range(NUM_EPOCH):\n",
|
|
" for batch in batches:\n",
|
|
" # Create the Tensors that will be fed in the network\n",
|
|
" obs = torch.from_numpy(np.stack([ex.obs for ex in batch]))\n",
|
|
" reward = torch.from_numpy(\n",
|
|
" np.array([ex.reward for ex in batch], dtype=np.float32).reshape(-1, 1)\n",
|
|
" )\n",
|
|
" done = torch.from_numpy(\n",
|
|
" np.array([ex.done for ex in batch], dtype=np.float32).reshape(-1, 1)\n",
|
|
" )\n",
|
|
" action = torch.from_numpy(np.stack([ex.action for ex in batch]))\n",
|
|
" next_obs = torch.from_numpy(np.stack([ex.next_obs for ex in batch]))\n",
|
|
"\n",
|
|
" # Use the Bellman equation to update the Q-Network\n",
|
|
" target = (\n",
|
|
" reward\n",
|
|
" + (1.0 - done)\n",
|
|
" * GAMMA\n",
|
|
" * torch.max(q_net(next_obs).detach(), dim=1, keepdim=True).values\n",
|
|
" )\n",
|
|
" mask = torch.zeros((len(batch), action_size))\n",
|
|
" mask.scatter_(1, action, 1)\n",
|
|
" prediction = torch.sum(qnet(obs) * mask, dim=1, keepdim=True)\n",
|
|
" criterion = torch.nn.MSELoss()\n",
|
|
" loss = criterion(prediction, target)\n",
|
|
"\n",
|
|
" # Perform the backpropagation\n",
|
|
" optimizer.zero_grad()\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "vcU4ZMAEWCvX"
|
|
},
|
|
"source": [
|
|
"### Run Training"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "_lIHijQfbYjh"
|
|
},
|
|
"source": [
|
|
"# -----------------\n",
|
|
"# This code is used to close an env that might not have been closed before\n",
|
|
"try:\n",
|
|
" env.close()\n",
|
|
"except:\n",
|
|
" pass\n",
|
|
"# -----------------\n",
|
|
"\n",
|
|
"from mlagents_envs.registry import default_registry\n",
|
|
"from mlagents_envs.environment import UnityEnvironment\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"%matplotlib inline\n",
|
|
"\n",
|
|
"# Create the GridWorld Environment from the registry\n",
|
|
"env = default_registry[\"GridWorld\"].make()\n",
|
|
"print(\"GridWorld environment created.\")\n",
|
|
"\n",
|
|
"# Create a new Q-Network. \n",
|
|
"qnet = VisualQNetwork((64, 84, 3), 126, 5)\n",
|
|
"\n",
|
|
"experiences: Buffer = []\n",
|
|
"optim = torch.optim.Adam(qnet.parameters(), lr= 0.001)\n",
|
|
"\n",
|
|
"cumulative_rewards: List[float] = []\n",
|
|
"\n",
|
|
"# The number of training steps that will be performed\n",
|
|
"NUM_TRAINING_STEPS = 70\n",
|
|
"# The number of experiences to collect per training step\n",
|
|
"NUM_NEW_EXP = 1000\n",
|
|
"# The maximum size of the Buffer\n",
|
|
"BUFFER_SIZE = 10000\n",
|
|
"\n",
|
|
"for n in range(NUM_TRAINING_STEPS):\n",
|
|
" new_exp,_ = Trainer.generate_trajectories(env, qnet, NUM_NEW_EXP, epsilon=0.1)\n",
|
|
" random.shuffle(experiences)\n",
|
|
" if len(experiences) > BUFFER_SIZE:\n",
|
|
" experiences = experiences[:BUFFER_SIZE]\n",
|
|
" experiences.extend(new_exp)\n",
|
|
" Trainer.update_q_net(qnet, optim, experiences, 5)\n",
|
|
" _, rewards = Trainer.generate_trajectories(env, qnet, 100, epsilon=0)\n",
|
|
" cumulative_rewards.append(rewards)\n",
|
|
" print(\"Training step \", n+1, \"\\treward \", rewards)\n",
|
|
"\n",
|
|
"\n",
|
|
"env.close()\n",
|
|
"\n",
|
|
"# Show the training graph\n",
|
|
"plt.plot(range(NUM_TRAINING_STEPS), cumulative_rewards)\n"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "2ihb_gmYLUsH"
|
|
},
|
|
"source": [
|
|
""
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
}
|
|
]
|
|
}
|