{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Report02 - 泰坦尼克号的生存预测\n", "\n", "* 沈键\n", "* 2021200082\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. 任务简介\n", "\n", " 泰坦尼克号沉船事故是世界上最著名的沉船事故之一。1912年4月15日,在她的处女航期间,泰坦尼克号撞上冰山后沉没,造成2224名乘客和机组人员中超过1502人的死亡。这一轰动的悲剧震惊了国际社会,并因此建立了更好的船舶安全法规。 事故中导致死亡的一个原因是许多船员和乘客没有足够的救生艇。然而在被获救群体中也有一些比较幸运的因素;一些人群在事故中被救的几率高于其他人,比如妇女、儿童和上层阶级。 这个任务中,我们需要分析和判断出什么样的人更容易获救,然后要利用机器学习来预测出在这场灾难中哪些人会最终获救。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. 分析数据" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " 本次任务中,给了一个名为train.csv的数据集文件,里面有891名乘客的数据。首先,我们使用pandas读取该文件,查看有哪些变量。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NaNS
5603Moran, Mr. JamesmaleNaN003308778.4583NaNQ
6701McCarthy, Mr. Timothy Jmale54.0001746351.8625E46S
7803Palsson, Master. Gosta Leonardmale2.03134990921.0750NaNS
8913Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)female27.00234774211.1333NaNS
91012Nasser, Mrs. Nicholas (Adele Achem)female14.01023773630.0708NaNC
\n", "
" ], "text/plain": [ " PassengerId Survived Pclass \\\n", "0 1 0 3 \n", "1 2 1 1 \n", "2 3 1 3 \n", "3 4 1 1 \n", "4 5 0 3 \n", "5 6 0 3 \n", "6 7 0 1 \n", "7 8 0 3 \n", "8 9 1 3 \n", "9 10 1 2 \n", "\n", " Name Sex Age SibSp \\\n", "0 Braund, Mr. Owen Harris male 22.0 1 \n", "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", "2 Heikkinen, Miss. Laina female 26.0 0 \n", "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", "4 Allen, Mr. William Henry male 35.0 0 \n", "5 Moran, Mr. James male NaN 0 \n", "6 McCarthy, Mr. Timothy J male 54.0 0 \n", "7 Palsson, Master. Gosta Leonard male 2.0 3 \n", "8 Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg) female 27.0 0 \n", "9 Nasser, Mrs. Nicholas (Adele Achem) female 14.0 1 \n", "\n", " Parch Ticket Fare Cabin Embarked \n", "0 0 A/5 21171 7.2500 NaN S \n", "1 0 PC 17599 71.2833 C85 C \n", "2 0 STON/O2. 3101282 7.9250 NaN S \n", "3 0 113803 53.1000 C123 S \n", "4 0 373450 8.0500 NaN S \n", "5 0 330877 8.4583 NaN Q \n", "6 0 17463 51.8625 E46 S \n", "7 1 349909 21.0750 NaN S \n", "8 2 347742 11.1333 NaN S \n", "9 0 237736 30.0708 NaN C " ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "train_df = pd.read_csv(\"./data/train.csv\")\n", "train_df.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过查看前几行,可以看到,每一行数据包括乘客的id、乘客等级、名字、性别、年龄、和该乘客一起旅行的兄弟姐妹和配偶的数量、和该乘客一起旅行的父母和孩子的数量、船票号、船票价格、船舱号、登船港口(S=英国南安普顿Southampton(起航点)/C=法国 瑟堡市Cherbourg(途经点)/Q=爱尔兰 昆士Queenstown(途经点))以及该乘客对应的最终是否存活。通过观察各列数据,还可以看出,某些列上存在缺失值,pandas带入数据集文件后,缺失值处会显示为NaN。通过热力图,可以更加直观的展示缺失值的分布情况。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", "\n", "sns.heatmap(train_df.isnull(), yticklabels=False, cbar=False, cmap='CMRmap')\n", "plt.tight_layout() \n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过将缺失值处的位置通过热力图高亮显示,可以看到,年龄(Age)、船舱号(Cabin)和登船港口(Embarked)列存在缺失值,并且,年龄和船舱号的缺失值很多。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 891 entries, 0 to 890\n", "Data columns (total 12 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 PassengerId 891 non-null int64 \n", " 1 Survived 891 non-null int64 \n", " 2 Pclass 891 non-null int64 \n", " 3 Name 891 non-null object \n", " 4 Sex 891 non-null object \n", " 5 Age 714 non-null float64\n", " 6 SibSp 891 non-null int64 \n", " 7 Parch 891 non-null int64 \n", " 8 Ticket 891 non-null object \n", " 9 Fare 891 non-null float64\n", " 10 Cabin 204 non-null object \n", " 11 Embarked 889 non-null object \n", "dtypes: float64(2), int64(5), object(5)\n", "memory usage: 83.7+ KB\n" ] } ], "source": [ "train_df.info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "年龄列的缺失率为(891-714)/891\\*100%=19.9%, 船舱号列的缺失率为(891-204)/891\\*100%=77.1%, 登船港口列的缺失率为(891-889)/891\\*100%=0.2%。可以看出船舱号列的缺失率很高,所以我们将其忽略,同时,乘客id和船票号意义不大,我们也将其忽略。接下里,我们分别研究其余因素对获救率的影响。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train_df.drop('PassengerId', axis=1, inplace=True)\n", "train_df.drop('Ticket', axis=1, inplace=True)\n", "train_df.drop('Cabin', axis=1, inplace=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 乘客等级对获救率的影响" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from pylab import *\n", "\n", "mpl.rcParams['font.sans-serif'] = ['SimHei']\n", "matplotlib.rcParams['axes.unicode_minus'] = False\n", "\n", "fig = plt.figure()\n", "fig.set(alpha=0.2)\n", "\n", "Suvived_0 = train_df.Pclass[train_df.Survived == 0].value_counts()\n", "Suvived_1 = train_df.Pclass[train_df.Survived == 1].value_counts()\n", "df = pd.DataFrame({u\"获救\": Suvived_1, u\"未获救\": Suvived_0})\n", "df.plot(kind='bar', stacked=True)\n", "plt.xticks(rotation=360)\n", "plt.title(u'各乘客等级的获救情况')\n", "plt.xlabel(u'乘客等级')\n", "plt.ylabel(u'人数')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,乘客等级为1的获救率最高,乘客等级为2的次之,乘客等级为3的获救率最低。显然,富人的获救率比穷人的获救率更高,且等级高的对应舱的救援设备一般都会好于等级低的船舱。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 名字对获救率的影响" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "观察乘客名字列,可以发现,乘客名字的中间部分对应了乘客目前的社会头衔或者已婚情况。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mr 517\n", "Miss 182\n", "Mrs 125\n", "Master 40\n", "Dr 7\n", "Rev 6\n", "Major 2\n", "Mlle 2\n", "Col 2\n", "Countess 1\n", "Lady 1\n", "Mme 1\n", "Jonkheer 1\n", "Don 1\n", "Ms 1\n", "Capt 1\n", "Sir 1\n", "Name: Name, dtype: int64\n" ] } ], "source": [ "import re\n", "\n", "def get_title(name):\n", " title_search = re.search('([A-Za-z]+)\\.',name)\n", " if title_search:\n", " return title_search.group(1)\n", " return\n", "\n", "titles = train_df[\"Name\"].apply(get_title)\n", "print(pd.value_counts(titles))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "例如Miss表示未婚的女士,而Mrs为已婚的女士,Capt、Col、Major、Dr、Rev头衔的可认为是政府官员,Don、Sir、Countess、Lady头衔的可认为是皇室成员,Master和Jonkheer表示有技能的人,如果出现其他头衔,则分到Others类。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# 将Name列改名为Title\n", "train_df.rename(columns={'Name':'Title'}, inplace=True)\n", "train_df['Title'] = train_df['Title'].apply(get_title)\n", "title_classification = {'Officer':['Capt', 'Col', 'Major', 'Dr', 'Rev'],\n", " 'Royalty':['Don', 'Sir', 'Countess', 'Lady'],\n", " 'Mrs':['Mme', 'Ms', 'Mrs'],\n", " 'Miss':['Mlle', 'Miss'],\n", " 'Mr':['Mr'],\n", " 'Master':['Master','Jonkheer']}\n", "title_map = {}\n", "for title in title_classification.keys():\n", " title_map.update(dict.fromkeys(title_classification[title], title))\n", "\n", "train_df['Title'] = train_df['Title'].map(title_map)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "Suvived_0 = train_df.Title[train_df.Survived == 0].value_counts()\n", "Suvived_1 = train_df.Title[train_df.Survived == 1].value_counts()\n", "df = pd.DataFrame({u\"获救\": Suvived_1, u\"未获救\": Suvived_0})\n", "df.plot(kind='bar', stacked=True)\n", "plt.xticks(rotation=30)\n", "plt.title(u'各头衔的获救情况')\n", "plt.xlabel(u'乘客头衔')\n", "plt.ylabel(u'人数')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,不同头衔对应的获救情况不同,其中,已婚女士和未婚女士的获救率较高,而男士的获救率较低,这与社会上提倡女士优先有着密切的关系,同时,\n", "皇室和有技能的人的获救率也很高,这与他们的社会地位有关,社会地位高的人往往会被优先救援。\n", "\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SurvivedPclassTitleSexAgeSibSpParchFareEmbarked
003Mrmale22.0107.2500S
111Mrsfemale38.01071.2833C
213Missfemale26.0007.9250S
311Mrsfemale35.01053.1000S
403Mrmale35.0008.0500S
503MrmaleNaN008.4583Q
601Mrmale54.00051.8625S
703Mastermale2.03121.0750S
813Mrsfemale27.00211.1333S
912Mrsfemale14.01030.0708C
\n", "
" ], "text/plain": [ " Survived Pclass Title Sex Age SibSp Parch Fare Embarked\n", "0 0 3 Mr male 22.0 1 0 7.2500 S\n", "1 1 1 Mrs female 38.0 1 0 71.2833 C\n", "2 1 3 Miss female 26.0 0 0 7.9250 S\n", "3 1 1 Mrs female 35.0 1 0 53.1000 S\n", "4 0 3 Mr male 35.0 0 0 8.0500 S\n", "5 0 3 Mr male NaN 0 0 8.4583 Q\n", "6 0 1 Mr male 54.0 0 0 51.8625 S\n", "7 0 3 Master male 2.0 3 1 21.0750 S\n", "8 1 3 Mrs female 27.0 0 2 11.1333 S\n", "9 1 2 Mrs female 14.0 1 0 30.0708 C" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.3 性别对获救率的影响" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "Suvived_0 = train_df.Sex[train_df.Survived == 0].value_counts()\n", "Suvived_1 = train_df.Sex[train_df.Survived == 1].value_counts()\n", "df = pd.DataFrame({u\"获救\": Suvived_1, u\"未获救\": Suvived_0})\n", "df.plot(kind='bar', stacked=True)\n", "plt.xticks(rotation=30)\n", "plt.title(u'不同性别的获救情况')\n", "plt.xlabel(u'性别')\n", "plt.ylabel(u'人数')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "如同在头衔分析中提到的一样,女士的获救率高于男士,与社会上提倡女士优先有着密切的关系。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2.4 年龄对获救率的影响" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "facet = sns.FacetGrid(train_df, hue=\"Survived\",aspect=2)\n", "facet.map(sns.kdeplot,'Age',shade= True)\n", "facet.set(xlim=(0, train_df['Age'].max()))\n", "facet.add_legend()\n", "plt.xlabel(u'年龄') \n", "plt.ylabel(u'获救率') \n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上图中可以看出,年龄(<=15)小存活率高,年龄大(>=60)的存活率低,这与社会对小孩的保护以及老年人体力不行有关。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.5 家庭成员数量对获救率的影响" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "家庭成员包括一起旅行的兄弟姐妹和配偶的数量以及一起旅行的父母和孩子的数量。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "train_df['FamilySize'] = train_df['SibSp'] + train_df['Parch'] + 1\n", "Suvived_0 = train_df.FamilySize[train_df.Survived == 0].value_counts()\n", "Suvived_1 = train_df.FamilySize[train_df.Survived == 1].value_counts()\n", "df = pd.DataFrame({u\"获救\": Suvived_1, u\"未获救\": Suvived_0})\n", "df.plot(kind='bar', stacked=True)\n", "plt.xticks(rotation=30)\n", "plt.title(u'不同家庭成员数量的获救情况')\n", "plt.xlabel(u'家庭成员数量')\n", "plt.ylabel(u'人数')\n", "plt.show()\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "train_df.drop('SibSp', axis=1, inplace=True)\n", "train_df.drop('Parch', axis=1, inplace=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "船上家庭成员在3~4个左右的获救率最高,家庭成员太少或太多存活率都不高。一般而言,一个家庭里的成员会优先救自己本家庭中的人,但如果家庭人员太多,也会出现耽误了太多时间救援家庭成员而导致自己未及时离开的问题。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.6 票价对获救率的影响" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "facet = sns.FacetGrid(train_df, hue=\"Survived\",aspect=2)\n", "facet.map(sns.kdeplot,'Fare',shade= True)\n", "facet.set(xlim=(0, train_df['Fare'].max()))\n", "facet.add_legend()\n", "plt.xlabel(u'票价') \n", "plt.ylabel(u'获救率') \n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "一般而言,票价高的往往位置比较好,救援设施也更好,获救率也越高。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.7 登船港口对获救率的影响" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "Suvived_0 = train_df.Embarked[train_df.Survived == 0].value_counts()\n", "Suvived_1 = train_df.Embarked[train_df.Survived == 1].value_counts()\n", "df = pd.DataFrame({u\"获救\": Suvived_1, u\"未获救\": Suvived_0})\n", "df.plot(kind='bar', stacked=True)\n", "plt.xticks(rotation=30)\n", "plt.title(u'不同登船港口的获救情况')\n", "plt.xlabel(u'登船港口')\n", "plt.ylabel(u'人数')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "船港口不同,生存率不同。在C港口上船的获救率最高,而在S港口上船的获救率最低。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. 逻辑回归模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " 这是一个二分类问题,可以使用逻辑回归模型进行学习和预测。线性模型假设输出值与输入值之间满足:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$ Y=X \\times W^T + b $$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "其中,$Y$为输出值,$X$为输入特征量,$W$为输入特征量的权重系数矩阵,$b$为偏置项。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "当系数矩阵$W$和偏置项$b$确定后,输入一个特征向量,即可计算出一个值。但这样获得值是连续的,可能很大,也可能会很小,而分类范围,需要在[0,1],逻辑回归就是一种减小预测范围,将预测值限定为[0,1]间的一种回归模型,其回归方程与回归曲线如下图所示。逻辑曲线在z=0时,十分敏感,在z>>0或z<<0处,都不敏感,将预测值限定为(0,1)。常用的逻辑回归函数为Sigmoid函数,其表达式为:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "g(z)=\\frac{1}{1+e^{-z}}\n", "$$" ] }, { "attachments": { "linear_logistic_regression.png": { "image/png": "" } }, "cell_type": "markdown", "metadata": {}, "source": [ "![linear_logistic_regression.png](attachment:linear_logistic_regression.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接着,再通过定义一个损失函数来描述预测值与真实值之间的误差,当预测值与真实值之间的误差越大时,损失函数也越大,通过梯度下降法可以得到系数矩阵和偏置项的优化趋势,再设立合适的学习率迭代求解,可得到对应较小损失函数值下的系数矩阵和偏置项的值,模型训练结束。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 模型训练" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.1 处理缺失值" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "    目前的数据中,剔除船舱号后,年龄列和和登船港口列存在缺失值,需要进行补充。对于年龄,可以采用使用平均年龄进行填充,而对于登船港口,可以使用人数最多的登船港口(S)进行填充。此外,由于年龄是一个一个的数字,在数据量不够大的情况,这样一个一个的数字没太大意义,我们要按照年龄段进行划分,票价也是如此。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def handle_age(age):\n", " if age <= 15:\n", " return 0\n", " elif age <= 60:\n", " return 1\n", " else:\n", " return 2\n", " \n", "def handle_fare(fare):\n", " if fare <= 32:\n", " return 0\n", " elif fare <= 100:\n", " return 1\n", " elif fare <= 200:\n", " return 2\n", " else:\n", " return 3\n", " \n", "train_df['Age'] = train_df['Age'].fillna(train_df['Age'].mean()).map(handle_age)\n", "train_df['Embarked'] = train_df['Embarked'].fillna('S')\n", "train_df['Fare'] = train_df['Fare'].map(handle_fare)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 891 entries, 0 to 890\n", "Data columns (total 8 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 Survived 891 non-null int64 \n", " 1 Pclass 891 non-null int64 \n", " 2 Title 891 non-null object\n", " 3 Sex 891 non-null object\n", " 4 Age 891 non-null int64 \n", " 5 Fare 891 non-null int64 \n", " 6 Embarked 891 non-null object\n", " 7 FamilySize 891 non-null int64 \n", "dtypes: int64(5), object(3)\n", "memory usage: 55.8+ KB\n" ] } ], "source": [ "train_df.info()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SurvivedPclassTitleSexAgeFareEmbarkedFamilySize
003Mrmale10S2
111Mrsfemale11C2
213Missfemale10S1
311Mrsfemale11S2
403Mrmale10S1
\n", "
" ], "text/plain": [ " Survived Pclass Title Sex Age Fare Embarked FamilySize\n", "0 0 3 Mr male 1 0 S 2\n", "1 1 1 Mrs female 1 1 C 2\n", "2 1 3 Miss female 1 0 S 1\n", "3 1 1 Mrs female 1 1 S 2\n", "4 0 3 Mr male 1 0 S 1" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.2 数字化数据" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " 当前的数据集还还不能直接进行训练,因为数据集中部分列是由字符串表示的,无法参与到数值运算过程中,因此,我们需要将字符串处理成数字的形式。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SurvivedPclassTitleSexAgeFareEmbarkedFamilySize
003Mrmale10S2
111Mrsfemale11C2
213Missfemale10S1
311Mrsfemale11S2
403Mrmale10S1
\n", "
" ], "text/plain": [ " Survived Pclass Title Sex Age Fare Embarked FamilySize\n", "0 0 3 Mr male 1 0 S 2\n", "1 1 1 Mrs female 1 1 C 2\n", "2 1 3 Miss female 1 0 S 1\n", "3 1 1 Mrs female 1 1 S 2\n", "4 0 3 Mr male 1 0 S 1" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在头衔对获救率影响的分析中,我们将头衔分类成了'Officer', 'Royalty', 'Mrs', 'Miss', 'Mr', 'Master',使用数字1-6分别表示上述类别。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "title_map2num = {'Officer': 1, 'Royalty': 2, 'Mrs': 3, 'Miss': 4, 'Mr':5, 'Master': 6}\n", "train_df['Title'] = train_df['Title'].map(title_map2num)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SurvivedPclassTitleSexAgeFareEmbarkedFamilySize
0035male10S2
1113female11C2
2134female10S1
3113female11S2
4035male10S1
\n", "
" ], "text/plain": [ " Survived Pclass Title Sex Age Fare Embarked FamilySize\n", "0 0 3 5 male 1 0 S 2\n", "1 1 1 3 female 1 1 C 2\n", "2 1 3 4 female 1 0 S 1\n", "3 1 1 3 female 1 1 S 2\n", "4 0 3 5 male 1 0 S 1" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "对于性别这列,将female映射为0, male映射为1。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "sex_map2num = {'female': 0, 'male': 1}\n", "train_df['Sex'] = train_df['Sex'].map(sex_map2num)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SurvivedPclassTitleSexAgeFareEmbarkedFamilySize
0035110S2
1113011C2
2134010S1
3113011S2
4035110S1
\n", "
" ], "text/plain": [ " Survived Pclass Title Sex Age Fare Embarked FamilySize\n", "0 0 3 5 1 1 0 S 2\n", "1 1 1 3 0 1 1 C 2\n", "2 1 3 4 0 1 0 S 1\n", "3 1 1 3 0 1 1 S 2\n", "4 0 3 5 1 1 0 S 1" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "将三个登船港口分别映射为0, 1, 2。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "embarked_map2num = {'S': 0, 'C': 1, 'Q': 2}\n", "train_df['Embarked'] = train_df['Embarked'].map(embarked_map2num)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SurvivedPclassTitleSexAgeFareEmbarkedFamilySize
003511002
111301112
213401001
311301102
403511001
\n", "
" ], "text/plain": [ " Survived Pclass Title Sex Age Fare Embarked FamilySize\n", "0 0 3 5 1 1 0 0 2\n", "1 1 1 3 0 1 1 1 2\n", "2 1 3 4 0 1 0 0 1\n", "3 1 1 3 0 1 1 0 2\n", "4 0 3 5 1 1 0 0 1" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.3 使用Pytorch搭建逻辑回归模型" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "import logging\n", "import pickle\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import TensorDataset\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data import random_split\n", "\n", "class Titanic_Model(nn.Module):\n", " def __init__(self, input_dim, num_classes):\n", " super().__init__()\n", " self.input_dim = input_dim\n", " self.num_classes = num_classes\n", "\n", " self.linear_layer = nn.Linear(input_dim, num_classes)\n", "\n", " def forward(self, inputs):\n", " outputs = self.linear_layer(inputs)\n", " return outputs\n", "\n", " @staticmethod\n", " def compute_accuracy(outputs, labels):\n", " _, preds = torch.max(outputs, dim=1)\n", " return torch.tensor(torch.sum(preds == labels).item() / len(preds))\n", "\n", " @staticmethod\n", " def log_epoch_loss_and_acc(prefix, epoch, epoch_loss, epoch_acc, interval=5):\n", " if epoch % interval == 0:\n", " logging.info(f'{prefix}_Epoch [{epoch}], loss: {epoch_loss:.4f},'\n", " f' acc: {epoch_acc:.4f}.')\n", "\n", " def evaluate(self, batch, loss_func, need_acc=False, no_grad=False):\n", " if no_grad:\n", " with torch.no_grad():\n", " inputs, labels = batch\n", " outputs = self(inputs)\n", " loss = loss_func(outputs, labels)\n", " else:\n", " inputs, labels = batch\n", " outputs = self(inputs)\n", " loss = loss_func(outputs, labels)\n", "\n", " if need_acc:\n", " acc = self.compute_accuracy(outputs, labels)\n", " return {'loss': loss, 'acc': acc}\n", " else:\n", " return {'loss': loss}\n", "\n", " def compute_epoch_loss_and_acc(self, dataloader, loss_func):\n", " results = [self.evaluate(batch, loss_func, need_acc=True, no_grad=True)\n", " for batch in dataloader]\n", " batch_losses = [r['loss'] for r in results]\n", " epoch_loss = torch.stack(batch_losses).mean()\n", " batch_accs = [r['acc'] for r in results]\n", " epoch_acc = torch.stack(batch_accs).mean()\n", " return {'epoch_loss': epoch_loss, 'epoch_acc': epoch_acc}\n", "\n", " def epoch_postprocess(self, prefix, data_loader, epoch,\n", " history, loss_func, log_interval):\n", " loss_and_acc = self.compute_epoch_loss_and_acc(data_loader, loss_func)\n", " epoch_loss = loss_and_acc['epoch_loss']\n", " epoch_acc = loss_and_acc['epoch_acc']\n", " history.append({'epoch_loss': epoch_loss,\n", " 'epoch_acc': epoch_acc})\n", " self.log_epoch_loss_and_acc(prefix, epoch,\n", " epoch_loss,\n", " epoch_acc,\n", " log_interval)\n", "\n", " def train(self, train_loader, val_loader, num_epochs, lr,\n", " loss_func=F.cross_entropy, opt_func=torch.optim.SGD,\n", " log_interval=5):\n", " optimizer = opt_func(self.parameters(), lr)\n", " self.history_train = [] # history of train set\n", " self.history_val = [] # history of validation set\n", "\n", " # initial loss and accuracy of training dataset\n", " self.epoch_postprocess('Train', train_loader, 0,\n", " self.history_train, loss_func, log_interval)\n", "\n", " # initial loss and accuracy of validation dataset\n", " self.epoch_postprocess('Val', val_loader, 0,\n", " self.history_val, loss_func, log_interval)\n", "\n", " # iteration\n", " for epoch in range(num_epochs):\n", " for batch in train_loader:\n", " loss = self.evaluate(batch, loss_func, need_acc=False)['loss']\n", " loss.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " # training dataset loss and accuracy\n", " self.epoch_postprocess('Train', train_loader, epoch+1,\n", " self.history_train, loss_func, log_interval)\n", "\n", " # validation dataset loss and accuracy\n", " self.epoch_postprocess('Val', val_loader, epoch+1,\n", " self.history_val, loss_func, log_interval)\n", "\n", " def predict(self, inputs):\n", " outputs = self(inputs)\n", " _, preds = torch.max(outputs, dim=1)\n", " return [preds[i].item() for i in range(len(preds))]\n", "\n", " def save_model(self, save_file):\n", " torch.save(self.state_dict(), save_file)\n", " pickle.dump(self.history_train, open('titanic_history_train.pkl', 'wb'))\n", " pickle.dump(self.history_val, open('titanic_history_val.pkl', 'wb'))\n", "\n", "\n", " def recover_model(self, save_file):\n", " self.load_state_dict(torch.load(save_file))\n", " self.history_train = pickle.load(open('titanic_history_train.pkl', 'rb'))\n", " self.history_val = pickle.load(open('titanic_history_val.pkl', 'rb'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在训练前,还需要将测试集划分为训练集和验证集,当前采用5:1的形式进行划分。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# convert pandas dataframe to numpy array\n", "train_data = train_df.to_numpy()\n", "# convert numpy array to tensor\n", "inputs = torch.from_numpy(train_data[:, 1:]).type(torch.float)\n", "labels = torch.from_numpy(train_data[:, 0]).type(torch.long)\n", "dataset = TensorDataset(inputs, labels)\n", "train_ds, val_ds = random_split(dataset, [742, 149])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用gpu加速计算,Pytorch中使用gpu计算十分简单,只需要将训练数据和模型参数转移到显存中即可(前提是配置好cuda驱动)。" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "def to_device(data, device):\n", " \"\"\"Move tensor(s) to chosen device\"\"\"\n", " if isinstance(data, (list,tuple)):\n", " return [to_device(x, device) for x in data]\n", " return data.to(device, non_blocking=True)\n", "\n", "class DeviceDataLoader():\n", " \"\"\"Wrap a dataloader to move data to a device (default: cpu)\"\"\"\n", " def __init__(self, dl, device):\n", " self.dl = dl\n", " self.device = device\n", "\n", " def __iter__(self):\n", " \"\"\"Yield a batch of data after moving it to device\"\"\"\n", " for b in self.dl:\n", " yield to_device(b, self.device)\n", "\n", " def __len__(self):\n", " \"\"\"Number of batches\"\"\"\n", " return len(self.dl)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "03/06/2022 10:43:48 PM INFO:Initializing linear regression model.\n", "03/06/2022 10:43:50 PM INFO:Start training...\n", "03/06/2022 10:43:51 PM INFO:Train_Epoch [0], loss: 0.7585, acc: 0.6144.\n", "03/06/2022 10:43:51 PM INFO:Val_Epoch [0], loss: 0.7717, acc: 0.5962.\n", "03/06/2022 10:43:51 PM INFO:Train_Epoch [10], loss: 0.5908, acc: 0.6755.\n", "03/06/2022 10:43:51 PM INFO:Val_Epoch [10], loss: 0.6045, acc: 0.6525.\n", "03/06/2022 10:43:52 PM INFO:Train_Epoch [20], loss: 0.5441, acc: 0.7176.\n", "03/06/2022 10:43:52 PM INFO:Val_Epoch [20], loss: 0.5625, acc: 0.7412.\n", "03/06/2022 10:43:52 PM INFO:Train_Epoch [30], loss: 0.5153, acc: 0.7699.\n", "03/06/2022 10:43:52 PM INFO:Val_Epoch [30], loss: 0.5365, acc: 0.7412.\n", "03/06/2022 10:43:53 PM INFO:Train_Epoch [40], loss: 0.5048, acc: 0.7699.\n", "03/06/2022 10:43:53 PM INFO:Val_Epoch [40], loss: 0.5287, acc: 0.7475.\n", "03/06/2022 10:43:53 PM INFO:Train_Epoch [50], loss: 0.4907, acc: 0.7996.\n", "03/06/2022 10:43:53 PM INFO:Val_Epoch [50], loss: 0.5201, acc: 0.7600.\n", "03/06/2022 10:43:54 PM INFO:Train_Epoch [60], loss: 0.4833, acc: 0.8001.\n", "03/06/2022 10:43:54 PM INFO:Val_Epoch [60], loss: 0.5168, acc: 0.7862.\n", "03/06/2022 10:43:54 PM INFO:Train_Epoch [70], loss: 0.4777, acc: 0.8036.\n", "03/06/2022 10:43:54 PM INFO:Val_Epoch [70], loss: 0.5135, acc: 0.7738.\n", "03/06/2022 10:43:55 PM INFO:Train_Epoch [80], loss: 0.4720, acc: 0.7983.\n", "03/06/2022 10:43:55 PM INFO:Val_Epoch [80], loss: 0.5098, acc: 0.7800.\n", "03/06/2022 10:43:55 PM INFO:Train_Epoch [90], loss: 0.4705, acc: 0.7948.\n", "03/06/2022 10:43:55 PM INFO:Val_Epoch [90], loss: 0.5083, acc: 0.7862.\n", "03/06/2022 10:43:56 PM INFO:Train_Epoch [100], loss: 0.4694, acc: 0.8001.\n", "03/06/2022 10:43:56 PM INFO:Val_Epoch [100], loss: 0.5098, acc: 0.7800.\n", "03/06/2022 10:43:56 PM INFO:Training finished.\n", "03/06/2022 10:43:56 PM INFO:Save model.\n" ] } ], "source": [ "logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', \\\n", " level=logging.INFO, datefmt='%m/%d/%Y %I:%M:%S %p')\n", "x_dim = 7 # input dimension\n", "y_dim = 2 # label dimension\n", "train_sz = 742\n", "val_sz = 149\n", "batch_size = 16\n", "num_epochs = 100\n", "learning_rate = 0.005\n", "device = torch.device('cuda')\n", "\n", "train_loader = DataLoader(train_ds, batch_size, shuffle=True)\n", "val_loader = DataLoader(val_ds, batch_size)\n", "# move dataloader to gpu\n", "train_loader = DeviceDataLoader(train_loader, device)\n", "val_loader = DeviceDataLoader(val_loader, device)\n", "\n", "# initialize linear regression model\n", "logging.info(\"Initializing linear regression model.\")\n", "titanic_model = Titanic_Model(x_dim, y_dim)\n", "# move model parameters to gpu\n", "to_device(titanic_model, device)\n", "logging.info(\"Start training...\")\n", "titanic_model.train(train_loader, val_loader, num_epochs,\n", " learning_rate, log_interval=10, opt_func=torch.optim.SGD\n", ")\n", "logging.info(\"Training finished.\")\n", "\n", "logging.info(\"Save model.\")\n", "titanic_model.save_model('report02-titanic_model.pth')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从运行结果中可以看到,训练得到的模型在测试集上的识别准确率为80.01%,在验证集上的识别准确率为78.00%。画出迭代过程中的损失函数值与准确率的变化趋势图:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "history_train = pickle.load(open('titanic_history_train.pkl', 'rb'))\n", "history_val = pickle.load(open('titanic_history_val.pkl', 'rb'))" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Loss vs. No. of epochs')" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "train_losses = [float(x['epoch_loss']) for x in history_train]\n", "val_losses = [float(x['epoch_loss']) for x in history_val]\n", "plt.plot(train_losses, '-x', val_losses, '-x')\n", "plt.xlabel('epoch')\n", "plt.ylabel('loss')\n", "plt.title('Loss vs. No. of epochs')" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "train_accs = [float(x['epoch_acc']) for x in history_train]\n", "val_accs = [float(x['epoch_acc']) for x in history_val]\n", "plt.plot(train_accs, '-x', val_accs, '-x')\n", "plt.xlabel('epoch')\n", "plt.ylabel('accuracy')\n", "plt.title('Accuracy vs. No. of epochs')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,模型在学习速率为0.005的情况下,训练到80步的时候,准确率在79%左右不再上升。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. 小结" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "    在这份报告中,我们使用Pytorch搭建逻辑回归模型进行了泰坦尼克号生存率预测。首先,我们分析并提炼了数据集的特征数据,在这个过程中,熟悉了pandas中的dataframe数据结构的基本操作,发现其在批量处理数据时十分方便,并将一些连续值的变量变成分段变量以提高数据的凝练度。此外,该任务中,由于样本中存在大量缺失值,所以还对如何处理缺失值进行了学习。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "main_language": "python" }, "nbformat": 4, "nbformat_minor": 2 }