浏览代码

plotting

/develop/bisim-sac-transfer
yanchaosun 5 年前
当前提交
8fc18e5d
共有 2 个文件被更改,包括 281 次插入47 次删除
  1. 229
      ml-agents/mlagents/trainers/tests/encoder_plot.ipynb
  2. 99
      ml-agents/mlagents/trainers/tests/reward_plot.ipynb

229
ml-agents/mlagents/trainers/tests/encoder_plot.ipynb
文件差异内容过多而无法显示
查看文件

99
ml-agents/mlagents/trainers/tests/reward_plot.ipynb


{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import csv\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def read_reward(dirname):\n",
" reward = []\n",
" with open(dirname+'/Simple.csv', newline='') as csvfile:\n",
" reader = csv.DictReader(csvfile)\n",
" for row in reader:\n",
" reward.append(float(row['Environment/Cumulative Reward']))\n",
" return reward"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def smooth(scalars, weight: float): # Weight between 0 and 1\n",
" last = scalars[0] # First value in the plot (first timestep)\n",
" smoothed = list()\n",
" for point in scalars:\n",
" smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value\n",
" smoothed.append(smoothed_val) # Save it\n",
" last = smoothed_val # Anchor the last smoothed value\n",
"\n",
" return smoothed"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"seeds = [0,1,2,3,4]\n",
"obs_types = [\"long\", \"longpre\"]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"rewards_ppo = {} \n",
"for obs in obs_types:\n",
" rewards_ppo[obs] = []\n",
" for s in seeds:\n",
" rewards_ppo[obs].append(read_reward(\"transfer_results/ppo_{}_s{}\".format(obs, s)))\n",
" rewards_ppo[obs] = np.array(rewards_ppo[obs])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mlagents-env",
"language": "python",
"name": "mlagents-env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
正在加载...
取消
保存