文章目录
- 一、squeeze函数的用法
- 二、nn.CrossEntropyLoss函数
- 三、isinstance函数
- 四、定义冻结层 freeze_layers
- 五、SummaryWriter 基础用法
- 六、Python 基础语法
- 1.变量嵌入到字符串
- 2. enumerate() 函数
- 3. 进度条库tqdm
- 4. 字典(dict)展开为关键字参数(keyword arguments)
- 5. assert 断言操作
- 6. \_\_class__.__name__获取对象类名
- 7. all() 方法判断字符是不是都非零
- 附录
一、squeeze函数的用法
import torch# 创建一个具有形状(2, 1, 3)的张量x = torch.tensor([[[1, 2, 3]], [[4, 5, 6]]])print("Original tensor shape:", x.shape)# 输出: Original tensor shape: torch.Size([2, 1, 3])# 使用squeeze()移除所有大小为1的维度x_squeezed = x.squeeze()print("Squeezed tensor shape:", x_squeezed.shape)print(x_squeezed)# 输出: Squeezed tensor shape: torch.Size([2, 3])# 使用squeeze(dim)仅移除特定维度x_squeezed_dim_1 = x.squeeze(1)print("Squeezed tensor (dim 1) shape:", x_squeezed_dim_1.shape)# 输出: Squeezed tensor (dim 1) shape: torch.Size([2, 3])
在这个例子中,我们创建了一个形状为(2, 1, 3)
的3维张量。我们可以看到,第二个维度(索引为1)的大小为1。使用squeeze()
函数移除所有大小为1的维度后,张量的形状变为(2, 3)
。同时,使用squeeze(dim)
函数仅移除特定维度也可以达到相同效果。
我们注意到给出的张量在第二个维度上不为零,不经让人产生疑问,我用squeeze(2)
会报错吗?不妨动手一试
x_squeezed_dim_2 = x.squeeze(2)print(x_squeezed_dim_2)print(x_squeezed_dim_2.shape)
tensor([[[1, 2, 3]],
[[4, 5, 6]]])
torch.Size([2, 1, 3])
结果是代码正常编译了,并没有产生问题。和原本张量一致,张量不会发生压缩。
二、nn.CrossEntropyLoss函数
nn.CrossEntropyLoss()
是PyTorch中一个非常常用的损失函数,用于多分类任务。这个损失函数同时执行了nn.LogSoftmax()
和nn.NLLLoss()
(负对数似然损失)。
请注意,这个损失函数需要两个输入:预测值(logits,未经softmax层处理的输出)和真实标签。对于预测值,输入张量的形状应该是(batch_size, num_classes, ...)
,其中...
表示任意其他尺寸。对于真实标签,输入张量的形状应该是(batch_size, ...)
,标签值应该是0
到num_classes-1
之间的整数。
下例中批量大小为3,类别数量为4
import torchimport torch.nn as nn# 创建一个批量大小为3,类别数量为4的预测值张量(logits)logits = torch.tensor([[2.5, 1.0, 0.5, 1.5],[0.3, 3.2, 2.1, 1.0],[1.2, 2.3, 3.1, 0.7]])# 创建一个对应的真实标签张量labels = torch.tensor([0, 1, 2]) # 第一个样本的真实类别是0,第二个是1,第三个是2# 初始化损失函数criterion = nn.CrossEntropyLoss()# 计算损失loss = criterion(logits, labels)print("Cross entropy loss:", loss.item())
经过运行,我们得到了如下的结果
Cross entropy loss: 0.4313742220401764
Process finished with exit code 0
如果是更高维度的,预测值张量会是什么形式的呢?
我们以语义分割任务为例,假设我们有一个批量大小为2,类别数量为3,图像高度和宽度分别为 4 × 44\times44×4 的预测值张量。在这种情况下,输入张量的形状应该是例如形状为 (batch_size, num_classes, height, width)
,其中 height
和 width
分别表示图像的高度和宽度。这意味着我们需要一个(batch_size, height, width)
的标签向量。这是一个具体的示例,输入的形状为(2, 3, 4, 4)
:
[[[[0.5, 1.0, 1.2, 0.3],[0.2, 0.9, 1.1, 1.4],[1.5, 0.7, 0.6, 0.8],[0.1, 0.4, 0.2, 0.9]],[[1.0, 0.5, 0.7, 1.2],[1.5, 0.3, 0.8, 0.6],[0.2, 1.0, 1.2, 0.5],[1.1, 0.9, 0.7, 0.6]],[[0.7, 1.1, 0.6, 0.9],[0.6, 1.2, 0.5, 0.3],[1.0, 0.8, 1.1, 1.4],[1.2, 0.5, 0.9, 0.8]]],[[[1.1, 0.6, 0.8, 0.5],[0.9, 1.0, 1.2, 1.1],[0.7, 1.5, 0.3, 0.6],[1.3, 0.2, 0.4, 0.9]],[[0.4, 1.2, 0.9, 1.5],[1.6, 0.1, 0.3, 0.7],[0.9, 0.6, 1.4, 1.0],[0.8, 1.1, 0.5, 0.3]],[[0.6, 1.3, 1.0, 0.2],[0.5, 1.7, 0.8, 0.9],[1.2, 0.3, 1.1, 1.5],[0.7, 0.9, 1.0, 1.2]]]]
这是一个随机输出和随机目标张量的示例:
import torchimport torch.nn as nn# 假设 logits 是我们的模型预测的输出logits = torch.randn(2, 3, 4, 4)# 模拟输入张量# 假设 targets 是我们的真实标签targets = torch.randint(0, 3, (2, 4, 4))# 随机生成一个目标张量# 使用 nn.CrossEntropyLoss 计算损失criterion = nn.CrossEntropyLoss()loss = criterion(logits, targets)print(loss)
备注:在Pycharm中想要格式化代码,可以使用快捷键
Windows/Linux:Ctrl+Alt+L
但是,上面的模拟输入张量和随机目标张量都是随机的,为了更具说服力(其实是为了水文章长度),我们就用上面的 2 × 3 × 4 × 42\times3\times4\times42×3×4×4 维张量来试验一下。我们可以随时调整真实标签的值,来观察loss = criterion(logits, targets)
的值是增大还是减小。
tensor(1.2113)
Process finished with exit code 0
现在我们修改一下模型的预测输出结果 logits
。可以看到输出的 loss
值明显降低,说明预测值更加符合实际标签。
# 假设 logits 是我们的模型预测的输出logits = torch.tensor([[[[0, 0, 0, 0],[0.2, 0.9, 1.1, 1.4],[1.5, 0.7, 0.6, 0.8],[0.1, 0.4, 0.2, 0.9]],[[5, 5, 5, 5],[0, 0, 0, 0],[0.2, 1.0, 1.2, 0.5],[1.1, 0.9, 0.7, 0.6]],[[0.7, 1.1, 0.6, 0.9],[5, 5, 5, 5],[1.0, 0.8, 1.1, 1.4],[1.2, 0.5, 0.9, 0.8]]],[[[1.1, 0.6, 0.8, 0.5],[0.9, 1.0, 1.2, 1.1],[0.7, 1.5, 0.3, 0.6],[1.3, 0.2, 0.4, 0.9]],[[0.4, 1.2, 0.9, 1.5],[1.6, 0.1, 0.3, 0.7],[0.9, 0.6, 1.4, 1.0],[0.8, 1.1, 0.5, 0.3]],[[0.6, 1.3, 1.0, 0.2],[0.5, 1.7, 0.8, 0.9],[1.2, 0.3, 1.1, 1.5],[0.7, 0.9, 1.0, 1.2]]]])# 模拟输入张量# 假设 targets 是我们的真实标签targets = torch.tensor([[[1, 1, 1, 1],[2, 2, 2, 2],[0, 0, 0, 0],[0, 0, 0, 0]],[[1, 1, 1, 1],[2, 2, 2, 2],[0, 0, 0, 0],[0, 0, 0, 0]]])# 使用 nn.CrossEntropyLoss 计算损失criterion = nn.CrossEntropyLoss()loss = criterion(logits, targets)print(loss)
tensor(0.9149)
进程已结束,退出代码0
三、isinstance函数
isinstance()
函数是 Python 的内置函数,用于检查一个对象是否是指定类的实例。该函数具有两个参数:
- 第一个参数是要检查的对象。
- 第二个参数是类或类的元组。
函数的返回值是布尔值,如果对象是给定类的实例(或者是元组中任何类的实例),则返回 True
,否则返回 False
。
下面是一些使用 isinstance()
函数的示例:
# 示例 1: 判断变量是否为整数num = 5print(isinstance(num, int))# 输出: True# 示例 2: 判断变量是否为字符串text = "Hello, World!"print(isinstance(text, str))# 输出: True# 示例 3: 判断变量是否为整数或浮点数num2 = 3.14print(isinstance(num2, (int, float)))# 输出: True# 示例 4: 使用自定义类class MyClass:passclass AnotherClass:passobj = MyClass()print(isinstance(obj, MyClass))# 输出: Trueprint(isinstance(obj, AnotherClass))# 输出: False
上面提到第二个参数可以是类的元组,表示或的关系,下面是一个示例:
class Animal:passclass Dog(Animal):passclass Cat(Animal):passclass Car:pass# 创建一个 Dog 对象dog = Dog()# 使用 isinstance() 函数检查 dog 是否是 Dog 或 Cat 类的实例print(isinstance(dog, (Dog, Cat)))# 输出: True# 使用 isinstance() 函数检查 dog 是否是 Animal 或 Car 类的实例print(isinstance(dog, (Animal, Car)))# 输出: True, 因为 Dog 类是 Animal 类的子类
# Early stoppingif epoch - best_epoch > opt.es_patience > 0:print('[Info] Stop training at epoch {}. The lowest loss achieved is {}'.format(epoch, best_loss))break
es_patience
是 Early Stopping 的一种实现方式,其中’es’是early的缩写,’patience’指的是在停止训练之前允许的性能停滞时间。具体来说,es_patience
是一种在训练过程中使用的技术,它基于可以允许的性能停滞时间,在模型的训练过程中始终监测验证集的性能,以便及早停止训练并避免过拟合。
images = images.cuda()# 将图片数据从 CPU 发送到 GPU 上进行处理labels = labels.cuda()# 将标签数据从 CPU 发送到 GPU 上进行处理
if loss == 0 or not torch.isfinite(loss):continue
这行代码通常用于在训练神经网络时,处理梯度下降过程中产生的非数值(NaN)和无穷大(Inf)的情况。
loss
是一个 tensor 类型(张量),记录了当前模型输出与真实标签之间的损失值。在 PyTorch 中,如果 loss
的值为 000 或者不是有限数(即 NaN 或 Inf),则会出现异常,并且程序会中断。
# 创建一个包含 NaN 和 Inf 的张量data = torch.tensor([float('nan'), float('inf')])# 判断张量的元素是否为有限数if torch.isfinite(data).all():# 如果所有元素都是有限数,则进行其他操作print("All elements are finite.")else:# 如果存在非有限数元素,则跳过此操作print("There are infinite or NaN elements.")
四、定义冻结层 freeze_layers
if opt.freeze_layers is not None:assert isinstance(opt.freeze_layers, list), "Required List string"def freeze_layers(m):classname = m.__class__.__name__ for ntl in opt.freeze_layers:if ntl in classname:#可以理解为 "need to freeze layer"for param in m.parameters():param.require_grad = False model.apply(freeze_layers)#将该函数作用于模型上,以实现对特定层的参数进行冻结print('[Info] freeze layers in ', opt.freeze_layers)
以上代码实现了对模型特定层的权重冻结,具体过程如下:
首先进行一个条件判断,如果
opt.freeze_layers
不为None
,则进入到定义函数freeze_layers
的块中。在
freeze_layers
函数中,通过m.__class__.__name__
获取当前遍历的模块m
的类名,并将其与opt.freeze_layers
中的每个字符串进行比较。若classname
包含ntl
,则说明该模块需要被冻结。如果发现有层需要被冻结,则会遍历该层的参数列表,并将各参数的
require_grad
属性设置为 False,防止其在后续训练中被更新。在对所有层都完成操作之后,通过
model.apply(freeze_layers)
将freeze_layers
这个函数作用于模型中的所有层次上,从而实现对特定层的参数进行冻结最后,程序输出一条信息提示,显示哪些层被冻结了。
五、SummaryWriter 基础用法
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter('logs')writer.close()
我新建一个文件,尝试运行上述代码结果发生了下面的报错:
TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
我在命令行中以管理员身份运行下述代码,修改了 protobuf 的版本,成功解决了这个报错。(注意安装新版本之前一定要卸载干净旧的版本,否则会有未知的错误)
pip install protobuf==3.19.0
但是出现了新的报错如下:
AttributeError: module ‘tensorflow’ has no attribute ‘io’
根据提示,我们打开 event_file_writer.py
文件,修改代码为
from tensorboard.compat import tensorflow_stub as tf
回到最开始的文件,此时编译就可以正常通过了,结果如下
Connected to pydev debugger (build 222.3345.131)
Process finished with exit code 0
此时 logs
文件夹已经生成,但是如果我们想要看到 tensorflow 可视化工具还会出现一些问题。我们在 PowerShell 界面尝试输入下面的代码,但是会报错。出现的错误原因是
tensorboard --logdir=logs
tensorboard ValueError: Duplicate plugins for name projector
这是因为我们曾经安装过 tensorflow,但是因为 Python 的版本控制问题,他会有一些安装包没有卸载干净。我们必须要删除掉一些遗留的文件夹才能解决掉这个问题。后来我安装了 anaconda 解决了 environment 这一困难,当然这是后话了。当然还有一个小 bug 就是可视化界面必须要在 chrome 内核的浏览器中才能打开。当然,这也是后话了。
六、Python 基础语法
1.变量嵌入到字符串
当需要将变量嵌入到字符串中时,可以使用字符串格式化方法。在Python中,有多种实现这种方法的方式,下面是一些例子:
- 使用百分号:
name = "Alice"age = 25message = "My name is %s, and I'm %d years old." % (name, age)print(message)# 输出:"My name is Alice, and I'm 25 years old."
- 使用format()函数:
name = "Bob"weight = 68.5height = 1.75message = "Hello, my name is {} and my weight is {:.1f} kg. My height is {:.2f} m.".format(name, weight, height)`在这里插入代码片`print(message)# 输出:"Hello, my name is Bob and my weight is 68.5 kg. My height is 1.75 m."
- 使用f-string:
x = 3y = 4result = f'{x} + {y} = {x+y}'print(result)# 输出: "3 + 4 = 7"
2. enumerate() 函数
- 打印列表中的元素及其对应下标
fruits = ['banana', 'apple','mango']for index, fruit in enumerate(fruits):print(index, fruit)
0 banana
1 apple
2 mango
Process finished with exit code 0
- 将列表转化为字典,其中字典的
key
是列表元素的下标,value
是列表元素本身
fruits = ['banana', 'apple','mango']d = {index: fruit for index, fruit in enumerate(fruits)}print(d)
{0: ‘banana’, 1: ‘apple’, 2: ‘mango’}
Process finished with exit code 0
- 枚举字符串中的字符
word = 'hello'for i, char in enumerate(word): print(i, char)
0 h
1 e
2 l
3 l
4 o
3. 进度条库tqdm
tqdm 是 Python 中的一个进度条库,它可以让我们在循环体内添加一个进度条,以便在程序运行时实时显示循环进度,并可随时停止、暂停、恢复进度条等操作。
from tqdm import tqdmimport time# 定义一个包含 10000 个元素的列表l = list(range(10000))# 使用 tqdm 显示循环进度for i in tqdm(l):# 模拟耗时操作time.sleep(0.001)
tqdm 源自阿拉伯语 taqaddum (تقدّم) ,意思是进程 (“progress”)
4. 字典(dict)展开为关键字参数(keyword arguments)
在 Python 中,使用两个星号 **
可以将一个字典(dict)展开为关键字参数(keyword arguments)。这意味着,如果我们有一个包含若干个关键字参数的字典 params,我们可以通过在函数调用时使用双星号来将这些参数传递到函数中。例如:
def some_function(a, b, c):print(f"a={a}, b={b}, c={c}")params = {"a": 1, "b": 2, "c": 3}some_function(**params)# 等价于 some_function(a=1, b=2, c=3)
a=1, b=2, c=3
将一个字典展开为关键字参数时,字典中的键(key)必须和定义函数时的关键词参数名一致。只有这样,Python 才能正确地将字典中的值(value)分配给相应的关键词参数。
值得注意的是,如果在字典中缺少任何一个关键词参数,或者字典中存在多余的关键词参数,则会引发 TypeError 异常。我们将代码进行如下修改:
def some_function(a, b, c):print(f"a={a}, b={b}, c={c}")params = {"a": 1, "b": 2, "c": 3,"d":5}
TypeError: some_function() got an unexpected keyword argument ‘d’
5. assert 断言操作
def add_numbers(x, y):assert isinstance(x, int) and isinstance(y, int), "x and y must be integers."return x + yprint(add_numbers(2, 3))# Output: 5print(add_numbers('Hello', 3))# AssertionError: x and y must be integers.
AssertionError: x and y must be integers.
5
在以上示例中,第一行计算了 2 和 3 的和,输出结果 5,符合预期。而第二行在调用 add_numbers
函数时,将一个字符串 "Hello"
和整数 3 作为函数参数传入,因此此时 assert 语句判断失败,抛出异常并打印出错误信息 "x and y must be integers."
。
6. __class__.__name__获取对象类名
m.__class__.__name__
是 Python 中一种获取对象类型的方式。在示例代码中,它是用来获取当前遍历到的模块 m
的类名。
具体来说,在 Python 中,任何一个对象都有一个类(或类型),可以使用 type()
或者对象的 __class__
属性来获取它们的类型/类。例如,以下代码创建了两个对象并打印它们的类型:
a = 1b = "hello"print(type(a)) # print(b.__class__) #
因为调用 type()
得到结果的标准格式不便于直接作为字符串进行处理,所以常常使用 __class__.__name__
来获取对象类型的名称。__name__
是指该类型名称,而 __class__
则表示该类本身。例如,在上面示例代码中,使用 __class__.__name__
可以将结果转化为字符串类型的对象名称:
a = 1b = "hello"print(a.__class__.__name__) # 'int'print(b.__class__.__name__) # 'str'
int
str
类似地,对于 PyTorch 中的 nn
模块,也可以使用 __class__.__name__
获取模块的类名。比如下面的代码:
import torch.nn as nnlinear_layer = nn.Linear(10, 5)# 创建一层线性变换conv_layer = nn.Conv2d(3, 16, (3,3), padding=1)# 创建一层卷积变换print(linear_layer.__class__.__name__)# 输出:Linearprint(conv_layer.__class__.__name__)# 输出:Conv2d
Linear
Conv2d
在代码中,我们使用 nn.Linear 和 nn.Conv2d 分别创建了两种不同的神经网络层。linear_layer 对象被初始化为 nn.Linear(10, 5)
,因此,linear_layer.__class__
是 nn.Linear 类型,使用 __class__.__name__
获取其类名为 ‘Linear’。
对于 conv_layer,也是类似的过程。因此,conv_layer.__class__.__name__
会返回 ‘Conv2d’ 字符串表示它是一个卷积层。
7. all() 方法判断字符是不是都非零
all
方法在Python中用来判断一个数组是不是都是非零的。下面是例子:
# 定义一个包含零和正数元素的张量x = torch.tensor([1, 2, 0, 4, 5])# 判断张量中的所有元素是否都非零if x.all():print("All elements are nonzero.")else:print("There are zero elements.")
There are zero elements.
附录
import argparseimport datetimeimport osimport tracebackimport numpy as npimport torchfrom torch import nnfrom torch.utils.data import DataLoaderfrom torchvision import transformsfrom tensorboardX import SummaryWriterfrom tqdm import tqdmfrom dataset import USCISIDatasetfrom net import BusterNetfrom utils import CustomDataParalleldef get_args():parser = argparse.ArgumentParser('Buster Net')parser.add_argument('-n', '--num_workers', type=int, default=16, help='num_workers of dataloader')parser.add_argument('-b', '--batch_size', type=int, default=4,help='The number of images per batch among all devices')parser.add_argument('--num_gpus', type=int, default=1,help='The number of gpus') # Multi gpus not spport yet.parser.add_argument('--freeze_layers', nargs='*', default=None,help='freeze layers with strategy')parser.add_argument('--lr', type=float, default=1e-2)parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, ' 'suggest using \'adamw\' or \'adam\' until the' ' very final stage then switch to \'sgd\'')parser.add_argument('--num_epochs', type=int, default=500)parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')parser.add_argument('--es_min_delta', type=float, default=0.0,help='Early stopping\'s parameter: minimum change loss to qualify as an improvement')parser.add_argument('--es_patience', type=int, default=0,help='Early stopping\'s parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.')parser.add_argument('--lmdb_dir', type=str, default='./datasets/USCISI-CMFD', help='the root folder of dataset')parser.add_argument('--log_path', type=str, default='./logs/')parser.add_argument('-w', '--load_weights', type=str, default=None,help='whether to load weights from a checkpoint, set None to initialize, set \'last\' to load last checkpoint')parser.add_argument('--saved_path', type=str, default='logs/')args = parser.parse_args()return argsclass ModelWithLoss(nn.Module):def __init__(self, model, train_simi=True, train_mani=True, train_fusion=True, debug=False):super().__init__()self.ce_criterion = nn.CrossEntropyLoss()self.bce_criterion = nn.BCELoss()self.model = modelself.train_simi = train_simiself.train_mani = train_mani self.train_fusion = train_fusionself.debug = debugdef forward(self, imgs, gts):fusion_preds, mani_preds, simi_preds = self.model(imgs)simi_gts = (1 - gts[:, 2, :, :]).type(torch.float)mani_gts = gts[:, 0, :, :].type(torch.float)_, fusion_gts = gts.max(dim=1)loss = torch.zeros(3)if self.train_fusion:fusion_loss = self.ce_criterion(fusion_preds, fusion_gts)loss[0] = fusion_lossif self.train_mani:mani_preds = mani_preds.squeeze(1)mani_loss = self.bce_criterion(mani_preds, mani_gts)loss[1] = mani_lossif self.train_simi:simi_preds = simi_preds.squeeze(1)simi_loss = self.bce_criterion(simi_preds, simi_gts)#ground truth segmentation 真值分割loss[2] = simi_lossreturn lossdef train(opt):train_file = 'train.keys'val_file = 'valid.keys'# Train similarity network or manipulation network independently or the whole network.train_simi=Truetrain_mani=Truetrain_fusion=True# According to the papers, set input_size default to 256.input_size = 256train_transform = transforms.Compose([transforms.ToPILImage(),transforms.Resize((input_size, input_size)),# transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])val_transform = transforms.Compose([transforms.ToPILImage(),transforms.Resize((input_size, input_size)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])target_transform = transforms.Compose([transforms.ToPILImage(),transforms.Resize((input_size, input_size)),transforms.ToTensor(),])train_set = USCISIDataset(opt.lmdb_dir, train_file, train_transform, target_transform)val_set = USCISIDataset(opt.lmdb_dir, val_file, val_transform, target_transform)training_params = {'batch_size': opt.batch_size, 'shuffle': True, 'drop_last': True,#'collate_fn': collater, 'num_workers': opt.num_workers}val_params = {'batch_size': opt.batch_size,'shuffle': False,'drop_last': True,# 'collate_fn': collater,'num_workers': opt.num_workers}training_generator = DataLoader(train_set, **training_params)val_generator = DataLoader(val_set, **val_params)model = BusterNet(image_size=input_size)if opt.load_weights is not None:try:# Load pretrain VGG16 in https://download.pytorch.org/models/vgg16-397923af.pth or continuing trainingif 'vgg16_bn' in opt.load_weights:vgg_backbone = torch.load(opt.load_weights)model.manipulation_net.load_state_dict(vgg_backbone, strict=False)model.similarity_net.load_state_dict(vgg_backbone, strict=False)else:model.load_state_dict(torch.load(opt.load_weights), strict=False)except RuntimeError as e:print(f'[Warning] Ignoring {e}')print(f'[Info] loaded weights: {os.path.basename(opt.load_weights)}')else:print('[Info] initializing weights...')# init_weights(model)if opt.freeze_layers is not None:assert isinstance(opt.freeze_layers, list), "Required List string"def freeze_layers(m):classname = m.__class__.__name__ for ntl in opt.freeze_layers:if ntl in classname:for param in m.parameters():param.require_grad = False model.apply(freeze_layers)print('[Info] freeze layers in ', opt.freeze_layers)# warp the model with loss function, to reduce the memory usage on gpu0 and speedupmodel = ModelWithLoss(model, train_simi=train_simi, train_mani=train_mani, train_fusion=train_fusion)if opt.num_gpus > 1 and opt.batch_size // opt.num_gpus < 4:model.apply(replace_w_sync_bn)use_sync_bn = Trueelse:use_sync_bn = Falseos.makedirs(opt.saved_path, exist_ok=True)writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')if opt.num_gpus > 0:model = model.cuda()if opt.num_gpus > 1:model = CustomDataParallel(model, opt.num_gpus)if use_sync_bn:patch_replication_callback(model)if opt.optim == 'adamw':optimizer = torch.optim.AdamW(model.parameters(), opt.lr)elif opt.optim == 'adam':optimizer = torch.optim.Adam(model.parameters(), opt.lr)else:optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=0.9, nesterov=True)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)last_step = 0epoch = 0best_loss = 1e5best_epoch = 0step = max(0, last_step)model.train()num_iter_per_epoch = len(training_generator)try:for epoch in range(opt.num_epochs):epoch_loss = []progress_bar = tqdm(training_generator)for iter, data in enumerate(progress_bar):last_epoch = step // num_iter_per_epochif iter < step - last_epoch * num_iter_per_epoch:progress_bar.update()continuetry:imgs, gts, _ = dataif opt.num_gpus == 1:# if only one gpu, just send it to cuda:0# elif multiple gpus, send it to multiple gpus in CustomDataParallel, not hereimgs = imgs.cuda()gts = gts.cuda()optimizer.zero_grad()fusion_loss, mani_loss, simi_loss = model(imgs, gts)fusion_loss = fusion_loss.mean()simi_loss = simi_loss.mean()mani_loss = mani_loss.mean()loss = fusion_loss + mani_loss + simi_lossif loss == 0 or not torch.isfinite(loss):continueloss.backward()# torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)optimizer.step()epoch_loss.append(float(loss))progress_bar.set_description('Step: {}. Epoch: {}/{}. Iteration: {}/{}. Fusion loss: {:.5f}. Mani loss: {:.5f}. Mini loss: {:.5f} Total loss: {:.5f}'.format(step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, fusion_loss.item(),mani_loss.item(), simi_loss.item(), loss.item()))writer.add_scalar('Loss', loss, step)writer.add_scalar('fusion_loss', fusion_loss, step)writer.add_scalar('simi_loss', simi_loss, step)writer.add_scalar('mani_loss', mani_loss, step)# log learning_ratecurrent_lr = optimizer.param_groups[0]['lr']writer.add_scalar('learning_rate', current_lr, step)step += 1if step % opt.save_interval == 0 and step > 0:save_checkpoint(model, f'model_{epoch}_{step}.pth')print('checkpoint...')except Exception as e:print('[Error]', traceback.format_exc())print(e)continuescheduler.step(np.mean(epoch_loss))if epoch % opt.val_interval == 0:model.eval()loss_fusion_ls = []loss_simi_ls = []loss_mani_ls = []for iter, data in enumerate(val_generator):with torch.no_grad():imgs, gts, _ = dataif opt.num_gpus == 1:imgs = imgs.cuda()gts = gts.cuda()fusion_loss, mani_loss, simi_loss = model(imgs, gts)fusion_loss = fusion_loss.mean()simi_loss = simi_loss.mean()mani_loss = mani_loss.mean()loss = fusion_loss + mani_loss + simi_lossif loss == 0 or not torch.isfinite(loss):continueloss_fusion_ls.append(fusion_loss.item())loss_simi_ls.append(simi_loss.item())loss_mani_ls.append(mani_loss.item())fusion_loss = np.mean(loss_fusion_ls)simi_loss = np.mean(loss_simi_ls)mani_loss = np.mean(loss_mani_ls)loss = fusion_loss + simi_loss + mani_lossprint('Val. Epoch: {}/{}. Fusion loss: {:1.5f}. Simi loss: {:1.5f}. Mani loss: {:1.5f}. Total loss: {:1.5f}'.format(epoch, opt.num_epochs, fusion_loss, simi_loss, mani_loss, loss))writer.add_scalar('Val_Loss', loss, step)writer.add_scalar('Val_Fusion_loss', fusion_loss, step)writer.add_scalar('Val_Simi_loss', simi_loss, step)writer.add_scalar('Val_Mani_loss', mani_loss, step)if loss + opt.es_min_delta < best_loss:best_loss = lossbest_epoch = epochsave_checkpoint(model, f'model_{epoch}_{step}.pth')model.train()# Early stoppingif epoch - best_epoch > opt.es_patience > 0:print('[Info] Stop training at epoch {}. The lowest loss achieved is {}'.format(epoch, best_loss))breakexcept KeyboardInterrupt:save_checkpoint(model, f'model_{epoch}_{step}.pth')writer.close()writer.close()def save_checkpoint(model, name):if isinstance(model, CustomDataParallel):torch.save(model.module.model.state_dict(), os.path.join(opt.saved_path, name))else:torch.save(model.model.state_dict(), os.path.join(opt.saved_path, name))if __name__ == '__main__':opt = get_args()train(opt)