【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数

文章目录

  • 前言
  • 一、torch.cat()函数 拼接只存在h,w(高,宽)的图像
  • 二、torch.cat() 拼接存在c,h,w(通道,高,宽)的图像
  • 三、torch.add()使张量对应元素直接相加

前言

本篇作为后期文章“特征融合”的基础。
特征融合分早融合和晚融合,早融合里的重要手段是concat和add

一、torch.cat()函数 拼接只存在h,w(高,宽)的图像

torch.cat()可以将多个张量合并为一个张量,我们接下来从简单到复杂一点点来盘这个函数

我们首先随机生成两个形状一致的张量:

import torchA =torch.rand(3,2)#单通道,高为3.宽为2的张量B=torch.rand(3,3) #单通道,高为2.宽为3的张量print(A)print(B)

图片[1] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL

让这个张量在第0维度进行拼接,也就是在高这个维度进行拼接:

C=torch.cat((A,B),dim=0)print(C)print(C.shape)

图片[2] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL
可以看到高变成了3+3,宽不变

让这个张量在第1维度进行拼接,也就是在宽这个维度进行拼接:

C=torch.cat((A,B),dim=1)print(C)print(C.shape)

图片[3] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL
可以看到,高不变,宽变成了2+2

在第0维度拼接时,高可以不一样,但是宽需要一致,不然会报错:

import torchA =torch.rand(3,3)#单通道,高为3.宽为2的张量B=torch.rand(4,3) #单通道,高为2.宽为3的张量print(A)print(B)C=torch.cat((A,B),dim=0)print(C)print(C.shape)

不报错:
图片[4] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL

import torchA =torch.rand(3,3)#单通道,高为3.宽为2的张量B=torch.rand(3,5) #单通道,高为2.宽为3的张量print(A)print(B)C=torch.cat((A,B),dim=0)print(C)print(C.shape)

直接报错:
图片[5] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL
在第1维度拼接时,高必须一致,宽可以不一样,不然会报错:

import torchA =torch.rand(3,3)#单通道,高为3.宽为2的张量B=torch.rand(3,5) #单通道,高为2.宽为3的张量print(A)print(B)C=torch.cat((A,B),dim=1)print(C)print(C.shape)

不报错:
图片[6] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL

import torchA =torch.rand(3,3)#单通道,高为3.宽为2的张量B=torch.rand(4,3) #单通道,高为2.宽为3的张量print(A)print(B)C=torch.cat((A,B),dim=1)print(C)print(C.shape)

图片[7] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL

二、torch.cat() 拼接存在c,h,w(通道,高,宽)的图像

我们随机生成两个3通道的2X2图像

import torchA =torch.rand(3,2,2)#单通道,高为3.宽为2的张量B=torch.rand(3,2,2) #单通道,高为2.宽为3的张量print(A)print(B)

图片[8] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL
图片[9] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL

让他们在第0维度进行拼接(通道维度拼接):
图片[10] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL
相当于通道数堆叠了,变成了六个通道

让他们在第1维度进行拼接(高维度拼接):
图片[11] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL
让他们在第2维度进行拼接(宽维度拼接):
图片[12] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL
这两个堆叠结果就和之前的方法一样了

三、torch.add()使张量对应元素直接相加

import torchA =torch.rand(3,2,2)#单通道,高为3.宽为2的张量B=torch.rand(3,2,2) #单通道,高为2.宽为3的张量print(A)print(B)C=torch.add(A,B)print(C)print(C.shape)

张量A:
图片[13] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL
张量B:
图片[14] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL
相加后张量:
图片[15] - 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数 - MaxSSL
当然也可以不用add(A,B) 用A+B

© 版权声明
THE END
喜欢就支持一下吧
点赞0 分享