PyTorch中的nn.Dropout

在深度学习中,过拟合是一个常见的问题,而 nn.Dropout 是一种常用的正则化技术,有助于减少过拟合现象。
nn.Dropout 是 PyTorch 中用于应用 Dropout 正则化的模块。

作用

在训练阶段,nn.Dropout 按照指定的概率随机丢弃输入张量中的部分元素,以减少神经网络的复杂性。这有助于提高模型的泛化能力,避免过拟合。

nn.Dropout 的基本用法

torch.nn.Dropout(p=0.5, inplace=False)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn

class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.dropout = nn.Dropout(p=0.5) # 添加 Dropout 模块

def forward(self, x):
x = self.fc1(x)
x = self.dropout(x) # 在前向传播函数中应用 Dropout
return x

# 创建模型实例
model = MyModel()

训练和测试阶段的使用
在训练和测试阶段,需要注意 nn.Dropout 的行为不同:

训练阶段:nn.Dropout 会按照指定的概率丢弃输入元素。
测试阶段:nn.Dropout 不会丢弃输入元素,而是将输入内容原样输出。

不需要用if语句来控制是否是训练阶段,只需在训练阶段设置model.train(),在测试阶段设置model.eval()即可。

1
2
3
4
5
6
7
8
9
10
11
12
input_data = torch.randn(1, 10)

# 训练阶段
model.train()
output_train = model(input_data)

# 测试阶段
model.eval()
output_test = model(input_data)

print("Output during training:", output_train)
print("Output during testing:", output_test)

PyTorch中的nn.Dropout
http://jiqingjiang.github.io/p/b928e7f1/
作者
Jiqing
发布于
2024年8月1日
许可协议