您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
132 行
15 KiB
132 行
15 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import csv\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def read_reward(dirname):\n",
|
|
" reward = []\n",
|
|
" with open(dirname+'/Simple.csv', newline='') as csvfile:\n",
|
|
" reader = csv.DictReader(csvfile)\n",
|
|
" for row in reader:\n",
|
|
" reward.append(float(row['Environment/Cumulative Reward']))\n",
|
|
" return reward"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def smooth(scalars, weight: float): # Weight between 0 and 1\n",
|
|
" last = scalars[0] # First value in the plot (first timestep)\n",
|
|
" smoothed = list()\n",
|
|
" for point in scalars:\n",
|
|
" smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value\n",
|
|
" smoothed.append(smoothed_val) # Save it\n",
|
|
" last = smoothed_val # Anchor the last smoothed value\n",
|
|
"\n",
|
|
" return smoothed"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"seeds = [0,1,2,3,4]\n",
|
|
"obs_types = [\"long\", \"longpre\"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"rewards_ppo = {} \n",
|
|
"for obs in obs_types:\n",
|
|
" rewards_ppo[obs] = []\n",
|
|
" for s in seeds:\n",
|
|
" rewards_ppo[obs].append(read_reward(\"transfer_results/ppo_{}_s{}\".format(obs, s)))\n",
|
|
" rewards_ppo[obs] = np.array(rewards_ppo[obs])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"xs = list(range(len(rewards_ppo[\"long\"])))\n",
|
|
"for obs in obs_types:\n",
|
|
" plt.plot(xs, smooth(np.mean(rewards_ppo[obs], axis=1), 0.8), label=obs)\n",
|
|
"plt.legend()\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"rewards_model = {} \n",
|
|
"for obs in obs_types:\n",
|
|
" rewards_ppo[obs] = []\n",
|
|
" for s in seeds:\n",
|
|
" rewards_ppo[obs].append(read_reward(\"transfer_results/ppo_{}_s{}\".format(obs, s)))\n",
|
|
" rewards_ppo[obs] = np.array(rewards_ppo[obs])"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "mlagents-env",
|
|
"language": "python",
|
|
"name": "mlagents-env"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|