pytorch入门

PyTorch 为深度学习提供了强大的数据加载和可视化工具。掌握 DatasetDataLoaderTransformsTensorBoard 是高效训练模型的关键。本文将带你快速入门这四大核心组件。

📚 核心组件速记

组件 一句话记忆
Dataset __getitem____len__ → 封装你的数据源,按索引取数据
DataLoader batch_size, shuffle → 从 Dataset 中批量、随机地取数据
Transforms Compose, ToTensor → 数据预处理与增强的“流水线”
TensorBoard SummaryWriter → 训练过程可视化,调参与Debug的神器

🎨 数据流转示意图

在 PyTorch 中,数据通常经历从磁盘文件到模型输入的转换过程。理解这个流程是掌握数据加载机制的第一步。



graph TD
    subgraph "硬盘"
        A["图片文件 (cat.jpg)"]
    end
    subgraph "CPU 内存"
        B["PIL Image / OpenCV (numpy array)"]
    end
    subgraph "GPU/CPU 内存"
        C["Torch Tensor<br/>[C, H, W]"]
        D["Batch Tensor<br/>[B, C, H, W]"]
        C -- "DataLoader 批量化" --> D
    end

    A -- "Image.open()" --> B
    B -- "transforms.ToTensor()" --> C
    D ==> E[模型输入]

1️⃣ torch.utils.data.Dataset:封装你的数据

Dataset 是一个抽象类,用于表示数据集。所有自定义的数据集都应该继承它,并重写两个核心方法:

  • __len__(self): 返回数据集的大小。
  • __getitem__(self, idx): 根据索引 idx 返回一条数据(通常是一个样本和对应的标签)。

这提供了一个统一的接口,让 DataLoader 等工具可以方便地处理各种不同来源的数据。

示例:创建一个自定义 Dataset

假设我们有一些图片路径和对应的标签,我们可以这样创建一个 Dataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from torch.utils.data import Dataset
from PIL import Image
import os

class MyImageDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform

def __len__(self):
return len(self.image_paths)

def __getitem__(self, idx):
# 读取图片
image_path = self.image_paths[idx]
image = Image.open(image_path).convert("RGB")

# 获取标签
label = self.labels[idx]

# 应用数据变换
if self.transform:
image = self.transform(image)

return image, label

# 使用示例
image_files = ["./data/cat.png", "./data/dog.png"]
image_labels = [0, 1]
custom_dataset = MyImageDataset(image_files, image_labels)
print(f"数据总数: {len(custom_dataset)}")
first_item = custom_dataset[0]
print(f"第一个样本: {first_item}")

记忆:Dataset 就是个“数据地图”,__getitem__ 负责“按图索骥”。


2️⃣ torchvision.transforms:数据预处理与增强

在将数据喂给模型之前,通常需要进行预处理,例如转换为 Tensor、归一化、裁剪、旋转等。transforms 就是用来完成这些操作的工具箱。

  • 常见变换ToTensor(), Resize(), CenterCrop(), RandomHorizontalFlip(), Normalize() 等。
  • transforms.Compose(): 将多个变换操作串联成一个“流水线”。

示例:定义一个变换流水线

1
2
3
4
5
6
7
8
9
10
11
12
from torchvision import transforms

# 定义一个转换流水线
data_transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整图片大小
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(), # 将 PIL Image 或 numpy.ndarray 转换为 tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])

# 将 transform 应用到 Dataset
custom_dataset = MyImageDataset(image_files, image_labels, transform=data_transform)

为什么要随机翻转?

RandomHorizontalFlip 是一种数据增强 (Data Augmentation) 技术。它的目的是在不改变图片语义(猫翻转了还是猫)的前提下,增加训练数据的多样性。

  • 提高泛化能力:模型会学习到,无论是正着还是反着的猫,都应该被识别为猫,从而对真实世界中更多样的数据表现更好,避免过拟合
  • 扩充数据集:相当于免费让你的数据集大小翻倍。

提示:数据增强(如随机翻转)只在训练时使用,验证和测试时应去掉,以保证结果的一致性。


3️⃣ torch.utils.data.DataLoader:批量加载数据

DataLoader 是一个迭代器,它从 Dataset 中自动拉取数据,并将其打包成批次 (batch)。它还提供了许多有用的功能:

  • batch_size: 每个批次包含的样本数。
  • shuffle=True: 在每个 epoch 开始时打乱数据顺序,有助于提高模型泛化能力。
  • num_workers: 使用多少个子进程来预加载数据。增加此值可以加快数据加载速度,但会消耗更多内存和 CPU。

示例:使用 DataLoader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from torch.utils.data import DataLoader

# 假设 custom_dataset 已经创建好
custom_dataset = MyImageDataset(image_files, image_labels, transform=data_transform)

data_loader = DataLoader(
dataset=custom_dataset,
batch_size=32,
shuffle=True,
num_workers=4
)

# 像遍历列表一样遍历 DataLoader
for epoch in range(num_epochs):
for images, labels in data_loader:
# 在这里执行模型训练
# images 的形状: [32, 3, 224, 224] (假设batch_size=32)
# labels 的形状: [32]
pass


graph TD
    A[原始数据文件] --> B(Dataset);
    B -- 按索引 __getitem__ --> C{单条数据};
    D(Transforms) -- 应用于 --> C;
    C --> E(DataLoader);
    E -- 打包成 batch --> F[模型];


4️⃣ torch.utils.tensorboard:可视化训练过程

TensorBoard 是一个强大的可视化工具,可以帮助你理解、调试和优化你的模型。PyTorch 通过 SummaryWriter 类与其集成。

主要功能: - 记录标量(如 loss、accuracy)的变化曲线。 - 可视化图像、模型图、embedding 等。

示例:使用 SummaryWriter

add_scalar 是最常用的方法之一,它用于记录单个标量值的变化。

  • tag: 图表的标题,例如 'Loss/train' 或 'Accuracy/validation'。使用 / 可以帮助在 TensorBoard 中对图表进行分组。
  • scalar_value: 要记录的数值,通常是 loss 值或准确率。
  • global_step: x 轴坐标,通常是训练的步数(batch index)或轮数(epoch index)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from torch.utils.tensorboard import SummaryWriter

# 1. 创建一个 writer 实例,日志会保存在 'logs' 目录下
writer = SummaryWriter('logs/experiment_1')

# 2. 在训练循环中记录信息
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(data_loader, 0):
# ... 训练代码 ...
loss = criterion(outputs, labels)
running_loss += loss.item()

# 每 100 个 batch 记录一次 loss
if i % 100 == 99:
writer.add_scalar('training loss', running_loss / 100, epoch * len(data_loader) + i)
running_loss = 0.0

假设在每个 epoch 结束后记录一张图片
writer.add_image('one_image', images[0], epoch)

# 3. 训练结束后关闭 writer
writer.close()

记录完成后,在终端中运行 tensorboard --logdir=logs,然后访问浏览器中显示的地址(通常是 http://localhost:6006)即可查看可视化结果。


📝 小结

  1. Dataset:数据源的标准化封装,核心是 __len____getitem__
  2. Transforms:数据预处理和增强的流水线,使用 Compose 组合。
  3. DataLoader:将 Dataset 包装成可迭代对象,实现批量、随机和并行加载。
  4. TensorBoard:通过 SummaryWriter 记录训练过程,实现可视化分析。

熟练运用这四大组件,你的 PyTorch 项目将变得更加规范、高效和易于调试 🚀