本文深入介绍深度学习中最核心的多分类激活函数:Softmax 函数。从数学原理到实际应用,用通俗易懂的方式带你理解这个将 logits 转换为概率分布的强大工具。
激活函数系列(二):Softmax - 从入门到精通
系列导读 : 继 Sigmoid 之后,本文介绍 Softmax 函数。后续我们还将介绍 ReLU、Tanh、GELU 等常用激活函数。
🎯 什么是 Softmax? Softmax 函数是深度学习中最常用的多分类激活函数 。简单来说,它将一个任意实数向量转换为一个概率分布 ,其中:
每个输出都在 (0, 1) 之间
所有输出之和为 1
较大的输入会获得较大的概率
名称由来 : “Softmax” = “Soft”(软化)+ “max”(最大)。它对 argmax(取最大值的索引)进行了”软化”,使得我们可以得到一个可微分的概率分布,而不是硬性的最大值索引。
graph LR
A["输入 logits: [2.0, 1.0, 0.1]"] --> B["Softmax 函数"]
B --> C["输出概率: [0.705, 0.259, 0.036]"]
C --> D["和为 1.0"]
style B fill:#ffeb3b,stroke:#f57c00,stroke-width:3px
为什么需要 Softmax? 在多分类问题中,我们需要网络输出属于每个类别的概率 。原始的网络输出(称为 logits )是任意实数,无法直接解释为概率。Softmax 解决了这个问题:
输入 : 任意实数向量(logits)
Softmax : 转换为概率分布
输出 : 每个类别的概率,便于决策和优化
📊 Softmax 函数:数学定义 1. 函数定义 给定一个 K 维实数向量 x = [x₁, x₂, …, x_K],Softmax 函数定义为:
其中:
$e^{x_i}$ 是对第 i 个元素取指数
$\sum_{j=1}^{K} e^{x_j}$ 是所有元素指数之和(归一化因子)
2. 关键性质 graph TD
A["Softmax 输出"] --> B["性质 1: 范围 (0, 1)"]
A --> C["性质 2: 和为 1"]
A --> D["性质 3: 单调性"]
A --> E["性质 4: 缩放不变性"]
B --> F["可解释为概率"]
C --> G["形成有效概率分布"]
D --> H["保持输入的相对大小关系"]
E --> I["加上常数不改变输出"]
style B fill:#c8e6c9
style C fill:#c8e6c9
style D fill:#fff9c4
style E fill:#fff9c4
性质
数学表达
直观解释
输出有界
0 < softmax(xᵢ) < 1
每个输出都在 0 和 1 之间
和为 1
$\sum_i \text{softmax}(x_i) = 1$
形成有效的概率分布
单调性
xᵢ > xⱼ ⇒ softmax(xᵢ) > softmax(xⱼ)
保持输入的相对大小关系
缩放不变性
softmax(x + c) = softmax(x)
所有输入加同一常数,输出不变
3. 计算示例 假设输入向量 x = [2.0, 1.0, 0.1]:
步骤 1 : 计算指数
步骤 2 : 计算和
步骤 3 : 计算每个概率
最终输出 : [0.659, 0.242, 0.099]
注意到:最大的输入 2.0 对应最大的概率 0.659,且概率之和为 1。
🧮 数学特性深入 1. 缩放不变性(重要!) Softmax 函数具有数值稳定性 相关的缩放不变性:
为什么这很重要?
在实际计算中,如果 xᵢ 很大(例如 1000),e^xᵢ 会溢出(超过浮点数表示范围)。我们可以通过减去一个常数 c 来避免溢出。
常用技巧 : 减去最大值
这样最大的指数是 e⁰ = 1,永远不会溢出!
2. 梯度推导 Softmax 的梯度是其优雅性质的来源。考虑交叉熵损失:
其中:
yᵢ 是真实标签(one-hot 编码)
ŷᵢ = softmax(xᵢ) 是预测概率
梯度计算 :
这意味着什么?
梯度的简单性 使得反向传播非常高效!误差就是预测概率与真实标签的差。
3. 与 Sigmoid 的关系 Softmax 可以看作是 Sigmoid 的多维推广 :
Sigmoid : 将单个实数映射到 (0, 1) —— 二分类
Softmax : 将向量映射到概率分布 —— 多分类
二分类时的等价性 :
✅ Softmax 的优点
优点
说明
应用场景
概率解释
输出可解释为属于每个类别的概率
多分类任务
可微分
处处可导,便于梯度下降优化
深度网络训练
输出有界
输出严格在 (0, 1) 之间,不会爆炸
数值稳定
梯度简洁
梯度形式简单(预测值-真实值)
高效反向传播
归一化
自动归一化为概率分布
直接用于决策
❌ Softmax 的缺点与挑战
缺点
说明
后果
计算复杂
需要计算指数和,计算量大
大类别数时慢
数值不稳定
指数运算可能溢出或下溢
需要特殊处理
不适用于大类别
O(K) 复杂度,K 为类别数
类别数>1000时困难
对异常值敏感
单个大的 logit 会主导输出
概率分布过于极端
数值稳定性问题 问题 1: 上溢
当 xᵢ 很大(例如 1000),e^xᵢ 会溢出为 ∞
问题 2: 下溢
当 xᵢ 很小(例如 -1000),e^xᵢ 会下溢为 0
所有 e^xⱼ 都为 0 时,分母为 0,无法计算
解决方案 : LogSumExp 技巧
💻 代码示例 1. Python 实现(基础版) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import numpy as npdef softmax_basic (x ): """ 基础版 Softmax 实现 注意:数值不稳定,仅用于理解 """ exp_x = np.exp(x) return exp_x / np.sum (exp_x, axis=-1 , keepdims=True ) x = np.array([2.0 , 1.0 , 0.1 ]) prob = softmax_basic(x) print (f"输入: {x} " )print (f"概率: {prob} " )print (f"和: {prob.sum ()} " )
2. 数值稳定版本 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 def softmax_stable (x ): """ 数值稳定的 Softmax 实现 使用 max(x) 防止指数溢出 """ exp_x = np.exp(x - np.max (x, axis=-1 , keepdims=True )) return exp_x / np.sum (exp_x, axis=-1 , keepdims=True ) x_large = np.array([1000 , 900 , 800 ]) prob_stable = softmax_stable(x_large) print (f"\n大数测试: {x_large} " )print (f"稳定版本: {prob_stable} " )print (f"和: {prob_stable.sum ()} " )
3. PyTorch 实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 import torchimport torch.nn as nnsoftmax_layer = nn.Softmax(dim=1 ) x = torch.tensor([[2.0 , 1.0 , 0.1 ]]) prob = softmax_layer(x) print (f"PyTorch Softmax: {prob} " )import torch.nn.functional as Fprob_func = F.softmax(x, dim=1 ) print (f"F.softmax: {prob_func} " )log_softmax = F.log_softmax(x, dim=1 ) print (f"LogSoftmax: {log_softmax} " )prob_from_log = torch.exp(log_softmax) print (f"从 LogSoftmax 恢复: {prob_from_log} " )
4. LogSumExp 技巧 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def log_softmax (x ): """ LogSoftmax 实现,使用 LogSumExp 技巧 数值最稳定,常用于 NLLLoss """ x_max = np.max (x, axis=-1 , keepdims=True ) return x - x_max - np.log(np.sum (np.exp(x - x_max), axis=-1 , keepdims=True )) x = np.array([2.0 , 1.0 , 0.1 ]) log_prob = log_softmax(x) prob = np.exp(log_prob) print (f"\nLogSoftmax: {log_prob} " )print (f"恢复的概率: {prob} " )print (f"和: {prob.sum ()} " )
5. 可视化 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 import matplotlib.pyplot as pltimport numpy as npx = np.linspace(-10 , 10 , 100 ) sigmoid = 1 / (1 + np.exp(-x)) softmax_2class = np.exp(x) / (np.exp(x) + np.exp(0 )) plt.figure(figsize=(12 , 5 )) plt.subplot(1 , 2 , 1 ) plt.plot(x, sigmoid, label='Sigmoid' , linewidth=2 ) plt.plot(x, softmax_2class, label='Softmax (2-class)' , linewidth=2 , linestyle='--' ) plt.axhline(y=0.5 , color='r' , linestyle=':' , alpha=0.5 ) plt.axvline(x=0 , color='r' , linestyle=':' , alpha=0.5 ) plt.title('Sigmoid vs Softmax (二分类)' ) plt.xlabel('x' ) plt.ylabel('Probability' ) plt.legend() plt.grid(True , alpha=0.3 ) x_3d = np.array([2.0 , 1.0 , 0.1 ]) prob_3d = np.exp(x_3d) / np.sum (np.exp(x_3d)) plt.subplot(1 , 2 , 2 ) categories = ['Class 1' , 'Class 2' , 'Class 3' ] colors = ['#e74c3c' , '#3498db' , '#2ecc71' ] bars = plt.bar(categories, prob_3d, color=colors) plt.title('Softmax 多分类示例' ) plt.ylabel('Probability' ) plt.ylim([0 , 1 ]) for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2. , height, f'{height:.3 f} ' , ha='center' , va='bottom' ) plt.tight_layout() plt.savefig('softmax_visualization.png' , dpi=300 ) plt.show()
🎮 实际应用场景 1. 图像分类 场景 : 将图像分类到 1000 个类别(ImageNet)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 import torchimport torch.nn as nnclass ImageClassifier (nn.Module): def __init__ (self, num_classes=1000 ): super ().__init__() self .features = nn.Sequential( nn.Conv2d(3 , 64 , 3 , padding=1 ), nn.ReLU(), nn.MaxPool2d(2 ), ) self .classifier = nn.Linear(512 , num_classes) def forward (self, x ): features = self .features(x) logits = self .classifier(features) probabilities = F.softmax(logits, dim=1 ) return probabilities model = ImageClassifier(num_classes=10 ) image = torch.randn(1 , 3 , 224 , 224 ) probs = model(image) predicted_class = torch.argmax(probs, dim=1 ) confidence = probs.max () print (f"预测类别: {predicted_class.item()} " )print (f"置信度: {confidence.item():.4 f} " )
2. 自然语言处理 - 序列标注 场景 : 命名实体识别(NER),每个词分配一个标签
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 class BiLSTMTagger (nn.Module): def __init__ (self, vocab_size, tag_size, embedding_dim, hidden_dim ): super ().__init__() self .embedding = nn.Embedding(vocab_size, embedding_dim) self .lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True , batch_first=True ) self .hidden2tag = nn.Linear(hidden_dim * 2 , tag_size) def forward (self, sentence ): embeds = self .embedding(sentence) lstm_out, _ = self .lstm(embeds) tag_space = self .hidden2tag(lstm_out) tag_scores = F.log_softmax(tag_space, dim=2 ) return tag_scores sentence = torch.tensor([[1 , 2 , 3 ]]) tag_probs = model(sentence) print (f"形状: {tag_probs.shape} " )
3. 注意力机制 场景 : Transformer 中的注意力权重
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 def scaled_dot_product_attention (query, key, value, mask=None ): """ 缩放点积注意力 """ scores = torch.matmul(query, key.transpose(-2 , -1 )) scores = scores / np.sqrt(query.size(-1 )) if mask is not None : scores = scores.masked_fill(mask == 0 , -1e9 ) attention_weights = F.softmax(scores, dim=-1 ) output = torch.matmul(attention_weights, value) return output, attention_weights batch_size = 2 seq_len = 5 d_model = 64 query = key = value = torch.randn(batch_size, seq_len, d_model) output, weights = scaled_dot_product_attention(query, key, value) print (f"注意力权重形状: {weights.shape} " ) print (f"权重和(应接近 1): {weights[0 , 0 ].sum ()} " )
4. 强化学习 - 策略网络 场景 : 在 Atari 游戏中选择动作
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 class PolicyNetwork (nn.Module): def __init__ (self, state_dim, action_dim ): super ().__init__() self .fc1 = nn.Linear(state_dim, 128 ) self .fc2 = nn.Linear(128 , action_dim) def forward (self, state ): x = F.relu(self .fc1(state)) logits = self .fc2(x) action_probs = F.softmax(logits, dim=-1 ) return action_probs state = torch.randn(1 , 84 , 84 , 4 ) policy = PolicyNetwork(84 *84 *4 , 4 ) action_probs = policy(state) dist = torch.distributions.Categorical(action_probs) action = dist.sample() log_prob = dist.log_prob(action) print (f"动作概率: {action_probs} " )print (f"采样动作: {action.item()} " )print (f"对数概率: {log_prob.item()} " )
📊 与其他激活函数对比
激活函数
公式
输出范围
主要应用
特点
Sigmoid
1/(1+e⁻ˣ)
(0, 1)
二分类
单一输出,概率解释
Softmax
eˣᵢ/∑eˣⱼ
(0, 1), ∑=1
多分类
概率分布,归一化
Tanh
(eˣ-e⁻ˣ)/(eˣ+e⁻ˣ)
(-1, 1)
隐藏层
零中心,有界
ReLU
max(0, x)
[0, +∞)
隐藏层
计算快,无梯度消失
GELU
x·Φ(x)
(-∞, +∞)
隐藏层
平滑,性能好
graph LR
A["输入 logits"] --> B{任务类型}
B -->|二分类| C["Sigmoid"]
B -->|多分类| D["Softmax"]
B -->|隐藏层| E["ReLU / GELU"]
C --> F["单个概率值"]
D --> G["概率分布"]
E --> H["非线性变换"]
style C fill:#ffcdd2
style D fill:#c8e6c9
style E fill:#fff9c4
⚡ 优化技巧与最佳实践 1. 使用 LogSoftmax 1 2 3 4 5 6 7 probs = F.softmax(logits, dim=1 ) loss = -torch.sum (target * torch.log(probs)) log_probs = F.log_softmax(logits, dim=1 ) loss = F.nll_loss(log_probs, target)
优势 :
数值更稳定
PyTorch 的 NLLLoss 直接使用 LogSoftmax
避免了 log(sum(exp)) 的计算
2. Label Smoothing 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def cross_entropy_with_label_smoothing (logits, target, smoothing=0.1 ): """ 标签平滑:防止过度自信 """ n_classes = logits.size(-1 ) smoothed_target = target * (1 - smoothing) + \ smoothing / (n_classes - 1 ) log_probs = F.log_softmax(logits, dim=-1 ) loss = -torch.sum (smoothed_target * log_probs, dim=-1 ) return loss.mean() logits = model(inputs) loss = cross_entropy_with_label_smoothing(logits, targets, smoothing=0.1 )
优势 :
防止模型过度自信
提高泛化能力
常用 smoothing 值:0.05-0.2
3. Temperature Scaling 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 def softmax_with_temperature (logits, temperature=1.0 ): """ Temperature Softmax 温度越低,分布越尖锐;温度越高,分布越平滑 """ scaled_logits = logits / temperature return F.softmax(scaled_logits, dim=-1 ) logits = model(inputs) probs_train = softmax_with_temperature(logits, temperature=1.0 ) probs_inference = softmax_with_temperature(logits, temperature=0.7 )
应用场景 :
训练: T = 1.0(标准 Softmax)
推理: T < 1.0(更尖锐的分布)
采样: T > 1.0(更平滑的分布,用于多样性)
4. 大类别优化 对于类别数 > 1000 的情况:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 class SampledSoftmax (nn.Module): """ 采样 Softmax:适用于大类别 只计算部分类别的概率 """ def __init__ (self, num_classes, num_sampled ): super ().__init__() self .num_classes = num_classes self .num_sampled = num_sampled def forward (self, hidden, labels, weights, biases ): sampled_values = torch.randint(0 , self .num_classes, (self .num_sampled,)) true_logits = torch.matmul(hidden, weights[:, labels]) sampled_logits = torch.matmul(hidden, weights[:, sampled_values]) return loss
应用 : 词表 > 10,000 的语言模型
🔍 数学推导:梯度计算 交叉熵损失 其中 $\hat{y}_i = \text{softmax}(x_i)$
梯度推导 步骤 1 : 对 $\hat{y}_j$ 求导
其中 $\delta_{ij}$ 是 Kronecker delta(i=j 时为 1,否则为 0)
步骤 2 : 对损失求导
步骤 3 : 链式法则
结论 :
直觉理解 :
如果预测过高 ($\hat{y}_i > y_i$),梯度为正,减少 $x_i$
如果预测过低 ($\hat{y}_i < y_i$),梯度为负,增加 $x_i$
完美符合直觉!
🎯 何时使用 Softmax? ✅ 推荐使用场景
场景
说明
示例
多分类问题
需要预测属于哪个类别
图像分类、文本分类
概率输出
需要输出概率分布
风险评估、不确定性量化
注意力机制
需要归一化权重
Transformer、图神经网络
策略网络
RL 中从概率分布采样
强化学习
序列标注
NER、词性标注
NLP 任务
❌ 不推荐使用场景
场景
替代方案
原因
二分类
Sigmoid
Softmax 效率低
回归问题
无激活/线性输出
Softmax 输出范围不对
隐藏层激活
ReLU、GELU
Softmax 梯度可能消失
大类别分类
采样 Softmax、分层分类
计算效率低
多标签分类
Sigmoid
多个独立二分类
📚 练习题 问题 1 为什么 Softmax 的输出和为 1?
答案 : 因为归一化因子是所有 e^xⱼ 的和,每个输出是 e^xᵢ 除以这个和,所以所有输出之和为 1。
问题 2 当输入向量 x = [0, 0, 0] 时,Softmax 的输出是什么?
答案 : [1/3, 1/3, 1/3]。因为 e⁰ = 1,所以每个输出都是 1/(1+1+1) = 1/3。
问题 3 如何避免 Softmax 的数值溢出问题?
答案 : 使用数值稳定版本,减去最大值:softmax(x) = exp(x-max(x)) / sum(exp(x-max(x)))。这样最大的指数是 1,不会溢出。
问题 4 Softmax 和 Sigmoid 在二分类时的关系是什么?
答案 : softmax([x, 0])_1 = sigmoid(x)。Softmax 可以看作是 Sigmoid 的多维推广。
问题 5 为什么强化学习中常用 Softmax 作为策略输出?
答案 : Softmax 输出是概率分布,可以从中采样动作。还可以通过温度参数控制探索-利用的权衡。
📖 总结 核心要点
定义 : Softmax 将 logits 转换为概率分布
关键性质 : 输出在 (0, 1),和为 1,单调,缩放不变
梯度简洁 : ∂L/∂xᵢ = ŷᵢ - yᵢ,易于反向传播
数值稳定 : 使用 LogSoftmax 或减去最大值
应用广泛 : 多分类、注意力、RL、NLP
与 Sigmoid 的对比
特性
Sigmoid
Softmax
输出维度
1
K(类别数)
任务类型
二分类
多分类
输出形式
单一概率
概率分布
归一化
无
自动
梯度
σ·(1-σ)
ŷ - y
记忆口诀
“多分类用 Softmax,概率分布和为 1。数值稳定要小心,LogSumExp 记在心。”
实践建议
使用框架实现 : PyTorch 的 F.softmax 或 nn.Softmax
优先用 LogSoftmax : 配合 NLLLoss 使用
大类别用技巧 : 采样 Softmax、分层分类
标签平滑 : 防止过度自信
温度缩放 : 推理时调整分布尖锐度
🔗 参考资料
下一篇预告 : 我们将介绍 ReLU 激活函数 ,这个革命性的简单函数如何解决了梯度消失问题,并成为现代深度学习的核心组件。
发布日期:2026年2月21日 标签:激活函数、Softmax、深度学习、神经网络、算法