常用类
这里总结一些频繁用到的支持类。
from dataclasses import dataclassfrom ..utils import BaseOutputfrom collections import OrderedDictclass BaseOutput(OrderedDict):...@dataclassclass Unet2DOutput(BaseOutput):"""The output of [`Unet2DModel`].Args:sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):The hidden states output from the last layer of the model."""sample: torch.FloatTensor
BaseOutput
继承自OrderedDict
,可以记住数据插入的顺序。BaseOutput
这个类是所有模型输出的基类。models\unet_2d.py
中就定义了Unet2DOutput
做为该模型的输出类。且还用了dataclass
装饰符,表明这个类只承载数据输出的作用。
from .modeling_utils import ModelMixinfrom ..configuration_utils import ConfigMixin, register_to_configclass Unet2DModel(ModelMixin, ConfigMixin):"""A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped otuput."""...
unet
Unet2DModel
主体由down_blocks, mid_blocks, up_blocks三块组成。输入除了sample,还有time_embedding和label_embedding。