Pytorch:神经网络过程代码详解

文章目录

  • 一、基本概念
    • 1、epoch
    • 2、遍历DataLoader
  • 二、神经网络训练过程代码详解
      • 步骤一:选择并初始化优化器
      • 步骤二:计算损失
      • 步骤三:反向传播
      • 步骤四:更新模型参数
      • 步骤五:清空梯度
      • 组合到训练循环中
      • 步骤六:保存模型
  • 三、神经网络评估过程代码详解
      • 步骤一:加载模型
      • 步骤二:切换至评估模式
      • 步骤三:进行评估(这里计算分类问题的准确率)
  • 四、经典数据集——鸢尾花数据集


一、基本概念

for epoch in range(total_epoch):
	for label_x,label_y in dataloader:
		pass

1、epoch

  epoch 指的是整个数据集在训练过程中被完整地遍历一次。如果数据集被分成多个批次输入模型,则一个 epoch 完成后意味着所有的批次已被模型处理一次。epoch 的数目通常根据训练数据的大小、模型复杂度和任务需求来决定。每个 epoch 结束后,模型学到的知识会更加深入,但也存在过度学习(过拟合)的风险,特别是当 epoch 数目过多时。
  即每一个epoch会处理所有的batchepoch也被称为训练周期

2、遍历DataLoader

  遍历DataLoader,实际上就是每次取出一个batch的数据。

二、神经网络训练过程代码详解

建议先理解:Module模块

步骤一:选择并初始化优化器

首先,根据模型的需求选择一个合适的优化器。不同的优化器可能适合不同类型的数据和网络架构。一旦选择了优化器,需要将模型的参数传递给它,并设置一些特定的参数,如学习率、权重衰减等。

import torch.optim as optim

# 假设 model 是你的网络模型
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

在这个例子中,选择了随机梯度下降(SGD)作为优化器,并设置了学习率和动量。

步骤二:计算损失

在训练循环中,每次迭代都会处理一批数据,模型会根据这些数据进行预测,并计算损失。

criterion = torch.nn.CrossEntropyLoss()  # 选择合适的损失函数
outputs = model(inputs)                 # 前向传播
loss = criterion(outputs, labels)       # 计算损失

步骤三:反向传播

一旦有了损失,就可以使用 .backward() 方法来自动计算模型中所有可训练参数的梯度。

loss.backward()

这一步将计算损失函数相对于每个参数的梯度,并将它们存储在各个参数的 .grad 属性中。

步骤四:更新模型参数

使用优化器的 .step() 方法来根据计算得到的梯度更新参数。

optimizer.step()

这个调用会更新模型的参数,具体的更新方式取决于你选择的优化算法。

步骤五:清空梯度

在每次迭代后,需要手动清空梯度,以便下一次迭代。如果不清空梯度,梯度会累积,导致不正确的参数更新。

optimizer.zero_grad()

组合到训练循环中

将上述步骤组合到一个训练循环中,我们得到了完整的训练过程:

model = MyModel() #实例化神经网络层,调用继承自Module类的MyModel类的构造函数
criterion = torch.nn.CrossEntropyLoss()  # 选择合适的损失函数,这里是交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 定义优化器,传入模型参数
model.train()#切换至训练模式
for epoch in range(total_epochs):
    for inputs, labels in dataloader:  # 从数据加载器获取数据
        inputs, labels = inputs.to(device), labels.to(device)
        
        # 前向传播
        outputs = model(inputs)
        # 计算损失
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()  # 清空之前的梯度
        loss.backward()        # 反向传播
        # 优化参数
        optimizer.step()       # 更新参数
        
        print(f'Epoch [{epoch+1}/{total_epochs}], Loss: {loss.item()}')
  • loss.item()

    • 在 PyTorch 中,loss 是一个 torch.Tensor 对象。当计算模型的损失时,这个对象通常只包含一个元素(一个标量值),它代表了当前批次数据的损失值。loss.item() 方法是从包含单个元素的张量中提取出那个标量值作为 Python 数值。这是很有用的,因为它允许你将损失值脱离张量的形式进行进一步的处理或输出,比如打印、记录或做条件判断。
  • print(f'Epoch [{epoch+1}/{total_epochs}], Loss: {loss.item()}')

    • 这行代码是用来在训练过程中输出当前 epoch 的编号和该 epoch 的损失值。这对于监控训练进程和调试模型非常有帮助。具体来说:
      • epoch+1:由于计数通常从 0 开始,所以 +1 是为了更自然地显示(从 1 开始而不是从 0 开始)。
      • {total_epochs}:这是训练过程中总的 epoch 数。
      • {loss.item()}:如前所述,这表示当前批次的损失值,作为一个标量数值输出。

