{ "cells": [ { "cell_type": "markdown", "id": "c1084315", "metadata": {}, "source": [ "# 1-1,结构化数据建模流程范例" ] }, { "cell_type": "code", "execution_count": null, "id": "3922d854-7272-41fa-b70b-15fe0bdb01fc", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "!pip install torch\n", "!pip install 'torchkeras>=4.0.0'" ] }, { "cell_type": "code", "execution_count": 1, "id": "83c3b159", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.__version__ = 2.4.0\n", "torchkeras.__version__ = 4.0.0\n" ] } ], "source": [ "import torch \n", "import torchkeras \n", "print(\"torch.__version__ = \", torch.__version__)\n", "print(\"torchkeras.__version__ = \", torchkeras.__version__) " ] }, { "cell_type": "markdown", "id": "29bd80bf", "metadata": {}, "source": [ "
\n", "\n", "\n", " \n", "公众号 **算法美食屋** 回复关键词:**pytorch**, 获取本项目源码和所用数据百度云盘下载链接。\n", " \n", " \n" ] }, { "cell_type": "markdown", "id": "021a78c0", "metadata": {}, "source": [ "### 一,准备数据" ] }, { "cell_type": "markdown", "id": "22985769", "metadata": {}, "source": [ "titanic数据集的目标是根据乘客信息预测他们在Titanic号撞击冰山沉没后能否生存。\n", "\n", "结构化数据一般会使用Pandas中的DataFrame进行预处理。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "af13fa86", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
049301Molson, Mr. Harry Marklandmale55.00011378730.5000C30S
15311Harper, Mrs. Henry Sleeper (Myna Haxtun)female49.010PC 1757276.7292D33C
238812Buss, Miss. Katefemale36.0002784913.0000NaNS
319202Carbines, Mr. Williammale19.0002842413.0000NaNS
468703Panula, Mr. Jaako Arnoldmale14.041310129539.6875NaNS
51612Hewlett, Mrs. (Mary D Kingcome)female55.00024870616.0000NaNS
622803Lovell, Mr. John Hall (\"Henry\")male20.500A/5 211737.2500NaNS
788402Banfield, Mr. Frederick Jamesmale28.000C.A./SOTON 3406810.5000NaNS
816803Skoog, Mrs. William (Anna Bernhardina Karlsson)female45.01434708827.9000NaNS
975213Moor, Master. Meiermale6.00139209612.4750E121S
\n", "
" ], "text/plain": [ " PassengerId Survived Pclass \\\n", "0 493 0 1 \n", "1 53 1 1 \n", "2 388 1 2 \n", "3 192 0 2 \n", "4 687 0 3 \n", "5 16 1 2 \n", "6 228 0 3 \n", "7 884 0 2 \n", "8 168 0 3 \n", "9 752 1 3 \n", "\n", " Name Sex Age SibSp \\\n", "0 Molson, Mr. Harry Markland male 55.0 0 \n", "1 Harper, Mrs. Henry Sleeper (Myna Haxtun) female 49.0 1 \n", "2 Buss, Miss. Kate female 36.0 0 \n", "3 Carbines, Mr. William male 19.0 0 \n", "4 Panula, Mr. Jaako Arnold male 14.0 4 \n", "5 Hewlett, Mrs. (Mary D Kingcome) female 55.0 0 \n", "6 Lovell, Mr. John Hall (\"Henry\") male 20.5 0 \n", "7 Banfield, Mr. Frederick James male 28.0 0 \n", "8 Skoog, Mrs. William (Anna Bernhardina Karlsson) female 45.0 1 \n", "9 Moor, Master. Meier male 6.0 0 \n", "\n", " Parch Ticket Fare Cabin Embarked \n", "0 0 113787 30.5000 C30 S \n", "1 0 PC 17572 76.7292 D33 C \n", "2 0 27849 13.0000 NaN S \n", "3 0 28424 13.0000 NaN S \n", "4 1 3101295 39.6875 NaN S \n", "5 0 248706 16.0000 NaN S \n", "6 0 A/5 21173 7.2500 NaN S \n", "7 0 C.A./SOTON 34068 10.5000 NaN S \n", "8 4 347088 27.9000 NaN S \n", "9 1 392096 12.4750 E121 S " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np \n", "import pandas as pd \n", "import matplotlib.pyplot as plt\n", "import torch \n", "from torch import nn \n", "from torch.utils.data import Dataset,DataLoader,TensorDataset\n", "\n", "dftrain_raw = pd.read_csv('./eat_pytorch_datasets/titanic/train.csv')\n", "dftest_raw = pd.read_csv('./eat_pytorch_datasets/titanic/test.csv')\n", "dftrain_raw.head(10)\n" ] }, { "cell_type": "markdown", "id": "a7edf254", "metadata": {}, "source": [ "字段说明:\n", "\n", "* Survived:0代表死亡,1代表存活【y标签】\n", "* Pclass:乘客所持票类,有三种值(1,2,3) 【转换成onehot编码】\n", "* Name:乘客姓名 【舍去】\n", "* Sex:乘客性别 【转换成bool特征】\n", "* Age:乘客年龄(有缺失) 【数值特征,添加“年龄是否缺失”作为辅助特征】\n", "* SibSp:乘客兄弟姐妹/配偶的个数(整数值) 【数值特征】\n", "* Parch:乘客父母/孩子的个数(整数值)【数值特征】\n", "* Ticket:票号(字符串)【舍去】\n", "* Fare:乘客所持票的价格(浮点数,0-500不等) 【数值特征】\n", "* Cabin:乘客所在船舱(有缺失) 【添加“所在船舱是否缺失”作为辅助特征】\n", "* Embarked:乘客登船港口:S、C、Q(有缺失)【转换成onehot编码,四维度 S,C,Q,nan】\n" ] }, { "cell_type": "markdown", "id": "9a726f38", "metadata": {}, "source": [ "利用Pandas的数据可视化功能我们可以简单地进行探索性数据分析EDA(Exploratory Data Analysis)。\n", "\n", "label分布情况" ] }, { "cell_type": "code", "execution_count": 3, "id": "78e13165", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "%config InlineBackend.figure_format = 'png'\n", "ax = dftrain_raw['Survived'].value_counts().plot(kind = 'bar',\n", " figsize = (12,8),fontsize=15,rot = 0)\n", "ax.set_ylabel('Counts',fontsize = 15)\n", "ax.set_xlabel('Survived',fontsize = 15)\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "2f6fa85c", "metadata": {}, "source": [ "年龄分布情况" ] }, { "cell_type": "code", "execution_count": 4, "id": "2284f17f", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "%config InlineBackend.figure_format = 'png'\n", "ax = dftrain_raw['Age'].plot(kind = 'hist',bins = 20,color= 'purple',\n", " figsize = (12,8),fontsize=15)\n", "\n", "ax.set_ylabel('Frequency',fontsize = 15)\n", "ax.set_xlabel('Age',fontsize = 15)\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "25ab6193", "metadata": {}, "source": [ "年龄和label的相关性" ] }, { "cell_type": "code", "execution_count": 5, "id": "e8a0979c", "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "%config InlineBackend.figure_format = 'png'\n", "ax = dftrain_raw.query('Survived == 0')['Age'].plot(kind = 'density',\n", " figsize = (12,8),fontsize=15)\n", "dftrain_raw.query('Survived == 1')['Age'].plot(kind = 'density',\n", " figsize = (12,8),fontsize=15)\n", "ax.legend(['Survived==0','Survived==1'],fontsize = 12)\n", "ax.set_ylabel('Density',fontsize = 15)\n", "ax.set_xlabel('Age',fontsize = 15)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "d4ba0563", "metadata": {}, "source": [ "下面为正式的数据预处理" ] }, { "cell_type": "code", "execution_count": 6, "id": "1d64c292-33ed-406d-9b00-5b47ba6bff2f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
049301Molson, Mr. Harry Marklandmale55.00011378730.5000C30S
15311Harper, Mrs. Henry Sleeper (Myna Haxtun)female49.010PC 1757276.7292D33C
238812Buss, Miss. Katefemale36.0002784913.0000NaNS
319202Carbines, Mr. Williammale19.0002842413.0000NaNS
468703Panula, Mr. Jaako Arnoldmale14.041310129539.6875NaNS
.......................................
70785913Baclini, Mrs. Solomon (Latifa Qurban)female24.003266619.2583NaNC
7086501Stewart, Mr. Albert AmaleNaN00PC 1760527.7208NaNC
70913003Ekstrom, Mr. Johanmale45.0003470616.9750NaNS
7102102Fynney, Mr. Joseph Jmale35.00023986526.0000NaNS
71147601Clifford, Mr. George QuincymaleNaN0011046552.0000A14S
\n", "

712 rows × 12 columns

\n", "
" ], "text/plain": [ " PassengerId Survived Pclass Name \\\n", "0 493 0 1 Molson, Mr. Harry Markland \n", "1 53 1 1 Harper, Mrs. Henry Sleeper (Myna Haxtun) \n", "2 388 1 2 Buss, Miss. Kate \n", "3 192 0 2 Carbines, Mr. William \n", "4 687 0 3 Panula, Mr. Jaako Arnold \n", ".. ... ... ... ... \n", "707 859 1 3 Baclini, Mrs. Solomon (Latifa Qurban) \n", "708 65 0 1 Stewart, Mr. Albert A \n", "709 130 0 3 Ekstrom, Mr. Johan \n", "710 21 0 2 Fynney, Mr. Joseph J \n", "711 476 0 1 Clifford, Mr. George Quincy \n", "\n", " Sex Age SibSp Parch Ticket Fare Cabin Embarked \n", "0 male 55.0 0 0 113787 30.5000 C30 S \n", "1 female 49.0 1 0 PC 17572 76.7292 D33 C \n", "2 female 36.0 0 0 27849 13.0000 NaN S \n", "3 male 19.0 0 0 28424 13.0000 NaN S \n", "4 male 14.0 4 1 3101295 39.6875 NaN S \n", ".. ... ... ... ... ... ... ... ... \n", "707 female 24.0 0 3 2666 19.2583 NaN C \n", "708 male NaN 0 0 PC 17605 27.7208 NaN C \n", "709 male 45.0 0 0 347061 6.9750 NaN S \n", "710 male 35.0 0 0 239865 26.0000 NaN S \n", "711 male NaN 0 0 110465 52.0000 A14 S \n", "\n", "[712 rows x 12 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dftrain_raw " ] }, { "cell_type": "code", "execution_count": 7, "id": "a28dcead", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x_train.shape = (712, 15)\n", "x_test.shape = (179, 15)\n", "y_train.shape = (712, 1)\n", "y_test.shape = (179, 1)\n" ] } ], "source": [ "def preprocessing(dfdata):\n", "\n", " dfresult= pd.DataFrame()\n", "\n", " #Pclass\n", " dfPclass = pd.get_dummies(dfdata['Pclass']).astype(float)\n", " dfPclass.columns = ['Pclass_' +str(x) for x in dfPclass.columns ]\n", " dfresult = pd.concat([dfresult,dfPclass],axis = 1)\n", "\n", " #Sex\n", " dfSex = pd.get_dummies(dfdata['Sex']).astype(float)\n", " dfresult = pd.concat([dfresult,dfSex],axis = 1)\n", "\n", " #Age\n", " dfresult['Age'] = dfdata['Age'].fillna(0)\n", " dfresult['Age_null'] = pd.isna(dfdata['Age']).astype(float)\n", "\n", " #SibSp,Parch,Fare\n", " dfresult['SibSp'] = dfdata['SibSp']\n", " dfresult['Parch'] = dfdata['Parch']\n", " dfresult['Fare'] = dfdata['Fare']\n", "\n", " #Carbin\n", " dfresult['Cabin_null'] = pd.isna(dfdata['Cabin']).astype(float)\n", "\n", " #Embarked\n", " dfEmbarked = pd.get_dummies(dfdata['Embarked'],dummy_na=True).astype(float)\n", " dfEmbarked.columns = ['Embarked_' + str(x) for x in dfEmbarked.columns]\n", " dfresult = pd.concat([dfresult,dfEmbarked],axis = 1)\n", "\n", " return(dfresult)\n", "\n", "x_train = preprocessing(dftrain_raw).values\n", "y_train = dftrain_raw[['Survived']].values\n", "\n", "x_test = preprocessing(dftest_raw).values\n", "y_test = dftest_raw[['Survived']].values\n", "\n", "print(\"x_train.shape =\", x_train.shape )\n", "print(\"x_test.shape =\", x_test.shape )\n", "\n", "print(\"y_train.shape =\", y_train.shape )\n", "print(\"y_test.shape =\", y_test.shape )\n" ] }, { "cell_type": "markdown", "id": "95d6c6d9", "metadata": {}, "source": [ "进一步使用DataLoader和TensorDataset封装成可以迭代的数据管道。" ] }, { "cell_type": "code", "execution_count": 8, "id": "1d744935", "metadata": {}, "outputs": [], "source": [ "dl_train = DataLoader(TensorDataset(torch.tensor(x_train).float(),torch.tensor(y_train).float()),\n", " shuffle = True, batch_size = 8)\n", "dl_val = DataLoader(TensorDataset(torch.tensor(x_test).float(),torch.tensor(y_test).float()),\n", " shuffle = False, batch_size = 8)\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "5ec2fc6b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000,\n", " 0.0000, 7.2292, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 45.0000, 0.0000, 0.0000,\n", " 1.0000, 14.4542, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 28.0000, 0.0000, 0.0000,\n", " 0.0000, 10.5000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000],\n", " [ 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 15.0000, 0.0000, 1.0000,\n", " 0.0000, 14.4542, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 36.0000, 0.0000, 1.0000,\n", " 0.0000, 26.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000],\n", " [ 0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 1.0000, 0.0000, 5.0000,\n", " 2.0000, 46.9000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000],\n", " [ 1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 33.0000, 0.0000, 0.0000,\n", " 0.0000, 86.5000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000],\n", " [ 0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.6700, 0.0000, 1.0000,\n", " 1.0000, 14.5000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000]]) tensor([[0.],\n", " [0.],\n", " [0.],\n", " [1.],\n", " [1.],\n", " [0.],\n", " [1.],\n", " [1.]])\n" ] } ], "source": [ "# 测试数据管道\n", "for features,labels in dl_train:\n", " print(features,labels)\n", " break" ] }, { "cell_type": "code", "execution_count": null, "id": "0008e20c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "25126768", "metadata": {}, "source": [ "### 二,定义模型" ] }, { "cell_type": "markdown", "id": "3a275e98", "metadata": {}, "source": [ "使用Pytorch通常有三种方式构建模型:使用nn.Sequential按层顺序构建模型,继承nn.Module基类构建自定义模型,继承nn.Module基类构建模型并辅助应用模型容器进行封装。\n", "\n", "此处选择使用最简单的nn.Sequential,按层顺序模型。" ] }, { "cell_type": "code", "execution_count": 10, "id": "617186ef", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential(\n", " (linear1): Linear(in_features=15, out_features=20, bias=True)\n", " (relu1): ReLU()\n", " (linear2): Linear(in_features=20, out_features=15, bias=True)\n", " (relu2): ReLU()\n", " (linear3): Linear(in_features=15, out_features=1, bias=True)\n", ")\n" ] } ], "source": [ "def create_net():\n", " net = nn.Sequential()\n", " net.add_module(\"linear1\",nn.Linear(15,20))\n", " net.add_module(\"relu1\",nn.ReLU())\n", " net.add_module(\"linear2\",nn.Linear(20,15))\n", " net.add_module(\"relu2\",nn.ReLU())\n", " net.add_module(\"linear3\",nn.Linear(15,1))\n", " return net\n", " \n", "net = create_net()\n", "print(net)" ] }, { "cell_type": "code", "execution_count": null, "id": "cdef0374", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "6761227a", "metadata": {}, "source": [ "### 三,训练模型" ] }, { "cell_type": "markdown", "id": "af89d5af", "metadata": {}, "source": [ "Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异。\n", "\n", "有3类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类形式训练循环。\n", "\n", "此处介绍一种较通用的仿照Keras风格的脚本形式的训练循环。\n", "\n", "该脚本形式的训练代码与 torchkeras 库的核心代码基本一致。\n", "\n", "torchkeras详情: https://github.com/lyhue1991/torchkeras \n" ] }, { "cell_type": "code", "execution_count": 11, "id": "87c5039d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2024-09-11 10:45:33\n", "Epoch 1 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1092.33it/s, train_acc=0.612, train_loss=0.662]\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1775.68it/s, val_acc=0.687, val_loss=0.616]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "<<<<<< reach best val_acc : 0.6871508359909058 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "================================================================================2024-09-11 10:45:33\n", "Epoch 2 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1377.66it/s, train_acc=0.691, train_loss=0.596]\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1653.62it/s, val_acc=0.709, val_loss=0.548]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "<<<<<< reach best val_acc : 0.7094972133636475 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "================================================================================2024-09-11 10:45:34\n", "Epoch 3 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1352.33it/s, train_acc=0.729, train_loss=0.554]\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1625.89it/s, val_acc=0.749, val_loss=0.509]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "<<<<<< reach best val_acc : 0.748603343963623 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2024-09-11 10:45:34\n", "Epoch 4 / 20\n", "\n", "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1293.46it/s, train_acc=0.77, train_loss=0.515]\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1669.97it/s, val_acc=0.732, val_loss=0.483]\n", "\n", "================================================================================2024-09-11 10:45:34\n", "Epoch 5 / 20\n", "\n", "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 768.35it/s, train_acc=0.772, train_loss=0.518]\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1374.77it/s, val_acc=0.777, val_loss=0.475]\n", "\n", "================================================================================2024-09-11 10:45:34" ] }, { "name": "stderr", "output_type": "stream", "text": [ "<<<<<< reach best val_acc : 0.7765362858772278 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 6 / 20\n", "\n", "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1144.28it/s, train_acc=0.784, train_loss=0.51]\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1443.71it/s, val_acc=0.782, val_loss=0.451]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "<<<<<< reach best val_acc : 0.7821229100227356 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "================================================================================2024-09-11 10:45:34\n", "Epoch 7 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1058.84it/s, train_acc=0.791, train_loss=0.475]\n", "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1315.87it/s, val_acc=0.777, val_loss=0.42]\n", "\n", "================================================================================2024-09-11 10:45:34\n", "Epoch 8 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1061.54it/s, train_acc=0.785, train_loss=0.482]\n", "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1328.48it/s, val_acc=0.793, val_loss=0.42]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "<<<<<< reach best val_acc : 0.7932960987091064 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2024-09-11 10:45:34\n", "Epoch 9 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1077.93it/s, train_acc=0.801, train_loss=0.462]\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1285.21it/s, val_acc=0.799, val_loss=0.413]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "<<<<<< reach best val_acc : 0.7988826632499695 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2024-09-11 10:45:34\n", "Epoch 10 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1118.12it/s, train_acc=0.795, train_loss=0.466]\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1670.20it/s, val_acc=0.793, val_loss=0.447]\n", "\n", "================================================================================2024-09-11 10:45:34\n", "Epoch 11 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1322.75it/s, train_acc=0.805, train_loss=0.461]\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1569.32it/s, val_acc=0.771, val_loss=0.434]\n", "\n", "================================================================================2024-09-11 10:45:34\n", "Epoch 12 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1313.24it/s, train_acc=0.798, train_loss=0.446]\n", "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1565.37it/s, val_acc=0.81, val_loss=0.426]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "<<<<<< reach best val_acc : 0.8100558519363403 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2024-09-11 10:45:35\n", "Epoch 13 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1258.56it/s, train_acc=0.812, train_loss=0.443]\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1679.09it/s, val_acc=0.804, val_loss=0.407]\n", "\n", "================================================================================2024-09-11 10:45:35\n", "Epoch 14 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1312.92it/s, train_acc=0.801, train_loss=0.456]\n", "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1661.83it/s, val_acc=0.81, val_loss=0.405]\n", "\n", "================================================================================2024-09-11 10:45:35\n", "Epoch 15 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1346.36it/s, train_acc=0.805, train_loss=0.449]\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1544.44it/s, val_acc=0.804, val_loss=0.453]\n", "\n", "================================================================================2024-09-11 10:45:35\n", "Epoch 16 / 20\n", "\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 1286.41it/s, train_acc=0.803, train_loss=0.445]\n", "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1599.58it/s, val_acc=0.799, val_loss=0.451]\n", "\n", "================================================================================2024-09-11 10:45:35\n", "Epoch 17 / 20\n", "\n", "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 785.39it/s, train_acc=0.794, train_loss=0.457]\n", "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1554.25it/s, val_acc=0.793, val_loss=0.4]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "<<<<<< val_acc without improvement in 5 epoch, early stopping >>>>>>\n" ] } ], "source": [ "import os,sys,time\n", "import numpy as np\n", "import pandas as pd\n", "import datetime \n", "from tqdm import tqdm \n", "\n", "import torch\n", "from torch import nn \n", "from copy import deepcopy\n", "from torchkeras.metrics import Accuracy\n", "\n", "\n", "def printlog(info):\n", " nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n", " print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n", " print(str(info)+\"\\n\")\n", " \n", "\n", "loss_fn = nn.BCEWithLogitsLoss()\n", "optimizer= torch.optim.Adam(net.parameters(),lr = 0.005) \n", "metrics_dict = {\"acc\":Accuracy()}\n", "\n", "epochs = 20 \n", "ckpt_path='checkpoint.pt'\n", "\n", "#early_stopping相关设置\n", "monitor=\"val_acc\"\n", "patience=5\n", "mode=\"max\"\n", "\n", "history = {}\n", "\n", "for epoch in range(1, epochs+1):\n", " printlog(\"Epoch {0} / {1}\".format(epoch, epochs))\n", "\n", " # 1,train ------------------------------------------------- \n", " net.train()\n", " \n", " total_loss,step = 0,0\n", " \n", " loop = tqdm(enumerate(dl_train), total =len(dl_train),file = sys.stdout)\n", " train_metrics_dict = deepcopy(metrics_dict) \n", " \n", " for i, batch in loop: \n", " \n", " features,labels = batch\n", " #forward\n", " preds = net(features)\n", " loss = loss_fn(preds,labels)\n", " \n", " #backward\n", " loss.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", " \n", " #metrics\n", " step_metrics = {\"train_\"+name:metric_fn(preds, labels).item() \n", " for name,metric_fn in train_metrics_dict.items()}\n", " \n", " step_log = dict({\"train_loss\":loss.item()},**step_metrics)\n", "\n", " total_loss += loss.item()\n", " \n", " step+=1\n", " if i!=len(dl_train)-1:\n", " loop.set_postfix(**step_log)\n", " else:\n", " epoch_loss = total_loss/step\n", " epoch_metrics = {\"train_\"+name:metric_fn.compute().item() \n", " for name,metric_fn in train_metrics_dict.items()}\n", " epoch_log = dict({\"train_loss\":epoch_loss},**epoch_metrics)\n", " loop.set_postfix(**epoch_log)\n", "\n", " for name,metric_fn in train_metrics_dict.items():\n", " metric_fn.reset()\n", " \n", " for name, metric in epoch_log.items():\n", " history[name] = history.get(name, []) + [metric]\n", " \n", "\n", " # 2,validate -------------------------------------------------\n", " net.eval()\n", " \n", " total_loss,step = 0,0\n", " loop = tqdm(enumerate(dl_val), total =len(dl_val),file = sys.stdout)\n", " \n", " val_metrics_dict = deepcopy(metrics_dict) \n", " \n", " with torch.no_grad():\n", " for i, batch in loop: \n", "\n", " features,labels = batch\n", " \n", " #forward\n", " preds = net(features)\n", " loss = loss_fn(preds,labels)\n", "\n", " #metrics\n", " step_metrics = {\"val_\"+name:metric_fn(preds, labels).item() \n", " for name,metric_fn in val_metrics_dict.items()}\n", "\n", " step_log = dict({\"val_loss\":loss.item()},**step_metrics)\n", "\n", " total_loss += loss.item()\n", " step+=1\n", " if i!=len(dl_val)-1:\n", " loop.set_postfix(**step_log)\n", " else:\n", " epoch_loss = (total_loss/step)\n", " epoch_metrics = {\"val_\"+name:metric_fn.compute().item() \n", " for name,metric_fn in val_metrics_dict.items()}\n", " epoch_log = dict({\"val_loss\":epoch_loss},**epoch_metrics)\n", " loop.set_postfix(**epoch_log)\n", "\n", " for name,metric_fn in val_metrics_dict.items():\n", " metric_fn.reset()\n", " \n", " epoch_log[\"epoch\"] = epoch \n", " for name, metric in epoch_log.items():\n", " history[name] = history.get(name, []) + [metric]\n", "\n", " # 3,early-stopping -------------------------------------------------\n", " arr_scores = history[monitor]\n", " best_score_idx = np.argmax(arr_scores) if mode==\"max\" else np.argmin(arr_scores)\n", " if best_score_idx==len(arr_scores)-1:\n", " torch.save(net.state_dict(),ckpt_path)\n", " print(\"<<<<<< reach best {0} : {1} >>>>>>\".format(monitor,\n", " arr_scores[best_score_idx]),file=sys.stderr)\n", " if len(arr_scores)-best_score_idx>patience:\n", " print(\"<<<<<< {} without improvement in {} epoch, early stopping >>>>>>\".format(\n", " monitor,patience),file=sys.stderr)\n", " break \n", " net.load_state_dict(torch.load(ckpt_path,weights_only=True))\n", " \n", "dfhistory = pd.DataFrame(history)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6c5ba3ad", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ec78a930", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "70f7f3b3", "metadata": {}, "source": [ "### 四,评估模型" ] }, { "cell_type": "markdown", "id": "cefbce96", "metadata": {}, "source": [ "我们首先评估一下模型在训练集和验证集上的效果。" ] }, { "cell_type": "code", "execution_count": 13, "id": "2a731173", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
train_losstrain_accval_lossval_accepoch
00.6615940.6123600.6164070.6871511
10.5964480.6910110.5481760.7094972
20.5535930.7289330.5089820.7486033
30.5145010.7696630.4827260.7318444
40.5176880.7724720.4749830.7765365
50.5097550.7837080.4506120.7821236
60.4746830.7907300.4200940.7765367
70.4824960.7851120.4200830.7932968
80.4615510.8005620.4132040.7988839
90.4659860.7949440.4466650.79329610
100.4613250.8047750.4342900.77095011
110.4459430.7977530.4255590.81005612
120.4428020.8117980.4070310.80446913
130.4556150.8005620.4046790.81005614
140.4494650.8047750.4526110.80446915
150.4449320.8033710.4506960.79888316
160.4571840.7935390.3995760.79329617
\n", "
" ], "text/plain": [ " train_loss train_acc val_loss val_acc epoch\n", "0 0.661594 0.612360 0.616407 0.687151 1\n", "1 0.596448 0.691011 0.548176 0.709497 2\n", "2 0.553593 0.728933 0.508982 0.748603 3\n", "3 0.514501 0.769663 0.482726 0.731844 4\n", "4 0.517688 0.772472 0.474983 0.776536 5\n", "5 0.509755 0.783708 0.450612 0.782123 6\n", "6 0.474683 0.790730 0.420094 0.776536 7\n", "7 0.482496 0.785112 0.420083 0.793296 8\n", "8 0.461551 0.800562 0.413204 0.798883 9\n", "9 0.465986 0.794944 0.446665 0.793296 10\n", "10 0.461325 0.804775 0.434290 0.770950 11\n", "11 0.445943 0.797753 0.425559 0.810056 12\n", "12 0.442802 0.811798 0.407031 0.804469 13\n", "13 0.455615 0.800562 0.404679 0.810056 14\n", "14 0.449465 0.804775 0.452611 0.804469 15\n", "15 0.444932 0.803371 0.450696 0.798883 16\n", "16 0.457184 0.793539 0.399576 0.793296 17" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dfhistory " ] }, { "cell_type": "code", "execution_count": 12, "id": "10ab56d5", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%config InlineBackend.figure_format = 'svg'\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "def plot_metric(dfhistory, metric):\n", " train_metrics = dfhistory[\"train_\"+metric]\n", " val_metrics = dfhistory['val_'+metric]\n", " epochs = range(1, len(train_metrics) + 1)\n", " plt.plot(epochs, train_metrics, 'bo--')\n", " plt.plot(epochs, val_metrics, 'ro-')\n", " plt.title('Training and validation '+ metric)\n", " plt.xlabel(\"Epochs\")\n", " plt.ylabel(metric)\n", " plt.legend([\"train_\"+metric, 'val_'+metric])\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 14, "id": "e3b47c77", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-09-11T10:45:51.760370\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_metric(dfhistory,\"loss\")" ] }, { "cell_type": "code", "execution_count": 15, "id": "4bdfcf5a", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-09-11T10:45:53.781421\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_metric(dfhistory,\"acc\")" ] }, { "cell_type": "code", "execution_count": null, "id": "c1f1cb9e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "17d7e80c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "05c4494b", "metadata": {}, "source": [ "### 五,使用模型" ] }, { "cell_type": "code", "execution_count": 16, "id": "da16f9ab", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0771],\n", " [0.6915],\n", " [0.3397],\n", " [0.9527],\n", " [0.6116],\n", " [0.8747],\n", " [0.1023],\n", " [0.8377],\n", " [0.5713],\n", " [0.0841]])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#预测概率\n", "\n", "y_pred_probs = torch.sigmoid(net(torch.tensor(x_test[0:10]).float())).data\n", "y_pred_probs" ] }, { "cell_type": "code", "execution_count": 17, "id": "a31cb281", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.],\n", " [1.],\n", " [0.],\n", " [1.],\n", " [1.],\n", " [1.],\n", " [0.],\n", " [1.],\n", " [1.],\n", " [0.]])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#预测类别\n", "y_pred = torch.where(y_pred_probs>0.5,\n", " torch.ones_like(y_pred_probs),torch.zeros_like(y_pred_probs))\n", "y_pred" ] }, { "cell_type": "code", "execution_count": null, "id": "48f6cacc", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "57fe8bba", "metadata": {}, "source": [ "### 六,保存模型" ] }, { "cell_type": "markdown", "id": "89eb7506", "metadata": {}, "source": [ "Pytorch 有两种保存模型的方式,都是通过调用pickle序列化方法实现的。\n", "\n", "第一种方法只保存模型参数。\n", "\n", "第二种方法保存完整模型。\n", "\n", "推荐使用第一种,第二种方法可能在切换设备和目录的时候出现各种问题。\n" ] }, { "cell_type": "markdown", "id": "9113eb43", "metadata": {}, "source": [ "**1,保存模型参数(推荐)**" ] }, { "cell_type": "code", "execution_count": 18, "id": "e6098000", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias', 'linear3.weight', 'linear3.bias'])\n" ] } ], "source": [ "print(net.state_dict().keys())\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "4cfa68ac", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0771],\n", " [0.6915],\n", " [0.3397],\n", " [0.9527],\n", " [0.6116],\n", " [0.8747],\n", " [0.1023],\n", " [0.8377],\n", " [0.5713],\n", " [0.0841]])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 保存模型参数\n", "\n", "torch.save(net.state_dict(), \"./data/net_parameter.pt\")\n", "\n", "net_clone = create_net()\n", "net_clone.load_state_dict(torch.load(\"./data/net_parameter.pt\",weights_only=True))\n", "\n", "torch.sigmoid(net_clone.forward(torch.tensor(x_test[0:10]).float())).data\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1fcfbadf", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "bee51ce5", "metadata": {}, "source": [ "**2,保存完整模型(不推荐)**" ] }, { "cell_type": "code", "execution_count": 20, "id": "4c969c33", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0771],\n", " [0.6915],\n", " [0.3397],\n", " [0.9527],\n", " [0.6116],\n", " [0.8747],\n", " [0.1023],\n", " [0.8377],\n", " [0.5713],\n", " [0.0841]])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.save(net, './data/net_model.pt')\n", "net_loaded = torch.load('./data/net_model.pt',weights_only=False)\n", "torch.sigmoid(net_loaded(torch.tensor(x_test[0:10]).float())).data\n" ] }, { "cell_type": "markdown", "id": "52eacb75", "metadata": {}, "source": [ "**如果本书对你有所帮助,想鼓励一下作者,记得给本项目加一颗星星star⭐️,并分享给你的朋友们喔😊!** \n", "\n", "如果对本书内容理解上有需要进一步和作者交流的地方,欢迎在公众号\"算法美食屋\"下留言。作者时间和精力有限,会酌情予以回复。\n", "\n", "也可以在公众号后台回复关键字:**加群**,加入读者交流群和大家讨论。\n", "\n", "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "formats": "ipynb,md" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }