2023年的深度学习入门指南(26) – 在自己电脑上运行通义千问7b模型
通过量化,通义千问4位量化的模型大小为5.86G,可以在3060等小于16G的家用GPU上也可以运行起来。
通义千问7b的量化运行
通义千问7b提供了4位量化好的Qwen/Qwen-7B-Chat-Int4模型,我们直接调用就好。
首先安装依赖包:
pip install transformers==4.32.0pip install acceleratepip install tiktokenpip install einopspip install transformers_stream_generator==0.0.4pip install scipypip install auto-gptq optimum
如果你是Linux环境的话,可以安装下Flash-Attention来加速:
git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attentioncd flash-attention && pip install .
Windows下暂时还用不了,这个不是必选步骤。
下面我们就可以来写代码调用通义千问7b了:
from transformers import AutoTokenizer, AutoModelForCausalLM# Note: The default behavior now has injection attack prevention off.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat-Int4", trust_remote_code=True)model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat-Int4",device_map="auto",trust_remote_code=True).eval()response, history = model.chat(tokenizer, "生成用C++将字符串倒序的代码", history=None)print(response)
生成结果如下:
以下是C++中将字符串逆序的示例代码:#include #include int main() {std::string str = "Hello, World!";std::string reversedStr = str;std::reverse(reversedStr.begin(), reversedStr.end());std::cout << reversedStr << std::endl;return 0;}首先,我们定义了一个包含字符串的变量 `str`。然后,我们定义了一个空字符串变量 `reversedStr`,用于存储逆序后的字符串。接下来,我们使用 `std::reverse()` 函数将 `str` 中的字符逆序。该函数需要一个迭代器范围作为参数,表示要逆序的字符序列。在这里,我们使用 `str.begin()` 和 `str.end()` 获取字符串的起始和结束迭代器,然后将它们传递给 `std::reverse()` 函数。最后,我们输出逆序后的字符串。
我是在3060 GPU上运行成功的。
下面我们继续讲解通义千问7B的源代码。
通义千问7b的全连接网络
除了使用了silu激活函数之外,其他就是基本的全连接网络了。
class QWenMLP(nn.Module):def __init__(self, config):super().__init__()self.w1 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)self.w2 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)ff_dim_in = config.intermediate_size // 2self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)def forward(self, hidden_states):a1 = self.w1(hidden_states)a2 = self.w2(hidden_states)intermediate_parallel = a1 * F.silu(a2)output = self.c_proj(intermediate_parallel)return output
SiLU 函数是一种神经网络中的激活函数,全称是 Sigmoid Linear Unit, 也被称为 Swish 函数。它由 Google Brain 在 2017 年提出,是一种非线性激活函数,能够有效地对神经网络的输入进行非线性变换。
SiLU 函数的定义如下:
f(x) = x * sigmoid(x)
其中,sigmoid 函数是 Sigmoid 函数,定义如下:
sigmoid(x) = 1 / (1 + exp(-x))
SiLU 函数的特点如下:
- 正数区域内,SiLU 函数的输出与 ReLU 函数的输出相同。
- 在负数区域内,SiLU 函数的输出与 sigmoid 函数的输出相同。
- SiLU 函数在整个定义域内都是可微的,这使得在反向传播过程中的梯度计算更加稳定。
- SiLU函数不是单调递增的,而是在x≈−1.28时达到全局最小值−0.28,这可以起到一个隐式正则化的作用,抑制过大的权重
Transformer块
下面我们将RMSNorm,QWenAttention和QWenMLP三者搭建成QWenBlock,就类似于LLaMA中的TransformerBlock:
class QWenBlock(nn.Module):def __init__(self, config):super().__init__()hidden_size = config.hidden_sizeself.bf16 = config.bf16self.ln_1 = RMSNorm(hidden_size,eps=config.layer_norm_epsilon,)self.attn = QWenAttention(config)self.ln_2 = RMSNorm(hidden_size,eps=config.layer_norm_epsilon,)self.mlp = QWenMLP(config)def forward(self,hidden_states: Optional[Tuple[torch.FloatTensor]],rotary_pos_emb: Optional[List[torch.Tensor]] = None,registered_causal_mask: Optional[torch.Tensor] = None,layer_past: Optional[Tuple[torch.Tensor]] = None,attention_mask: Optional[torch.FloatTensor] = None,head_mask: Optional[torch.FloatTensor] = None,encoder_hidden_states: Optional[torch.Tensor] = None,encoder_attention_mask: Optional[torch.FloatTensor] = None,use_cache: Optional[bool] = False,output_attentions: Optional[bool] = False,):layernorm_output = self.ln_1(hidden_states)attn_outputs = self.attn(layernorm_output,rotary_pos_emb,registered_causal_mask=registered_causal_mask,layer_past=layer_past,attention_mask=attention_mask,head_mask=head_mask,use_cache=use_cache,output_attentions=output_attentions,)attn_output = attn_outputs[0]outputs = attn_outputs[1:]residual = hidden_stateslayernorm_input = attn_output + residuallayernorm_output = self.ln_2(layernorm_input)residual = layernorm_inputmlp_output = self.mlp(layernorm_output)hidden_states = residual + mlp_outputif use_cache:outputs = (hidden_states,) + outputselse:outputs = (hidden_states,) + outputs[1:]return outputs
这一模块主要就是将一些参数传递给上节我们介绍过的QWenAttention:
- hidden_states:一个可选的元组,包含了上一层的输出张量,形状为(batch_size, sequence_length, hidden_size)。
- rotary_pos_emb:一个可选的列表,包含了旋转位置编码张量,形状为(batch_size, sequence_length, hidden_size)。
- registered_causal_mask:一个可选的张量,用于注册因果掩码,防止模型看到未来的信息。形状为(batch_size, sequence_length, sequence_length)。
- layer_past:一个可选的元组,包含了上一层的注意力键值对张量,用于实现缓存机制,加速生成过程。形状为(2, batch_size, num_heads, sequence_length, head_dim)。
- attention_mask:一个可选的浮点张量,用于对输入序列进行掩码,忽略无效的位置或填充部分。形状为(batch_size, sequence_length)或(batch_size, 1, 1, sequence_length)。
- head_mask:一个可选的浮点张量,用于对注意力头进行掩码,随机删除一些头以增加模型的鲁棒性。形状为(num_heads,)或(1, 1, num_heads, 1)。
- encoder_hidden_states:一个可选的张量,用于实现编码器-解码器结构时,传递编码器的输出给解码器。形状为(batch_size, encoder_sequence_length, hidden_size)。
- encoder_attention_mask:一个可选的浮点张量,用于实现编码器-解码器结构时,对编码器输出进行掩码。形状为(batch_size, encoder_sequence_length)或(batch_size, 1, 1, encoder_sequence_length)。
- use_cache:一个可选的布尔值,用于指示是否使用缓存机制。
- output_attentions:一个可选的布尔值,用于指示是否输出注意力权重张量。
RMSNorm
RMSNorm我们已经讲过多次的,这里就不多介绍了:
class RMSNorm(torch.nn.Module):def __init__(self, dim: int, eps: float = 1e-6):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def _norm(self, x):return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x):if rms_norm is not None and x.is_cuda:return rms_norm(x, self.weight, self.eps)else:output = self._norm(x.float()).type_as(x)return output * self.weight
位置编码
还记得讲百川模型代码时我们遇到的einsum吗?在千问的代码里我们会再次遇到这样的爱因斯坦风格,这次我们用到的是一个库einops。
在einops的加持下,我们可以将维度变换的操作变得更有可读性:
from einops import rearrangeemb = rearrange(emb, "n d -> 1 n 1 d")
rearrange函数可以根据字符串表达式来重新排列张量维度。
这里的”n d -> 1 n 1 d”表示:
- 从(n, d)形状
- 重新排列为(1, n, 1, d)形状
也就是在emb张量的维度1(n个向量)前面增加两维,变成1和1。
其余的还是使用cos和sin函数作cache:
class RotaryEmbedding(torch.nn.Module):def __init__(self, dim, base=10000):super().__init__()self.dim = dimself.base = baseself.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))if importlib.util.find_spec("einops") is None:raise RuntimeError("einops is required for Rotary Embedding")self._rotary_pos_emb_cache = Noneself._seq_len_cached = 0self._ntk_alpha_cached = 1.0def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):seqlen = max_seq_len + offsetif seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))self.inv_freq = 1.0 / (base** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()/ self.dim))self._seq_len_cached = max(2 * seqlen, 16)self._ntk_alpha_cached = ntk_alphaseq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)emb = torch.cat((freqs, freqs), dim=-1)from einops import rearrangeemb = rearrange(emb, "n d -> 1 n 1 d")cos, sin = emb.cos(), emb.sin()self._rotary_pos_emb_cache = [cos, sin]def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)cos, sin = self._rotary_pos_emb_cachereturn [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]]
千问7B的旋转函数也是用einops.rearrange来实现的:
def _rotate_half(x):from einops import rearrangex = rearrange(x, "... (j d) -> ... j d", j=2)x1, x2 = x.unbind(dim=-2)return torch.cat((-x2, x1), dim=-1)
最后是apply_rotary_pos_emb函数,作用是将旋转位置编码应用到输入张量t上。
def apply_rotary_pos_emb(t, freqs):cos, sin = freqsif apply_rotary_emb_func is not None and t.is_cuda:t_ = t.float()cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]output = apply_rotary_emb_func(t_, cos, sin).type_as(t)return outputelse:rot_dim = freqs[0].shape[-1]cos, sin = freqst_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]t_ = t_.float()t_pass_ = t_pass_.float()t_ = (t_ * cos) + (_rotate_half(t_) * sin)return torch.cat((t_, t_pass_), dim=-1).type_as(t)
apply_rotary_pos_emb的主要步骤:
- 从freqs中分离出cos和sin编码。
- 如果CUDA环境且有apply_rotary_emb_func实现,直接调用该函数进行优化的旋转编码。
- 否则,手动实现旋转编码:
- 将t切分为要编码部分t_和不编码部分t_pass_。
- 计算旋转编码后的t_。
- 将编码后的t_和未编码的t_pass_拼接。
- 返回拼接后的结果。
这样,当有优化实现时直接调用,否则用Python实现旋转位置编码。
旋转位置编码的作用是让模型表征更具局部性,使自注意力更聚焦在关键区域。这通常能提升长序列建模的性能。
通义千问的Transformer模型
class QWenModel(QWenPreTrainedModel):_keys_to_ignore_on_load_missing = ["attn.masked_bias"]def __init__(self, config):super().__init__(config)self.vocab_size = config.vocab_sizeself.num_hidden_layers = config.num_hidden_layersself.embed_dim = config.hidden_sizeself.gradient_checkpointing = Falseself.use_dynamic_ntk = config.use_dynamic_ntkself.seq_length = config.seq_lengthself.wte = nn.Embedding(self.vocab_size, self.embed_dim)self.drop = nn.Dropout(config.emb_dropout_prob)if config.rotary_pct == 1.0:self.rotary_ndims = Noneelse:assert config.rotary_pct < 1self.rotary_ndims = int(config.kv_channels * config.rotary_pct)dim = (self.rotary_ndimsif self.rotary_ndims is not Noneelse config.kv_channels)self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)self.use_flash_attn = config.use_flash_attnself.is_fp32 = not (config.bf16 or config.fp16)if (self.use_flash_attnand flash_attn_unpadded_func is not Noneand not self.is_fp32):self.registered_causal_mask = Noneelse:max_positions = config.max_position_embeddingsself.register_buffer("registered_causal_mask",torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(1, 1, max_positions, max_positions),persistent=False,)self.h = nn.ModuleList([QWenBlock(config)for i in range(config.num_hidden_layers)])self.ln_f = RMSNorm(self.embed_dim,eps=config.layer_norm_epsilon,)self.post_init()
初始化的部分还是将之前介绍过的各模块组合在一起。
下面是虽然大但是主要是例行公事和错误判断的forward:
def forward(self,input_ids: Optional[torch.LongTensor] = None,past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,attention_mask: Optional[torch.FloatTensor] = None,token_type_ids: Optional[torch.LongTensor] = None,position_ids: Optional[torch.LongTensor] = None,head_mask: Optional[torch.FloatTensor] = None,inputs_embeds: Optional[torch.FloatTensor] = None,encoder_hidden_states: Optional[torch.Tensor] = None,encoder_attention_mask: Optional[torch.FloatTensor] = None,use_cache: Optional[bool] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,):output_attentions = (output_attentionsif output_attentions is not Noneelse self.config.output_attentions)output_hidden_states = (output_hidden_statesif output_hidden_states is not Noneelse self.config.output_hidden_states)use_cache = use_cache if use_cache is not None else self.config.use_cachereturn_dict = (return_dict if return_dict is not None else self.config.use_return_dict)if input_ids is not None and inputs_embeds is not None:raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")elif input_ids is not None:input_shape = input_ids.size()input_ids = input_ids.view(-1, input_shape[-1])batch_size = input_ids.shape[0]elif inputs_embeds is not None:input_shape = inputs_embeds.size()[:-1]batch_size = inputs_embeds.shape[0]else:raise ValueError("You have to specify either input_ids or inputs_embeds")device = input_ids.device if input_ids is not None else inputs_embeds.deviceif token_type_ids is not None:token_type_ids = token_type_ids.view(-1, input_shape[-1])if position_ids is not None:position_ids = position_ids.view(-1, input_shape[-1])if past_key_values is None:past_length = 0past_key_values = tuple([None] * len(self.h))else:past_length = past_key_values[0][0].size(-2)if position_ids is None:position_ids = torch.arange(past_length,input_shape[-1] + past_length,dtype=torch.long,device=device,)position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])if attention_mask is not None:if batch_size <= 0:raise ValueError("batch_size has to be defined and > 0")attention_mask = attention_mask.view(batch_size, -1)attention_mask = attention_mask[:, None, None, :]attention_mask = attention_mask.to(dtype=self.dtype)attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).minencoder_attention_mask = Nonehead_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)if inputs_embeds is None:inputs_embeds = self.wte(input_ids)hidden_states = inputs_embedskv_seq_len = hidden_states.size()[1]if past_key_values[0] is not None:# past key values[0][0] shape: bs * seq_len * head_num * dimkv_seq_len += past_key_values[0][0].shape[1]if (self.use_dynamic_ntkand kv_seq_len == hidden_states.size()[1]and not self.training):context_value = math.log(kv_seq_len / self.seq_length, 2) + 1ntk_alpha = 2 ** math.ceil(context_value) - 1ntk_alpha = max(ntk_alpha, 1)else:ntk_alpha = self.rotary_emb._ntk_alpha_cachedrotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)for idx in range(len(rotary_pos_emb)):rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)hidden_states = self.drop(hidden_states)output_shape = input_shape + (hidden_states.size(-1),)if self.gradient_checkpointing and self.training:if use_cache:logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")use_cache = Falsepresents = () if use_cache else Noneall_self_attentions = () if output_attentions else Noneall_hidden_states = () if output_hidden_states else Nonefor i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):if output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)if self.gradient_checkpointing and self.training:def create_custom_forward(module):def custom_forward(*inputs):# None for past_key_valuereturn module(*inputs, use_cache, output_attentions)return custom_forwardoutputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block),hidden_states,rotary_pos_emb,self.registered_causal_mask,None,attention_mask,head_mask[i],encoder_hidden_states,encoder_attention_mask,)else:outputs = block(hidden_states,layer_past=layer_past,rotary_pos_emb=rotary_pos_emb,registered_causal_mask=self.registered_causal_mask,attention_mask=attention_mask,head_mask=head_mask[i],encoder_hidden_states=encoder_hidden_states,encoder_attention_mask=encoder_attention_mask,use_cache=use_cache,output_attentions=output_attentions,)hidden_states = outputs[0]if use_cache is True:presents = presents + (outputs[1],)if output_attentions:all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)hidden_states = self.ln_f(hidden_states)hidden_states = hidden_states.view(output_shape)# Add last hidden stateif output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)if not return_dict:return tuple(v for v in [hidden_states, presents, all_hidden_states] if v is not None)return BaseModelOutputWithPast(last_hidden_state=hidden_states,past_key_values=presents,hidden_states=all_hidden_states,attentions=all_self_attentions,)
这实现了一个标准的Transformer编码器结构,有输入处理、Encoding块循环、输出后处理三个主要部分。使用了层规范化、多头自注意力、残差连接等机制。还支持caching、checkpoints、mask等功能。
预训练模型
下面再说一下QWenModel的基类,用于设置并行训练和保存点等信息的,继承自PreTrainedModel的类:
class QWenPreTrainedModel(PreTrainedModel):config_class = QWenConfigbase_model_prefix = "transformer"is_parallelizable = Falsesupports_gradient_checkpointing = True_no_split_modules = ["QWenBlock"]def __init__(self, *inputs, **kwargs):super().__init__(*inputs, **kwargs)def _init_weights(self, module):"""Initialize the weights."""if isinstance(module, nn.Linear):module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)if module.bias is not None:module.bias.data.zero_()elif isinstance(module, nn.Embedding):module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)if module.padding_idx is not None:module.weight.data[module.padding_idx].zero_()elif isinstance(module, RMSNorm):module.weight.data.fill_(1.0)for name, p in module.named_parameters():if name == "c_proj.weight":p.data.normal_(mean=0.0,std=(self.config.initializer_range/ math.sqrt(2 * self.config.num_hidden_layers)),)def _set_gradient_checkpointing(self, module, value=False):if isinstance(module, QWenModel):module.gradient_checkpointing = value
语言模型封装
上面的QWenModel返回的BaseModelOutputWithPast,如果要做成语言模型的话,还要封装成CausalLMOutputWithPast。
class QWenLMHeadModel(QWenPreTrainedModel):_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]def __init__(self, config):super().__init__(config)assert (config.bf16 + config.fp16 + config.fp32 <= 1), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0if autoset_precision:if SUPPORT_BF16:logger.warn("The model is automatically converting to bf16 for faster inference. ""If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\".")config.bf16 = Trueelif SUPPORT_FP16:logger.warn("The model is automatically converting to fp16 for faster inference. ""If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\".")config.fp16 = Trueelse:config.fp32 = Trueif config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")if config.fp32:if SUPPORT_BF16:logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")elif SUPPORT_FP16:logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")if config.use_flash_attn == "auto":if config.bf16 or config.fp16:logger.warn("Try importing flash-attention for faster inference...")config.use_flash_attn = Trueelse:config.use_flash_attn = Falseif config.use_flash_attn and config.fp32:logger.warn("Flash attention will be disabled because it does NOT support fp32.")if config.use_flash_attn:_import_flash_attn()self.transformer = QWenModel(config)self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)if config.bf16:self.transformer.bfloat16()self.lm_head.bfloat16()if config.fp16:self.transformer.half()self.lm_head.half()self.post_init()def get_output_embeddings(self):return self.lm_headdef set_output_embeddings(self, new_embeddings):self.lm_head = new_embeddingsdef prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):token_type_ids = kwargs.get("token_type_ids", None)if past_key_values:input_ids = input_ids[:, -1].unsqueeze(-1)if token_type_ids is not None:token_type_ids = token_type_ids[:, -1].unsqueeze(-1)attention_mask = kwargs.get("attention_mask", None)position_ids = kwargs.get("position_ids", None)if attention_mask is not None and position_ids is None:position_ids = attention_mask.long().cumsum(-1) - 1position_ids.masked_fill_(attention_mask == 0, 1)if past_key_values:position_ids = position_ids[:, -1].unsqueeze(-1)else:position_ids = Noneif inputs_embeds is not None and past_key_values is None:model_inputs = {"inputs_embeds": inputs_embeds}else:model_inputs = {"input_ids": input_ids}model_inputs.update({"past_key_values": past_key_values,"use_cache": kwargs.get("use_cache"),"position_ids": position_ids,"attention_mask": attention_mask,"token_type_ids": token_type_ids,})return model_inputsdef forward(self,input_ids: Optional[torch.LongTensor] = None,past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,attention_mask: Optional[torch.FloatTensor] = None,token_type_ids: Optional[torch.LongTensor] = None,position_ids: Optional[torch.LongTensor] = None,head_mask: Optional[torch.FloatTensor] = None,inputs_embeds: Optional[torch.FloatTensor] = None,encoder_hidden_states: Optional[torch.Tensor] = None,encoder_attention_mask: Optional[torch.FloatTensor] = None,labels: Optional[torch.LongTensor] = None,use_cache: Optional[bool] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple, CausalLMOutputWithPast]:return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)transformer_outputs = self.transformer(input_ids,past_key_values=past_key_values,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,encoder_hidden_states=encoder_hidden_states,encoder_attention_mask=encoder_attention_mask,use_cache=use_cache,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)hidden_states = transformer_outputs[0]lm_logits = self.lm_head(hidden_states)loss = Noneif labels is not None:labels = labels.to(lm_logits.device)shift_logits = lm_logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()loss_fct = CrossEntropyLoss()loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))if not return_dict:output = (lm_logits,) + transformer_outputs[1:]return ((loss,) + output) if loss is not None else outputreturn CausalLMOutputWithPast(loss=loss,logits=lm_logits,past_key_values=transformer_outputs.past_key_values,hidden_states=transformer_outputs.hidden_states,attentions=transformer_outputs.attentions,)
在forward之外,语言模型还需要封装一个生成函数。主要也是做一些配置,然后调用父类的生成函数:
def generate(self,inputs: Optional[torch.Tensor] = None,generation_config: Optional[GenerationConfig] = None,logits_processor: Optional[LogitsProcessorList] = None,stopping_criteria: Optional[StoppingCriteriaList] = None,prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,synced_gpus: Optional[bool] = None,assistant_model: Optional["PreTrainedModel"] = None,streamer: Optional["BaseStreamer"] = None,**kwargs,) -> Union[GenerateOutput, torch.LongTensor]:generation_config = generation_config if generation_config is not None else self.generation_config# Process stop_words_ids.stop_words_ids = kwargs.pop("stop_words_ids", None)if stop_words_ids is None and generation_config is not None:stop_words_ids = getattr(generation_config, "stop_words_ids", None)if stop_words_ids is None:stop_words_ids = getattr(generation_config, "stop_words_ids", None)if stop_words_ids is not None:stop_words_logits_processor = StopWordsLogitsProcessor(stop_words_ids=stop_words_ids,eos_token_id=generation_config.eos_token_id,)if logits_processor is None:logits_processor = LogitsProcessorList([stop_words_logits_processor])else:logits_processor.append(stop_words_logits_processor)return super().generate(inputs,generation_config=generation_config,logits_processor=logits_processor,stopping_criteria=stopping_criteria,prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,synced_gpus=synced_gpus,assistant_model=assistant_model,streamer=streamer,**kwargs,)
聊天功能封装
def chat(self,tokenizer: PreTrainedTokenizer,query: str,history: Optional[HistoryType],system: str = "You are a helpful assistant.",append_history: bool = True,stream: Optional[bool] = _SENTINEL,stop_words_ids: Optional[List[List[int]]] = None,generation_config: Optional[GenerationConfig] = None,**kwargs,) -> Tuple[str, HistoryType]:generation_config = generation_config if generation_config is not None else self.generation_configassert stream is _SENTINEL, _ERROR_STREAM_IN_CHATassert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMATif history is None:history = []if stop_words_ids is None:stop_words_ids = []max_window_size = kwargs.get('max_window_size', None)if max_window_size is None:max_window_size = generation_config.max_window_sizeraw_text, context_tokens = make_context(tokenizer,query,history=history,system=system,max_window_size=max_window_size,chat_format=generation_config.chat_format,)stop_words_ids.extend(get_stop_words_ids(generation_config.chat_format, tokenizer))input_ids = torch.tensor([context_tokens]).to(self.device)outputs = self.generate(input_ids,stop_words_ids=stop_words_ids,return_dict_in_generate=False,generation_config=generation_config,**kwargs,)response = decode_tokens(outputs[0],tokenizer,raw_text_len=len(raw_text),context_length=len(context_tokens),chat_format=generation_config.chat_format,verbose=False,errors='replace')if append_history:history.append((query, response))return response, history
其主要流程如下:
流式聊天封装
最后是封装成可以流式获取的函数。
其主要流程为:
- 和chat方法类似,先做输入query的处理,组装context。
- 计算停止词stop_words_ids。
- 将停止词集合封装成StopWordsLogitsProcessor。
- 将context转成input_ids作为模型输入。
- 关键在这里,调用generate_stream方法进行流式生成。它会逐个token地生成序列,并用yield返回每个结果。
- 在一个while循环中收集生成的token,并用decode方法转成文本。
- 通过yield关键字返回每个解码的结果。
- 最终形成一个生成器,可以不断获取模型生成的内容。
def chat_stream(self,tokenizer: PreTrainedTokenizer,query: str,history: Optional[HistoryType],system: str = "You are a helpful assistant.",stop_words_ids: Optional[List[List[int]]] = None,logits_processor: Optional[LogitsProcessorList] = None,generation_config: Optional[GenerationConfig] = None,**kwargs,) -> Generator[str, Any, None]:generation_config = generation_config if generation_config is not None else self.generation_configassert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMATif history is None:history = []if stop_words_ids is None:stop_words_ids = []max_window_size = kwargs.get('max_window_size', None)if max_window_size is None:max_window_size = generation_config.max_window_sizeraw_text, context_tokens = make_context(tokenizer,query,history=history,system=system,max_window_size=max_window_size,chat_format=generation_config.chat_format,)stop_words_ids.extend(get_stop_words_ids(generation_config.chat_format, tokenizer))if stop_words_ids is not None:stop_words_logits_processor = StopWordsLogitsProcessor(stop_words_ids=stop_words_ids,eos_token_id=generation_config.eos_token_id,)if logits_processor is None:logits_processor = LogitsProcessorList([stop_words_logits_processor])else:logits_processor.append(stop_words_logits_processor)input_ids = torch.tensor([context_tokens]).to(self.device)from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfigself.__class__.generate_stream = NewGenerationMixin.generateself.__class__.sample_stream = NewGenerationMixin.sample_streamstream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)def stream_generator():outputs = []for token in self.generate_stream(input_ids,return_dict_in_generate=False,generation_config=stream_config,logits_processor=logits_processor,seed=-1,**kwargs):outputs.append(token.item())yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')return stream_generator()
小结
这节我们终于介绍完了千问7b的模型的代码。凡是讲源码的肯定会遇到大量细节,这些细节也未必是值得花太多精力去抠的,但是原汁原味的代码还是能更精确地表达功能的真实含义。
后面我们还会将模型实现抽象一下,做更系统化的讲解便于初学者理解。对于从业的同学,因为你们面对的就是这些细节,所以先熟悉起来吧。