PyTorch 10大常用损失函数Loss Function详解

目录

前言

一、损失函数

二、详解

1.回归损失

2.分类损失

三. 总结


前言

损失函数在深度学习中占据着非常重要的作用,选取的正确与否直接关系到模型的好坏。

本文就常用的损失函数做一个通俗易懂的介绍。


一、损失函数

根据深度函数的模型类型,损失函数可分为三类:

1. 回归损失(Regression loss):预测连续的数值,即输出是连续数据:如预测房价、气温等;

2. 分类损失(Classification loss):预测离散的数值,即输出是离散数据:如预测硬币正反、图像分类、语义分割等;

3. 排序损失(Ranking loss):预测输入样本间的相对距离,即输出一般是概率值,如预测两张面部图像是否属于同一个人等;

二、详解

1.回归损失

(1.)L1 Loss 计算实际值与预测值之间的绝对差之和的平均值;

表达式如下:

使用示例:

import torchimport torch.nn as nnimport numpy as np# L1 Lossinput = torch.randn(2, 2, requires_grad=True)target = torch.randn(2, 2)mae_loss = torch.nn.L1Loss()output = mae_loss(input, target)# print("input: ", input)# print("target: ", target)# print("output: ", output)

图片[1] - PyTorch 10大常用损失函数Loss Function详解 - MaxSSL

(2.)L2Loss 计算实际值和预测值之间的平方差的平均值;

L2 由于将误差平方化,因此当误差大于1时,整体偏差会被放大,出现极端偏差值时,L2模型会因为惩罚更大而开始偏离较远,相比之下,L1对异常值的鲁棒性更好。

表达式如下:

使用示例:

# L2 Lossinput = torch.randn(2, 2, requires_grad=True)target = torch.randn(2, 2)mse_loss = torch.nn.MSELoss()output = mse_loss(input, target)# print("input: ", input)# print("target: ", target)# print("output: ", output)

图片[2] - PyTorch 10大常用损失函数Loss Function详解 - MaxSSL

(3.)SmoothL1Loss 计算;(其实可看做L1 和 L2的线性结合)

在实际值与预测值小于1时,选取L2相似计算较稳定,大于1时,L1对异常值的鲁棒性更好,选择了L1的变形计算;

表达式如下:

# Smooth L1 Lossinput = torch.randn(2, 2, requires_grad=True)target = torch.randn(2, 2)smooth_l1_loss = torch.nn.SmoothL1Loss()output = smooth_l1_loss(input, target)print("input: ", input)print("target: ", target)print("output: ", output)

图片[3] - PyTorch 10大常用损失函数Loss Function详解 - MaxSSL

2.分类损失

(1.)NLLLoss(Negative Log-Likelihood) 多分类问题;

# NLL Lossinput = torch.randn(2, 2, requires_grad=True)target = torch.tensor([1, 0])m = nn.LogSoftmax(dim=1)nll_loss = torch.nn.NLLLoss()output = nll_loss(m(input), target)print("input: ", input)print("target: ", target)print("output: ", output)

(5.)Cross-Entropy Loss ,计算实际输出(概率)与期望输出(概率)的距离;(二分类或多分类),可看做nn.LogSoftmax() 和 nn.NLLLoss() 二者的结合;

“NLLLoss的 输入 是一个对数概率向量和一个目标标签,不计算对数概率.适合网络的最后一层是log_softmax.损失函数 nn.CrossEntropyLoss() 与 NLLLoss() 相同, 唯一的不同是nn.CrossEntropyLoss()做 softmax.”

# Cross-Entropyinput = torch.randn(2, 2, requires_grad=True)target = torch.empty(2, dtype=torch.long).random_(2)cross_entropy_loss = torch.nn.CrossEntropyLoss()output = cross_entropy_loss(input, target)print("input: ", input)print("target: ", target)print("output: ", output)

图片[4] - PyTorch 10大常用损失函数Loss Function详解 - MaxSSL

(6.)Hinge Embedding Loss: 判断两个输入是否相似或不同;

# hinge embedding lossinput = torch.randn(3, 3, requires_grad=True)target = torch.randn(3, 3)hinge_loss = torch.nn.HingeEmbeddingLoss()output = hinge_loss(input, target)print("input: ", input)print("target: ", target)print("output: ", output)

图片[5] - PyTorch 10大常用损失函数Loss Function详解 - MaxSSL

(7.)Margin Ranking Loss:预测输入之间的相对距离;

# Margin Ranking Lossinput1 = torch.randn(3, requires_grad=True)input2 = torch.randn(3, requires_grad=True)target = torch.randn(3).sign()ranking_loss = torch.nn.MarginRankingLoss()output = ranking_loss(input1, input2, target)print("input1: ", input1)print("input2: ", input2)print("target: ", target)print("output: ", output)

(8.)Triplet Margin Loss:计算三元组的损失,确定样本之间的相对相似性;

# triplet margin lossanchor = torch.randn(5, 5, requires_grad=True)positive = torch.randn(5, 5, requires_grad=True)negivate = torch.randn(5, 5, requires_grad=True)triplet_margin_loss = torch.nn.TripletMarginLoss()output = triplet_margin_loss(anchor, positive, negivate)

(9.)KL Divergence Loss,计算两个概率分布距离;

KL Divergence :评估概率分布预测与ground truth分布的不同之处;

# KL Divergence Lossinput = torch.randn(3, 3, requires_grad=True)target = torch.randn(3,3)kl_loss = torch.nn.KLDivLoss()output = kl_loss(input, target)print("input: ", input)print("target: ", target)print("output: ", output)

(10.)SoftMarginLoss

# Softmargin_lossinput = torch.randn(3, 3, requires_grad=True)target = torch.randn(3,3)softmargin_loss = torch.nn.SoftMarginLoss()output = softmargin_loss(input, target)print("input: ", input)print("target: ", target)print("output: ", output)


下图为以上损失函数公式的简单形式:


三. 总结

PyTorch还有很多损失函数,大部分基于这几个类型进行变形和优化,掌握基础是关键!

© 版权声明
THE END
喜欢就支持一下吧
点赞0 分享