步骤六:保存模型

torch.save(model.state_dict(),"model.pth")

三、神经网络评估过程代码详解

步骤一:加载模型

实例化对应模型,使用该模型对象的.load_state_dict()方法导入之前存储的模型参数。

model=MyModel()
model.to(device)
model.load_state_dict(torch.load("model.pth"))

步骤二:切换至评估模式

model.eval()

步骤三:进行评估(这里计算分类问题的准确率)

  • 定义评估时需要的变量
  • 使用torch.no_grad()指定上下文Pytorch不追踪梯度信息
total_correct=0 #正确的样本数
total_samples=0 #样本数
with torch.no_grad():# 该局部化区域内的张量不再计算梯度
	for batch_x,batch_y in dataloader:
		batch_x.to(device)
		batch_y.to(device)
		batch_x = batch_x.to(torch.float) #转换成浮点

		output = model(batch_x) # 前向传播,得到分类结果 形状为[batch_size,num_classes]
		_,predicted = torch.max(output,dim=1)  # 不考虑dim=0的batch_size,从第一维开始考虑,沿着dim=1的方向寻找最大值,实际上就是找分类得分最高的分类,predicted接收的是max的索引,因此predicted的形状是[batch_size,1] 就是每一个样本的分类
		total_correct += (predicted == batch_y).sum().item()
		total_samples += predicted.size(dim=0)
accuracy=total_correct / total_samples
print(f"accuracy:{accuracy}")

_,predicted=tensor.max(output,dim=1)

  • output形状是[batch_size,classes_num],沿着列求最大值, 得到列中的最大值索引,相当于得到的predicted的形状是[batch_size,1],这里的1 的数值 是一行中 最大值的索引 也就是预测的类别。然后batch_y 的形状是[batch_size,1] predicted进行对比,就是对比类别是否相同。所以,我们在考虑问题的时候,由于batch_size的存在,第0维忽略掉考虑也行,然后就好理解了

四、经典数据集——鸢尾花数据集

代码来源

import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader


# 此函数用于加载鸢尾花数据集
def load_data(shuffle=True):
    x = torch.tensor(load_iris().data)
    y = torch.tensor(load_iris().target)

    # 数据归一化
    x_min = torch.min(x, dim=0).values
    x_max = torch.max(x, dim=0).values
    x = (x - x_min) / (x_max - x_min)

    if shuffle:
        idx = torch.randperm(x.shape[0])
        x = x[idx]
        y = y[idx]
    return x, y


# 自定义鸢尾花数据类
class IrisDataset(Dataset):
    def __init__(self, mode='train', num_train=120, num_dev=15):
        super(IrisDataset, self).__init__()
        x, y = load_data(shuffle=True)  # 将x转换为浮点型数据
        y = y.long()  # 将y转换为长整型数据
        # x, y = load_data(shuffle=True)
        if mode == 'train':
            self.x, self.y = x[:num_train], y[:num_train]
        elif mode == 'dev':
            self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev]
        else:
            self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:]

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)


# 创建一个模型类来定义神经网络模型
class IrisModel(nn.Module):
    def __init__(self):
        super(IrisModel, self).__init__()
        self.fc = nn.Linear(4, 3)

    def forward(self, x):
        return self.fc(x)


# 加载数据
batch_size = 16

