目录

背景

亮点

环境配置

数据

方法

结果

代码获取

参考文献


背景

基于Gumbel-softmax方法EEG通道选择层的PyTorch实现。该层可以放置在任何深度神经网络架构的前面,以共同学习给定任务和网络权重的脑电图通道的最佳子集。这一层由选择神经元组成,每个神经元都使用输入通道上离散分布的连续松弛来学习最佳的单热权重向量来选择输入通道,而不是线性组合它们。

亮点

使用Gumbel-softmax方法对多通道脑电数据进行单通道选择(非多通道线性加权)

使用多尺度滤波卷积网络实现运动想象4分类。

环境配置

PyTorch0.3.1,

CUDA9.1

数据

High-GammaDataset

方法

多尺度滤波卷积网络主要代码:

classMSFBCNN(nn.Module):def__init__(self,input_dim,output_dim,FT=10):super(MSFBCNN,self).__init__()self.T=input_dim[1]self.FT=FTself.D=1self.FS=self.FT*self.Dself.C=input_dim[0]self.output_dim=output_dim#Paralleltemporalconvolutionsself.conv1a=nn.Conv2d(1,self.FT,(1,65),padding=(0,32),bias=False)self.conv1b=nn.Conv2d(1,self.FT,(1,41),padding=(0,20),bias=False)self.conv1c=nn.Conv2d(1,self.FT,(1,27),padding=(0,13),bias=False)self.conv1d=nn.Conv2d(1,self.FT,(1,17),padding=(0,8),bias=False)self.batchnorm1=nn.BatchNorm2d(4*self.FT,False)#Spatialconvolutionself.conv2=nn.Conv2d(4*self.FT,self.FS,(self.C,1),padding=(0,0),groups=1,bias=False)self.batchnorm2=nn.BatchNorm2d(self.FS,False)#Temporalaveragepoolingself.pooling2=nn.AvgPool2d(kernel_size=(1,75),stride=(1,15),padding=(0,0))self.drop=nn.Dropout(0.5)#Classificationself.fc1=nn.Linear(self.FS*math.ceil(1+(self.T-75)/15),self.output_dim)defforward(self,x):#Layer1x1=self.conv1a(x);x2=self.conv1b(x);x3=self.conv1c(x);x4=self.conv1d(x);x=torch.cat([x1,x2,x3,x4],dim=1)x=self.batchnorm1(x)#Layer2x=torch.pow(self.batchnorm2(self.conv2(x)),2)x=self.pooling2(x)x=torch.log(x)x=self.drop(x)#FCLayerx=x.view(-1,self.num_flat_features(x))x=self.fc1(x)returnxdefnum_flat_features(self,x):size=x.size()[1:]#alldimensionsexceptthebatchdimensionnum_features=1forsinsize:num_features*=sreturnnum_features

Gumbel-softmax再参数化主要代码:

classSelectionLayer(nn.Module):def__init__(self,N,M,temperature=1.0):super(SelectionLayer,self).__init__()self.floatTensor=torch.FloatTensorifnottorch.cuda.is_available()elsetorch.cuda.FloatTensorself.N=Nself.M=Mself.qz_loga=Parameter(torch.randn(N,M)/100)self.temperature=self.floatTensor([temperature])self.freeze=Falseself.thresh=3.0defquantile_concrete(self,x):g=-torch.log(-torch.log(x))y=(self.qz_loga+g)/self.temperaturey=torch.softmax(y,dim=1)returnydefregularization(self):eps=1e-10z=torch.clamp(torch.softmax(self.qz_loga,dim=0),eps,1)H=torch.sum(F.relu(torch.norm(z,1,dim=1)-self.thresh))returnHdefget_eps(self,size):eps=self.floatTensor(size).uniform_(epsilon,1-epsilon)returnepsdefsample_z(self,batch_size,training):iftraining:eps=self.get_eps(self.floatTensor(batch_size,self.N,self.M))z=self.quantile_concrete(eps)z=z.view(z.size(0),1,z.size(1),z.size(2))returnzelse:ind=torch.argmax(self.qz_loga,dim=0)one_hot=self.floatTensor(np.zeros((self.N,self.M)))forjinrange(self.M):one_hot[ind[j],j]=1one_hot=one_hot.view(1,1,one_hot.size(0),one_hot.size(1))one_hot=one_hot.expand(batch_size,1,one_hot.size(2),one_hot.size(3))returnone_hotdefforward(self,x):z=self.sample_z(x.size(0),training=(self.trainingandnotself.freeze))z_t=torch.transpose(z,2,3)out=torch.matmul(z_t,x)returnout

结果

实现从64通道脑电信号中提取出N个重要通道脑电信号,增强后续分类任务的性能

代码获取

https://download.csdn.net/download/YINTENAXIONGNAIER/88946872

参考文献

  • Strypsteen,Thomas,andAlexanderBertrand.”End-to-endlearnableEEGchannelselectionfordeepneuralnetworkswithGumbel-softmax.”JournalofNeuralEngineering18.4(2021):0460a9.