yanchaosun
4 年前
当前提交
8fc18e5d
共有 2 个文件被更改,包括 281 次插入 和 47 次删除
-
229ml-agents/mlagents/trainers/tests/encoder_plot.ipynb
-
99ml-agents/mlagents/trainers/tests/reward_plot.ipynb
229
ml-agents/mlagents/trainers/tests/encoder_plot.ipynb
文件差异内容过多而无法显示
查看文件
文件差异内容过多而无法显示
查看文件
|
|||
{ |
|||
"cells": [ |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": 1, |
|||
"metadata": {}, |
|||
"outputs": [], |
|||
"source": [ |
|||
"import csv\n", |
|||
"import numpy as np\n", |
|||
"import matplotlib.pyplot as plt" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": 2, |
|||
"metadata": {}, |
|||
"outputs": [], |
|||
"source": [ |
|||
"def read_reward(dirname):\n", |
|||
" reward = []\n", |
|||
" with open(dirname+'/Simple.csv', newline='') as csvfile:\n", |
|||
" reader = csv.DictReader(csvfile)\n", |
|||
" for row in reader:\n", |
|||
" reward.append(float(row['Environment/Cumulative Reward']))\n", |
|||
" return reward" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": 3, |
|||
"metadata": {}, |
|||
"outputs": [], |
|||
"source": [ |
|||
"def smooth(scalars, weight: float): # Weight between 0 and 1\n", |
|||
" last = scalars[0] # First value in the plot (first timestep)\n", |
|||
" smoothed = list()\n", |
|||
" for point in scalars:\n", |
|||
" smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value\n", |
|||
" smoothed.append(smoothed_val) # Save it\n", |
|||
" last = smoothed_val # Anchor the last smoothed value\n", |
|||
"\n", |
|||
" return smoothed" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": 8, |
|||
"metadata": {}, |
|||
"outputs": [], |
|||
"source": [ |
|||
"seeds = [0,1,2,3,4]\n", |
|||
"obs_types = [\"long\", \"longpre\"]" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": 9, |
|||
"metadata": {}, |
|||
"outputs": [], |
|||
"source": [ |
|||
"rewards_ppo = {} \n", |
|||
"for obs in obs_types:\n", |
|||
" rewards_ppo[obs] = []\n", |
|||
" for s in seeds:\n", |
|||
" rewards_ppo[obs].append(read_reward(\"transfer_results/ppo_{}_s{}\".format(obs, s)))\n", |
|||
" rewards_ppo[obs] = np.array(rewards_ppo[obs])" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": null, |
|||
"metadata": {}, |
|||
"outputs": [], |
|||
"source": [] |
|||
} |
|||
], |
|||
"metadata": { |
|||
"kernelspec": { |
|||
"display_name": "mlagents-env", |
|||
"language": "python", |
|||
"name": "mlagents-env" |
|||
}, |
|||
"language_info": { |
|||
"codemirror_mode": { |
|||
"name": "ipython", |
|||
"version": 3 |
|||
}, |
|||
"file_extension": ".py", |
|||
"mimetype": "text/x-python", |
|||
"name": "python", |
|||
"nbconvert_exporter": "python", |
|||
"pygments_lexer": "ipython3", |
|||
"version": "3.7.6" |
|||
} |
|||
}, |
|||
"nbformat": 4, |
|||
"nbformat_minor": 4 |
|||
} |
撰写
预览
正在加载...
取消
保存
Reference in new issue