- 1、原创力文档(book118)网站文档一经付费(服务费),不意味着购买了该文档的版权,仅供个人/单位学习、研究之用,不得用于商业用途,未经授权,严禁复制、发行、汇编、翻译或者网络传播等,侵权必究。。
- 2、本站所有内容均由合作方或网友上传,本站不对文档的完整性、权威性及其观点立场正确性做任何保证或承诺!文档内容仅供研究参考,付费前请自行鉴别。如您付费,意味着您自己接受本站规则且自行承担风险,本站不退款、不进行额外附加服务;查看《如何避免下载的几个坑》。如果您已付费下载过本站文档,您可以点击 这里二次下载。
- 3、如文档侵犯商业秘密、侵犯著作权、侵犯人身权等,请点击“版权申诉”(推荐),也可以打举报电话:400-050-0827(电话支持时间:9:00-18:30)。
- 4、该文档为VIP文档,如果想要下载,成为VIP会员后,下载免费。
- 5、成为VIP后,下载本文档将扣除1次下载权益。下载后,不支持退款、换文档。如有疑问请联系我们。
- 6、成为VIP后,您将拥有八大权益,权益包括:VIP文档下载权益、阅读免打扰、文档格式转换、高级专利检索、专属身份标志、高级客服、多端互通、版权登记。
- 7、VIP文档为合作方或网友上传,每下载1次, 网站将根据用户上传文档的质量评分、类型等,对文档贡献者给予高额补贴、流量扶持。如果你也想贡献VIP文档。上传文档
查看更多
python卷积神经⽹络多元分类_卷积神经⽹络分类MNIST
# 卷积神经⽹络分类MNIST
之前两期简单介绍了神经⽹络的基础知识(由于我懒,⼀张插图都没有上)。
这⼀期我们来介绍⼀个简单的实现。基于Pytorch,训练⼀个简单的卷积神经⽹络⽤于分类MNIST数据集。同样ipynb⽂件到我的群⾥下
载。后边写得⽐较多了以后考虑整理⼀下放到Github上。
数据集的导⼊之前已经介绍过,所以就不重复了。
## 神经⽹络的搭建
Pytorch中的神经⽹络搭建都要构造成类。类中确定了神经⽹络的结构。训练的时候就要构造⼀个具体的实例。训练好实例后,还要⽤这个
实例去预测。
```python
# 定义⽹络
import torch.nn as nn
import torch.nn.functional as F
# 卷积之后,全连接层输⼊向量的长度。
# 由卷积结果的通道数乘以卷积结果的长再乘以卷积结果的宽得到。
# 这⾥没有⽤sigmoid函数,⽽是使⽤了softmax层
FC_SIZE = 16*2*2
class ClassificationMNIST(nn.Module):
def __init__(self):
super(ClassificationMNIST,self).__init__()
self.conv=nn.Sequential(nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, stride=1,padding=2),nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(4, 8, 3, 1),nn.ReLU(), nn.MaxPool2d(kernel_size=2),
nn.Conv2d(8, 16, 3, 1),nn.ReLU(), nn.MaxPool2d(kernel_size=2))
#self.fc = nn.Linear(FC_SIZE,10)
self.fc = nn.Sequential(nn.Linear(FC_SIZE, 10), nn.Softmax())
def forward(self, x):
x = self.conv(x)
x = x.view(-1,FC_SIZE)
#output = torch.sigmoid(x)
output = self.fc(x)
return output
```
基本的神经⽹络需要写两个⽅法。在`__init__`中列出所有要⽤到的卷积层。注意这⾥的卷积层,全连接层可以⾃⼰写。但是更推荐直接调
⽤`torch.nn`中的。
这⾥的`nn.Conv2d`等函数,的返回值实际上也是⼀个可调⽤对象。⽽且你构造的神经⽹络类的实例本⾝也是可调⽤对象。
第⼆个必须写的⽅法是`forward`⽅法。这个⽅法表⽰当你的神经⽹络对象被调⽤的时候,要执⾏的⽅法。
## 神经⽹络的训练
在训练开始之前要设置⼏个参数。
```python
# 初始化⽹络,及相关函数
cnn=ClassificationMNIST().to(device)
EPOCH = 10
# 学习率是个超参数,这东西是实验出来的
LR = 0.001
# 定义损失函数和优化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = torch.nn.CrossEntropyLoss(size_average=False)
```
其中的`loss_func`是损失函数。`optimizer`是优化器。损失函数就如前边所说是,判定我们的⽹络和⽬标函数近似程度的。`optimizer`
则确定梯度更新的策略。它们都有很多选择,这⾥我们就先直接⽤,以后再讨论有哪些选择,以及不同选择的优劣。
接下来的训练过程其实就是⼀个⼤循环。由于⼀般图像数据量巨⼤,我们不能⼀次⽤上所有的数据。我们就把整个图像数据集分为若⼲
batch。通过内层循环遍历所有数据。循环体中则是每次正向求出⽹络的分类结果,然后求和正确结果之间的loss,然后逆向求导,更新参
数。
可以看到我的代码中,出了正向计算和反向传播,还多了⼀些内容。这些是⽤于隔⼀段时间,计算⼀下在测试集上的预测正确率,然后输
出的。防⽌我们在等神经⽹络出结果的过程中太过⽆聊。
```python
# 训练
for epoch in range(EPOCH):
for i,data in enumerate(train_loader):
# 获取数据
您可能关注的文档
- 人教版五年级数学2019-2020学年度第一学期期末调研试卷(无答案).pdf
- 比例的基本性质练习题6729.pdf
- Java职业生涯规划.pdf
- (完整word版)七年级上册英语词汇专项训练.pdf
- 2021年住院医师规范化培训师资考核方案(精华版).pdf
- java后端语言,后端开发语言哪一种比较好?后端开发语言比较.pdf
- 七年级上册生物知识与能力训练答案.pdf
- JupyterNotebook超好用的扩展之代码自动补全、自动目录等.pdf
- 8种常见的版式设计,让你的UI设计作品美炸天!.pdf
- c语言程序设计冯志红pdf,C语言程序设计:现代方法(第2版)中文pdf扫描版[219MB]....pdf
- 2020六年级道德与法治全册 第三单元 第七课 亲情之爱 第2框《爱在家人间》同步练习 新人教版.pdf
- 2020-2021学年第二学期期末教学质量检测 五年级英语试卷.pdf
- 【微知识】10种不同测量摩天大楼高度的方法.pdf
- 【Javaweb】javaweb必备学习路线图-专为初学者打造.pdf
- web前端实训总结.pdf
- js和java前后端传递Date类型数据的问题.pdf
- 2022年九年级道德与法治上册第三单元文明与家园第六课建设美丽中国第2框共筑生命家园教案新人教版.pdf
- python第二版答案第六章_Python语言程序设计基础(第2版)课后题第六章.pdf
- java调用webservice传字符串参数.pdf
- 【木马免杀教程-零基础免杀第一课】.pdf
最近下载
- 通风防排是烟工程合同.doc VIP
- 不同层级护士核心能力的培养.pptx VIP
- 微型计算机原理与接口技术第二版邹逢兴部分习题答案.doc VIP
- 幼儿园课件::认识少数民族.pptx VIP
- DELIXI德力西CJX2s说明书.pdf
- YV100XG机器FAMF校正培训教材.docx VIP
- 2025至2030中国硫酸钙晶须行业市场发展现状及竞争格局与投资发展报告.docx
- (高清版)DB62∕T 3237-2023 建筑钢结构防火技术标准.docx VIP
- (四检)厦门市2025届高三第四次质量检测 生物试卷(含答案).docx
- 2025年220KV输电线路施工组织措施及施工方案1.pdf VIP
原创力文档


文档评论(0)