train_dataset = IrisDataset(mode='train')
dev_dataset = IrisDataset(mode='dev')
test_dataset = IrisDataset(mode='test')

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# 实例化神经网络模型
model = IrisModel()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        batch_x = batch_x.to(torch.float)  # 使用float32数据类型
        # RuntimeError: mat1 and mat2 must have the same dtype, but got Double and Float
        optimizer.zero_grad()
        output = model(batch_x)
        loss = criterion(output, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {avg_loss}")

# 保存模型
torch.save(model.state_dict(), 'iris_model.pth')

#%%
# 加载模型
model = IrisModel()  # 先实例化一个模型
model.to(device)
model.load_state_dict(torch.load('iris_model.pth'))

# 评估模型
model.eval()
total_correct = 0
total_samples = 0

with torch.no_grad():
    for batch_x, batch_y in test_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        batch_x = batch_x.to(torch.float)  # 使用float32数据类型
        output = model(batch_x)
        _, predicted = torch.max(output, dim=1)
        total_correct += (predicted == batch_y).sum().item()
        total_samples += batch_y.size(0)

accuracy = total_correct / total_samples
print(f"Test Accuracy: {accuracy}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/575193.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Mysql 、Redis 数据双写一致性 更新策略与应用

零、important point 1. 缓存双写一致性问题 2. java实现逻辑&#xff08;对于 QPS < 1000 可以使用&#xff09; public class UserService {public static final String CACHE_KEY_USER "user:";Resourceprivate UserMapper userMapper;Resourceprivate Re…

javascript使用setTimeout函数来实现仅执行最后一次操作

在JavaScript中&#xff0c;setTimeout函数用于在指定的毫秒数后执行一个函数或计算表达式。它的主要用途是允许开发者延迟执行某些代码&#xff0c;而不是立即执行。 当我们想要确保仅最后一次更新UI时&#xff0c;我们可以使用setTimeout来合并多次连续的更新请求。具体做法…

C++11 数据结构7 队列的链式存储,实现,测试

前期考虑 队列是两边都有开口&#xff0c;那么在链式情况下&#xff0c;线性表的链式那一边作为对头好呢&#xff1f; 从线性表的核心的插入和删除算法来看&#xff0c;如果在线性表链表的头部插入&#xff0c;每次循环都不会走&#xff0c;但是删除的时候&#xff0c;要删除线…

回归与聚类——K-Means(六)

什么是无监督学习 一家广告平台需要根据相似的人口学特征和购买习惯将美国人口分成不同的小 组&#xff0c;以便广告客户可以通过有关联的广告接触到他们的目标客户。Airbnb 需要将自己的房屋清单分组成不同的社区&#xff0c;以便用户能更轻松地查阅这些清单。一个数据科学团队…

Python爱心代码

爱心效果图&#xff1a; 完整代码&#xff1a; import random from math import sin, cos, pi, log from tkinter import *# 定义画布尺寸和颜色 CANVAS_WIDTH 640 CANVAS_HEIGHT 480 CANVAS_CENTER_X CANVAS_WIDTH / 2 CANVAS_CENTER_Y CANVAS_HEIGHT / 2 IMAGE_ENLARG…

C#实现TFTP客户端

1、文件结构 2、TftpConfig.cs using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;namespace TftpTest {public class TftpConfig{}/// <summary>/// 模式/// </summary>public enum Modes{…

大模型都在用的:旋转位置编码

写在前面 这篇文章提到了绝对位置编码和相对位置编码&#xff0c;但是他们都有局限性&#xff0c;比如绝对位置编码不能直接表征token的相对位置关系&#xff1b;相对位置编码过于复杂&#xff0c;影响效率。于是诞生了一种用绝对位置编码的方式实现相对位置编码的编码方式——…

LS2K1000LA基础教程

基于LS2K1000LA的基础教程 by 南京工业大学 孙冬梅 于 2024.4.25 文章目录 基于LS2K1000LA的基础教程一、目的二、平台1.硬件平台2.软件平台 三、测试0.开发板开机及编译器配置0.1 开发板控制台0.2 虚拟机编译器配置 1. 简单应用编程1.helloworld.c2. fileio 文件操作3.proce…

Scrapy 爬虫教程:从原理到实战

Scrapy 爬虫教程&#xff1a;从原理到实战 一、Scrapy框架简介 Scrapy是一个由Python开发的高效网络爬虫框架&#xff0c;用于从网站上抓取数据并提取结构化信息。它采用异步IO处理请求&#xff0c;能够同时发送多个请求&#xff0c;极大地提高了爬虫效率。 二、Scrapy运行原…

入坑 Java

原文&#xff1a;https://blog.iyatt.com/?p11305 前言 今天&#xff08;2023.8.31&#xff09;有个学长问我接不接一个单子&#xff0c;奈何没学过 Java&#xff0c;本来不打算接的。只是报酬感觉还不错&#xff0c;就接了。 要求的完成时间是在10月初&#xff0c;总共有一…

Spring Boost + Elasticsearch 实现检索查询

需求&#xff1a;对“昵称”进行“全文检索查询”&#xff0c;对“账号”进行“精确查询”。 认识 Elasticsearch 1. ES 的倒排索引 正向索引 对 id 进行检索速度很快。对其他字段即使加了索引&#xff0c;只能满足精确查询。模糊查询时&#xff0c;逐条数据扫描&#xff0c…

编译原理实验课

本人没咋学编译原理&#xff0c;能力有限&#xff0c;写的不好轻点喷&#xff0c;大佬路过的话&#xff0c;那你就路过就好 东大编译原理实验课原题&#xff0c;22年 1. 基本题&#xff1a;简单的扫描器设计 【问题描述】 熟悉并实现一个简单的扫描器&#xff0c;设计扫描器…

C++ | Leetcode C++题解之第49题字母异位词分组

题目&#xff1a; 题解&#xff1a; class Solution { public:vector<vector<string>> groupAnagrams(vector<string>& strs) {// 自定义对 array<int, 26> 类型的哈希函数auto arrayHash [fn hash<int>{}] (const array<int, 26>&…

黑马点评(十二) -- UV统计

一 . UV统计-HyperLogLog 首先我们搞懂两个概念&#xff1a; UV&#xff1a;全称Unique Visitor&#xff0c;也叫独立访客量&#xff0c;是指通过互联网访问、浏览这个网页的自然人。1天内同一个用户多次访问该网站&#xff0c;只记录1次。 PV&#xff1a;全称Page View&…

linux权限维持(四)

6.inetd服务后门 inetd 是一个监听外部网络请求 ( 就是一个 socket) 的系统守护进程&#xff0c;默认情况下为 13 端口。当 inetd 接收到 一个外部请求后&#xff0c;它会根据这个请求到自己的配置文件中去找到实际处理它的程序&#xff0c;然后再把接收到的 这个socket 交给那…

机器学习 -- 分类问题

场景 探讨了一个回归任务——预测住房价格&#xff0c;用到了线性回归、决策树以及随机森林等各种算法。本次中我们将把注意力转向分类系统。我们曾经对MNIST进行了分类任务&#xff0c;这次我们重新回到这里&#xff0c;细致的再来一次。 开始 获取数据 Scikit-Learn提供了…

力扣爆刷第127天之动态规划五连刷(整数拆分、一和零、背包)

力扣爆刷第127天之动态规划五连刷&#xff08;整数拆分、一和零、背包&#xff09; 文章目录 力扣爆刷第127天之动态规划五连刷&#xff08;整数拆分、一和零、背包&#xff09;关于0 1 背包问题的总结01背包遍历顺序&#xff1a;完全背包遍历顺序&#xff1a; 一、343. 整数拆…

Lock-It for Mac(应用程序加密工具)

OSXBytes Lock-It for Mac是一款功能强大的应用程序加密工具&#xff0c;专为Mac用户设计。该软件具有多种功能&#xff0c;旨在保护用户的隐私和数据安全。 Lock-It for Mac v1.3.0激活版下载 首先&#xff0c;Lock-It for Mac能够完全隐藏应用程序&#xff0c;使其不易被他人…

【Pytorch】(十四)C++ 加载TorchScript 模型

文章目录 &#xff08;十四&#xff09;C 加载TorchScript 模型Step 1: 将PyTorch模型转换为TorchScriptStep 2: 将TorchScript序列化为文件Step 3: C程序中加载TorchScript模型Step 4: C程序中运行TorchScript模型 【Pytorch】&#xff08;十三&#xff09;PyTorch模型部署: T…

平衡二叉树、红黑树、B树、B+树

Tree 1、前言2、平衡二叉树和红黑树3、B树和B树3.1、B树的构建3.2、B树和B树的区别3.3、数据的存储方式 1、前言 本文侧重在理论方面对平衡二叉树、红黑树、B树和B树的各方面性能进行比较。不涉及编程方面的实现。而关于于平衡二叉树在C中的实现&#xff0c;我的上一篇文章平衡…
最新文章