chatglm6b和闻达的功能扩展

最近大火的chatgpt,老板说让我看看能不能用自己的数据,回答专业一些,所以做了一些调研,最近用这个倒是成功推理了自己的数据,模型也开源了,之后有机会也训练一下自己的数据。

chatglm6b和闻达的功能扩展

  • 1.本机部署
    • 1.1环境部署
    • 1.2 配置参数
    • 1.3. 推理
  • 2.云服务器部署
  • 3.项目需求
    • 3.1 修改前端的名字
    • 3.2 不同用户用不同的知识库
      • 3.2.1修改生成不同目录的知识库文件
      • 3.2.2 不同用户用不同知识库
      • 3.2.3效果
      • 3.2.4一个txt或pdf自动生成一个独立的知识库
      • 3.2.5返回score值最低的知识库prompt
    • 3.3 ptuning微调
      • 3.3.1chatglm的ptuning
      • 3.3.2闻达的ptuning
    • 3.4做socket接口
    • 3.5用langchain与sql数据库交互
    • 3.6自定义template与sql数据库交互并用flask前端展示
    • 3.7自定义自己的tool和chain

1.本机部署

1.1环境部署

1.1双击打开anconda prompt创建虚拟环境

Conda create –n chatglm python#(创建名叫chatglm的虚拟python环境)Conda activate chatglm#(激活环境)

1.2下载pytorch(这里要根据自己的电脑版本下载)都在虚拟环境里操作

nvidia-smi#(查看自己cuda版本)pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118#(下载符合自己配置的torch,可以在官网https://pytorch.org/查看命令)

图片[1] - chatglm6b和闻达的功能扩展 - MaxSSL
1.3在官网https://download.pytorch.org/whl/torch_stable.html下载对应的cuda版本的torch和torchvision,然后pip install即可
这时gpu版的torch就下载成功:,验证方法如图:
图片[2] - chatglm6b和闻达的功能扩展 - MaxSSL
1.4安装依赖库

cd C:\Users\dz\Desktop\AIGC\wenda\wd-git\wenda\requirements#(进入工具包的simple目录下)pip install –r .\requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simplepip install protobuf flatbuffers termcolor#(根据提示下载需要的包和自己的模型requirements.txt文件)

1.2 配置参数

  1. 配模型:下载对应的模型权重文件,放到model文件夹下面,这里我用的是RWKV:
    图片[3] - chatglm6b和闻达的功能扩展 - MaxSSL
  2. 配数据:自己的文本数据放到txt文件夹下面:
    图片[4] - chatglm6b和闻达的功能扩展 - MaxSSL

3.配环境:在environment里面把环境配成自己刚刚创建的虚拟环境
图片[5] - chatglm6b和闻达的功能扩展 - MaxSSL

在config里面把权重文件的地址和配置改成自己的

图片[6] - chatglm6b和闻达的功能扩展 - MaxSSL

1.3. 推理

  1. 双击step.2本地数据库建库.bat建本地数据库
    图片[7] - chatglm6b和闻达的功能扩展 - MaxSSL
  2. 双击run_rwkv-点击运行.bat运行这个模型,然后浏览器打开http://127.0.0.1:17860/
    首先测试是否检测到本地数据库

问答功能

2.云服务器部署

电脑跑起来不行,所以在云服务器上搞了一个,本来是git源码的,但是源码git下来运行有问题,所以我还是把本地文件放到自己仓库,重新git了一下,云服务器租环境,就租wenda环境,然后

git clone https://github.com/Turing-dz/wenda_zoe_test.git

修改example.config.xml文件里的模型地址,然后就可以推理自己的数据了。

python pluges/gen_data_st.py#运行本地数据库python wenda.py -t glm6b -p 6006#云上规定用6006映射

然后打开链接,打开知识库按钮,就会推理自己的数据文件了。

3.项目需求

3.1 修改前端的名字

修改views/static/string.js里面的常量值就可以。

3.2 不同用户用不同的知识库

这个其实是一个安全问题,但代码修改起来也很简单,分两步,一个是生成不同的知识库,下一步就是调用不同的知识库。

3.2.1修改生成不同目录的知识库文件

1.修改example.config.yml,当用户没有给-u参数时,默认txt下的文件生成到memory的default1文件夹下。

user_Type: default1

图片[8] - chatglm6b和闻达的功能扩展 - MaxSSL
2.修改common.py文件,设置用户输入-u参数,如果没输入就用上面设置的默认default1

parser.add_argument('-u', type=str, dest="user_to_knowledge", help="不同用户的本地知识库")user_Type = str(args.user_to_knowledge) ifuser_Type != 'None':settings.user_Type=user_Type

图片[9] - chatglm6b和闻达的功能扩展 - MaxSSL
图片[10] - chatglm6b和闻达的功能扩展 - MaxSSL
3.修改gen_data_st.py文件,这个文件是生成知识库的,所以要修改生成地址

add_knowledge='memory/'+settings.user_Typetry:vectorstore_old = FAISS.load_local(add_knowledge, embeddings=embeddings)success_print("合并至已有索引。如不需合并请删除 add_knowledge 文件夹")vectorstore_old.merge_from(vectorstore)vectorstore_old.save_local(add_knowledge)

图片[11] - chatglm6b和闻达的功能扩展 - MaxSSL
图片[12] - chatglm6b和闻达的功能扩展 - MaxSSL

3.2.2 不同用户用不同知识库

修改zhishiku_rtst.py文件

def find(s,step = 0,memory_name=settings.user_Type): 

图片[13] - chatglm6b和闻达的功能扩展 - MaxSSL

3.2.3效果

python '/root/autodl-fs/wenda_zoe_test/plugins/gen_data_st.py' -u u2python '/root/autodl-fs/wenda_zoe_test/wenda.py' -u u2 -t glm6b -p 6006
python '/root/autodl-fs/wenda_zoe_test/plugins/gen_data_st.py' -u u5python '/root/autodl-fs/wenda_zoe_test/wenda.py' -u u5 -t glm6b -p 6006

3.2.4一个txt或pdf自动生成一个独立的知识库

天哥需要一个文件生成一个知识库。这个就更简单了,修改gen_data_st.py文件,

#add_knowledge='memory/'+settings.user_Type#这个是上次的-u功能,可以先注释#下面两段代码加到for循环里,并把地下的代码都右移一位,加到for循环里面add_knowledge='memory/'+fileadd_knowledge=add_knowledge.split(".")[0]

图片[14] - chatglm6b和闻达的功能扩展 - MaxSSL
但在后面需要返回score最大文章的content时,发现了bug,上面改完之后每次生成下一个文件的知识库时都会把之前的包括了,所以如果数据要独立,还得在all_files的循环开始加上

docs=[]vectorstore = None

最好把下面的合并索引也删掉。所以改完的gen_data_st .py如下:

import argparseimport sentence_transformersfrom langchain.text_splitter import CharacterTextSplitterfrom langchain.vectorstores.faiss import FAISSfrom langchain.embeddings import HuggingFaceEmbeddingsfrom langchain.docstore.document import Documentimport threadingimport pdfplumberimport reimport chardetimport osimport sysimport timeos.chdir(sys.path[0][:-8])from common import success_printfrom common import error_helperfrom common import settingsfrom common import CounterLocksource_folder = 'txt'source_folder_path = os.path.join(os.getcwd(), source_folder)#add_knowledge='memory/'+settings.user_Typeimport logginglogging.basicConfig()logger = logging.getLogger()logger.setLevel(logging.ERROR)root_path_list = source_folder_path.split(os.sep)docs = []vectorstore = Nonemodel_path = settings.librarys.rtst.model_pathtry:embeddings = HuggingFaceEmbeddings(model_name='')embeddings.client = sentence_transformers.SentenceTransformer(model_path, device="cuda")except Exception as e:error_helper("embedding加载失败,请下载相应模型", r"https://github.com/l15y/wenda#st%E6%A8%A1%E5%BC%8F")raise esuccess_print("Embedding 加载完成")embedding_lock=CounterLock()vectorstore_lock=threading.Lock()def clac_embedding(texts, embeddings, metadatas):global vectorstorewith embedding_lock:vectorstore_new = FAISS.from_texts(texts, embeddings, metadatas=metadatas)with vectorstore_lock:if vectorstore is None:vectorstore = vectorstore_newelse:vectorstore.merge_from(vectorstore_new)def make_index():global docsif hasattr(settings.librarys.rtst,"size") and hasattr(settings.librarys.rtst,"overlap"):text_splitter = CharacterTextSplitter(chunk_size=int(settings.librarys.rtst.size), chunk_overlap=int(settings.librarys.rtst.overlap), separator='\n')else:text_splitter = CharacterTextSplitter(chunk_size=20, chunk_overlap=0, separator='\n')doc_texts = text_splitter.split_documents(docs)docs = []texts = [d.page_content for d in doc_texts]metadatas = [d.metadata for d in doc_texts]thread = threading.Thread(target=clac_embedding, args=(texts, embeddings, metadatas))thread.start()while embedding_lock.get_waiting_threads()>2:time.sleep(0.1)all_files=[]for root, dirs, files in os.walk(source_folder_path):for file in files:all_files.append([root, file])success_print("文件列表生成完成",len(all_files))for i in range(len(all_files)):root, file=all_files[i]length_of_read=0docs=[]vectorstore = Nonedata = ""title = ""try:if file.endswith(".pdf"):file_path = os.path.join(root, file)with pdfplumber.open(file_path) as pdf:data_list = []for page in pdf.pages:print(page.extract_text())data_list.append(page.extract_text())data = "\n".join(data_list)else:# txtfile_path = os.path.join(root, file)with open(file_path, 'rb') as f:b = f.read()result = chardet.detect(b)with open(file_path, 'r', encoding=result['encoding']) as f:data = f.read()add_knowledge='memory/'+fileadd_knowledge=add_knowledge.split(".")[0]except Exception as e:print("文件读取失败,当前文件已被跳过:",file,"。错误信息:",e)data = re.sub(r'!', "!\n", data)data = re.sub(r':', ":\n", data)data = re.sub(r'。', "。\n", data)data = re.sub(r'\r', "\n", data)data = re.sub(r'\n\n', "\n", data)data = re.sub(r"\n\s*\n", "\n", data)length_of_read+=len(data)docs.append(Document(page_content=data, metadata={"source": file}))if length_of_read > 1e5:success_print("处理进度",int(100*i/len(all_files)),f"%\t({i}/{len(all_files)})")make_index()# print(embedding_lock.get_waiting_threads())length_of_read=0if len(all_files) == 0:#error_print("txt 目录没有数据")print("txt 目录没有数据")sys.exit(0)if len(docs) > 0:make_index()while embedding_lock.get_waiting_threads()>0:time.sleep(0.1)with embedding_lock:time.sleep(0.1)with vectorstore_lock:success_print("处理完成")# try:# vectorstore_old = FAISS.load_local(# add_knowledge, embeddings=embeddings)# success_print("合并至已有索引。如不需合并请删除 add_knowledge 文件夹")# vectorstore_old.merge_from(vectorstore)# vectorstore_old.save_local(add_knowledge)# except:# print("新建索引")vectorstore.save_local(add_knowledge)success_print("保存完成")

3.2.5返回score值最低的知识库prompt

需要遍历生成的知识库,所以在zhishiku_rtst.py里面加上

source_folder = 'memory'memory_name_list=[]source_folder_path = os.path.join(os.getcwd(), source_folder)for root, dirs, files in os.walk(source_folder_path):for dir in dirs:memory_name_list.append(dir)

然后在find函数里遍历,并计算score值,score越大距离越远,所以要最小的prompt,所以zhishiku_rtst.py文件如下:

from langchain.vectorstores.faiss import FAISSfrom langchain.embeddings import HuggingFaceEmbeddingsimport sentence_transformersimport numpy as npimport re,osfrom plugins.common import settings,allowCROSfrom plugins.common import error_helper from plugins.common import success_print divider='\n'if not os.path.exists('memory'):os.mkdir('memory')cunnrent_setting=settings.librarys.rtst#print(cunnrent_setting.user_to_knowledge)def get_doc_by_id(id,memory_name):return vectorstores[memory_name].docstore.search(vectorstores[memory_name].index_to_docstore_id[id])def process_strings(A, C, B):# find the longest common suffix of A and prefix of Bcommon = ""for i in range(1, min(len(A), len(B)) + 1):if A[-i:] == B[:i]:common = A[-i:]# if there is a common substring, replace one of them with C and concatenateif common:return A[:-len(common)] + C + B# otherwise, just return A + Belse:return A + Bdef get_doc(id,score,step,memory_name):doc = get_doc_by_id(id,memory_name)final_content=doc.page_contentprint("文段分数:",score,[doc.page_content])# print(id,score,step,memory_name,doc)if step > 0:for i in range(1, step+1):try:doc_before=get_doc_by_id(id-i,memory_name)if doc_before.metadata['source']==doc.metadata['source']:final_content=process_strings(doc_before.page_content,divider,final_content)# print("上文分数:",score,doc.page_content)except:passtry:doc_after=get_doc_by_id(id+i,memory_name)if doc_after.metadata['source']==doc.metadata['source']:final_content=process_strings(final_content,divider,doc_after.page_content)except:passif doc.metadata['source'].endswith(".pdf") or doc.metadata['source'].endswith(".txt"):title=f"[{doc.metadata['source']}](/api/read_news/{doc.metadata['source']})"else:title=doc.metadata['source']return {'title': title,'content':re.sub(r'\n+', "\n", final_content),"score":int(score)}source_folder = 'memory'memory_name_list=[]source_folder_path = os.path.join(os.getcwd(), source_folder)for root, dirs, files in os.walk(source_folder_path):for dir in dirs:memory_name_list.append(dir)success_print(memory_name_list)def find(s,step = 0,memory_name="test2"):#"test2",try:scor_min=7000docs_min=[]for memory_name in memory_name_list:docs = []scor=0n=0embedding = get_vectorstore(memory_name).embedding_function(s)scores, indices = vectorstores[memory_name].index.search(np.array([embedding], dtype=np.float32), int(cunnrent_setting.count))#print("scores, indices:",scores, indices)for j, i in enumerate(indices[0]):if i == -1:continueif scores[0][j]>7000:continuedocs.append(get_doc(i,scores[0][j],step,memory_name))scor+=scores[0][j]n+=1if n!=0:scor=scor/nif scor_min>scor:scor_min=scordocs_min=docsdocs=docs_min#print(scor_min)print(docs)return docsexcept Exception as e:print(e)return []try:embeddings = HuggingFaceEmbeddings(model_name='')embeddings.client = sentence_transformers.SentenceTransformer(cunnrent_setting.model_path, device=cunnrent_setting.device)except Exceptionas e:error_helper("embedding加载失败,请下载相应模型",r"https://github.com/l15y/wenda#st%E6%A8%A1%E5%BC%8F")raise evectorstores={}def get_vectorstore(memory_name):try:return vectorstores[memory_name]except Exceptionas e:try:vectorstores[memory_name] = FAISS.load_local('memory/'+memory_name, embeddings=embeddings)return vectorstores[memory_name]except Exceptionas e:success_print("没有读取到RTST记忆区%s,将新建。"%memory_name)return Nonefrom langchain.docstore.document import Documentfrom langchain.text_splitter import CharacterTextSplitterfrom bottle import route, response, request, static_file, hookimport bottle@route('/api/upload_rtst_zhishiku', method=("POST","OPTIONS"))def upload_zhishiku():allowCROS()try:data = request.jsontitle=data.get("title")memory_name=data.get("memory_name")data = re.sub(r'!', "!\n", data.get("txt"))data = re.sub(r'。', "。\n", data)data = re.sub(r'[\n\r]+', "\n", data)docs=[Document(page_content=data, metadata={"source":title })]print(docs)text_splitter = CharacterTextSplitter(chunk_size=20, chunk_overlap=0, separator='\n')doc_texts = text_splitter.split_documents(docs)texts = [d.page_content for d in doc_texts]metadatas = [d.metadata for d in doc_texts]vectorstore_new = FAISS.from_texts(texts, embeddings, metadatas=metadatas)vectorstore=get_vectorstore(memory_name)if vectorstore is None:vectorstores[memory_name]=vectorstore_newelse:vectorstores[memory_name].merge_from(vectorstore_new)return '成功'except Exception as e:return str(e)@route('/api/save_rtst_zhishiku', method=("POST","OPTIONS"))def save_zhishiku():allowCROS()try:data = request.jsonmemory_name=data.get("memory_name")vectorstores[memory_name].save_local('memory/'+memory_name)#print("保存到了"+'memory/'+memory_name)return "保存成功"except Exception as e:return str(e)import json@route('/api/find_rtst_in_memory', method=("POST","OPTIONS"))def api_find():allowCROS()data = request.jsonprompt = data.get('prompt')step = data.get('step')memory_name=data.get("memory_name")if step is None:step = int(settings.library.step)# for i in rangereturn json.dumps(find(prompt,int(step),memory_name_list))@route('/api/save_news', method=("POST","OPTIONS"))def save_news():allowCROS()try:data = request.jsonif not data:return 'no data'title = data.get('title')txt = data.get('txt')cut_file = f"txt/{title}.txt"with open(cut_file, 'w', encoding='utf-8') as f:f.write(txt)f.close()return 'success'except Exception as e:return(e)@route('/api/read_news/:path', method=("GET","OPTIONS"))def read_news(path=""):allowCROS()return static_file(path, root="txt/")

3.3 ptuning微调

3.3.1chatglm的ptuning

这里首先用官方的工具,生成对话的json数据,然后把autodl-tmp/ChatGLM-6B/ptuning/AdvertiseGen/里面的训练和测试的json数据替换成工具生成的自己的数据;修改autodl-tmp/ChatGLM-6B/ptuning/train.sh里面文件的地址,和数据的column,然后bash train.sh
图片[15] - chatglm6b和闻达的功能扩展 - MaxSSL
训练完后可以运行web_demo.py文件测试效果。

3.3.2闻达的ptuning

我这里是将上面train完的autodl-tmp/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000文件复制到wenda的model/ptuning目录下。
在config.yml的里面glm6b下加入了

ptuning: "autodl-fs/wenda_zoe_test/model/ptuning"

plugins/common.py文件加入参数:

ptuning_addr='model/ptuning'pre_seq_len=128prefix_projection=Falseifptuning_addr != 'None':settings.ptuning_addr=ptuning_addrifpre_seq_len != 'None':settings.pre_seq_len=pre_seq_lenifprefix_projection is not True:settings.prefix_projection=prefix_projection

在plugins/llm_glm6b.py里面改掉模型的加载:

#model = AutoModel.from_pretrained(settings.llm.path, local_files_only=True, trust_remote_code=True)config = AutoConfig.from_pretrained(settings.llm.path, trust_remote_code=True)config.pre_seq_len = settings.pre_seq_lenconfig.prefix_projection = settings.prefix_projectiontokenizer = AutoTokenizer.from_pretrained(settings.llm.path, local_files_only=True, trust_remote_code=True)if settings.ptuning_addr is not None:import torchmodel = AutoModel.from_pretrained(settings.llm.path, config=config,trust_remote_code=True)prefix_state_dict = torch.load(os.path.join(settings.ptuning_addr, "pytorch_model.bin"))new_prefix_state_dict = {}for k, v in prefix_state_dict.items():if k.startswith("transformer.prefix_encoder."):new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = vmodel.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)else:model = AutoModel.from_pretrained(settings.llm.path, config=config,trust_remote_code=True)

然后再运行wenda.py测试自己做的数据集,就会看到ptuning效果。
图片[16] - chatglm6b和闻达的功能扩展 - MaxSSL

3.4做socket接口

1.wenda_server.py

import logginglogging.captureWarnings(True)import torchimport threadingimport osimport jsonimport datetimefrom bottle import route, response, request, static_file, hookimport bottlefrom plugins.common import settings from plugins.common import error_helper,error_print,success_printfrom plugins.common import CounterLock,allowCROS#memory_name='test2'def load_LLM():try:from importlib import import_moduleLLM = import_module('plugins.llm_'+settings.llm_type)return LLMexcept Exception as e:print("LLM模型加载失败,请阅读说明:https://github.com/l15y/wenda", e)LLM = load_LLM()logging=settings.loggingsif logging:from plugins.defineSQL import session_maker, 记录if not hasattr(LLM,"Lock") :mutex = CounterLock()else:mutex = LLM.Lock()model = Nonetokenizer = Nonedef load_model():with mutex:LLM.load_model()torch.cuda.empty_cache()success_print("模型加载完成")thread_load_model = threading.Thread(target=load_model)thread_load_model.start()zhishiku = Nonedef load_zsk():try:from importlib import import_moduleglobal zhishikuimport plugins.zhishiku as zskzhishiku= zsksuccess_print("知识库加载完成")except Exception as e:error_helper("知识库加载失败,请阅读说明",r"https://github.com/l15y/wenda#%E7%9F%A5%E8%AF%86%E5%BA%93")raise ethread_load_zsk = threading.Thread(target=load_zsk)thread_load_zsk.start()import refooter = ''from socket import *IP = '127.0.0.1'PORT = 50000BUFLEN = 512listenSocket = socket(AF_INET, SOCK_STREAM)listenSocket.bind((IP, PORT))listenSocket.listen(8)print(f'服务端启动成功,在{PORT}端口等待客户端连接...')dataSocket, addr = listenSocket.accept()print('接受一个客户端连接:', addr)while True:# response.content_type = "text/event-stream"# response.add_header("Connection", "keep-alive")# response.add_header("Cache-Control", "no-cache")max_length = Noneif max_length is None:max_length = 2048top_p = Noneif top_p is None:top_p = 0.2temperature = Noneif temperature is None:temperature = 0.8use_zhishiku = Noneif use_zhishiku is None:use_zhishiku = Falserecved = dataSocket.recv(BUFLEN)if not recved:breakprompt = recved.decode()keyword=Noneif keyword is None:keyword = prompthistory_formatted = Noneresponse_text = ''IP = request.environ.get('HTTP_X_REAL_IP') or request.environ.get('REMOTE_ADDR')error = ""if use_zhishiku:response_d = zhishiku.find(keyword,int(settings.library.step))output_sources = [i['title'] for i in response_d]results = '\n'.join([str(i+1)+". "+re.sub('\n\n', '\n', response_d[i]['content']) for i in range(len(response_d))])prompt = 'system: 请扮演一名专业分析师,根据以下内容回答问题:'+prompt + "\n"+ resultsif settings.library.show_soucre == True:footer = "\n### 来源:\n"+('\n').join(output_sources)with mutex:try:for response in LLM.chat_one(prompt, history_formatted, max_length, top_p, temperature, zhishiku=use_zhishiku):if (response):response= response+footerexcept Exception as e:error = str(e)error_print("错误", error)response = ''# raise etorch.cuda.empty_cache()if response == '':response= "发生错误,正在重新加载模型"+erroros._exit(0)if logging:with session_maker() as session:jl = 记录(时间=datetime.datetime.now(), IP=IP,=prompt,=response)session.add(jl)session.commit()print(response)dataSocket.send(f'服务端返回信息: {response}'.encode())# yield "/././"dataSocket.close()listenSocket.close()# import webbrowser# webbrowser.open_new('http://127.0.0.1:'+str(settings.Port))# import functools# def pathinfo_adjust_wrapper(func):# # A wrapper for _handle() method# @functools.wraps(func)# def _(s,environ):# environ["PATH_INFO"] = environ["PATH_INFO"].encode("utf8").decode("latin1")# return func(s,environ)# return _# bottle.Bottle._handle = pathinfo_adjust_wrapper(bottle.Bottle._handle)#修复bottle在处理utf8 url时的bug# bottle.run(server='paste', host="0.0.0.0", port=settings.port, quiet=True)

2.client.py

from socket import *IP = '127.0.0.1'SERVER_PORT = 50000BUFLEN = 1024# 实例化一个socket对象,指明协议dataSocket = socket(AF_INET, SOCK_STREAM)# 连接服务端socketdataSocket.connect((IP, SERVER_PORT))while True:# 从终端读入用户输入的字符串toSend = input('>>> ')iftoSend =='exit':break# 发送消息,也要编码为 bytesdataSocket.send(toSend.encode())# 等待接收服务端的消息recved = dataSocket.recv(BUFLEN)# 如果返回空bytes,表示对方关闭了连接if not recved:break# 打印读取的信息print(recved.decode())dataSocket.close()

3.5用langchain与sql数据库交互

用本地glm6b想解决数据库问题,因此就结合langchain来做,因为langchain和glm6b的适配问题,因此对langchain做了一点点处理,如下:
首先运行文件如下,将glm6b引入llm

from langchain.llms.base import LLMfrom transformers import AutoTokenizer, AutoModel, AutoConfigimport sysfrom typing import List,Optionalclass ChatGLM(LLM):max_token: int = 2048temperature: float = 0.8top_p = 0.1tokenizer: object = Nonemodel: object = Nonehistory_len: int = 1024def __init__(self):super().__init__()@propertydef _llm_type(self) -> str:return "GLM"def load_model(self, llm_device="gpu", model_name_or_path=None):model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True).half().cuda()def _call(self, prompt: str, history: List[str] = [], stop: Optional[List[str]] = None):response, _ = self.model.chat(self.tokenizer, prompt,# history=history[-self.history_len:] if self.history_len > 0 else [],max_length=self.max_token, temperature=self.temperature,top_p=self.top_p)return responsemodelpath = r"C:\xxx\Desktop\wenda-main\wenda-main\model\chatglm2-6b"sys.path.append(modelpath)print(modelpath)llm = ChatGLM()llm.load_model(model_name_or_path=modelpath)from langchain import SQLDatabase, SQLDatabaseChaindb = SQLDatabase.from_uri("mysql+pymysql://root:xxx@localhost/xunlian")db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)db_chain.run("司明苏的单个人员成绩是多少?")

因为适配问题,所以需要对langchain/chain/sql_database/base.py文件里面的两次输出进行处理

class SQLDatabaseChain(Chain):def _call:if sql_cmd:sql_cmd = sql_cmd.split(";")sql_cmd = sql_cmd[0]+";"if chain_result['result']:chain_result['result']=chain_result['result'].split("\n")[-1]

处理完后的base文件如下:

"""Chain for interacting with SQL Database."""from __future__ import annotationsimport warningsfrom typing import Any, Dict, List, Optionalfrom pydantic import Extra, Field, root_validatorfrom langchain.callbacks.manager import CallbackManagerForChainRunfrom langchain.chains.base import Chainfrom langchain.chains.llm import LLMChainfrom langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTSfrom langchain.prompts.prompt import PromptTemplatefrom langchain.schema import BasePromptTemplatefrom langchain.schema.language_model import BaseLanguageModelfrom langchain.sql_database import SQLDatabasefrom langchain.tools.sql_database.prompt import QUERY_CHECKERINTERMEDIATE_STEPS_KEY = "intermediate_steps"class SQLDatabaseChain(Chain):"""Chain for interacting with SQL Database.Example:.. code-block:: pythonfrom langchain import SQLDatabaseChain, OpenAI, SQLDatabasedb = SQLDatabase(...)db_chain = SQLDatabaseChain.from_llm(OpenAI(), db)"""llm_chain: LLMChainllm: Optional[BaseLanguageModel] = None"""[Deprecated] LLM wrapper to use."""database: SQLDatabase = Field(exclude=True)"""SQL Database to connect to."""prompt: Optional[BasePromptTemplate] = None"""[Deprecated] Prompt to use to translate natural language to SQL."""top_k: int = 5"""Number of results to return from the query"""input_key: str = "query"#: :meta private:output_key: str = "result"#: :meta private:return_intermediate_steps: bool = False"""Whether or not to return the intermediate steps along with the final answer."""return_direct: bool = False"""Whether or not to return the result of querying the SQL table directly."""use_query_checker: bool = False"""Whether or not the query checker tool should be used to attempt to fix the initial SQL from the LLM."""query_checker_prompt: Optional[BasePromptTemplate] = None"""The prompt template that should be used by the query checker"""class Config:"""Configuration for this pydantic object."""extra = Extra.forbidarbitrary_types_allowed = True@root_validator(pre=True)def raise_deprecation(cls, values: Dict) -> Dict:if "llm" in values:warnings.warn("Directly instantiating an SQLDatabaseChain with an llm is deprecated. ""Please instantiate with llm_chain argument or using the from_llm ""class method.")if "llm_chain" not in values and values["llm"] is not None:database = values["database"]prompt = values.get("prompt") or SQL_PROMPTS.get(database.dialect, PROMPT)values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)return values@propertydef input_keys(self) -> List[str]:"""Return the singular input key.:meta private:"""return [self.input_key]@propertydef output_keys(self) -> List[str]:"""Return the singular output key.:meta private:"""if not self.return_intermediate_steps:return [self.output_key]else:return [self.output_key, INTERMEDIATE_STEPS_KEY]def _call(self,inputs: Dict[str, Any],run_manager: Optional[CallbackManagerForChainRun] = None,) -> Dict[str, Any]:_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()input_text = f"{inputs[self.input_key]}\nSQLQuery:"_run_manager.on_text(input_text, verbose=self.verbose)# If not present, then defaults to None which is all tables.table_names_to_use = inputs.get("table_names_to_use")table_info = self.database.get_table_info(table_names=table_names_to_use)llm_inputs = {"input": input_text,"top_k": str(self.top_k),"dialect": self.database.dialect,"table_info": table_info,"stop": ["\nSQLResult:"],}intermediate_steps: List = []try:intermediate_steps.append(llm_inputs)# input: sql generationsql_cmd = self.llm_chain.predict(callbacks=_run_manager.get_child(),**llm_inputs,).strip()if not self.use_query_checker:if sql_cmd:sql_cmd = sql_cmd.split(";")sql_cmd = sql_cmd[0]+";"_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)intermediate_steps.append(sql_cmd)# output: sql generation (no checker)intermediate_steps.append({"sql_cmd": sql_cmd})# input: sql execresult = self.database.run(sql_cmd)if result:my_sec_prompt="Question: "+input_text+sql_cmd+"\n"+"SQLResult:"+result+"\n"+"Answer:"intermediate_steps.append(str(result))# output: sql execselse:query_checker_prompt = self.query_checker_prompt or PromptTemplate(template=QUERY_CHECKER, input_variables=["query", "dialect"])query_checker_chain = LLMChain(llm=self.llm_chain.llm, prompt=query_checker_prompt)query_checker_inputs = {"query": sql_cmd,"dialect": self.database.dialect,}checked_sql_command: str = query_checker_chain.predict(callbacks=_run_manager.get_child(), **query_checker_inputs).strip()intermediate_steps.append(checked_sql_command)# output: sql generation (checker)_run_manager.on_text(checked_sql_command, color="green", verbose=self.verbose)intermediate_steps.append({"sql_cmd": checked_sql_command})# input: sql execresult = self.database.run(checked_sql_command)intermediate_steps.append(str(result))# output: sql execsql_cmd = checked_sql_command_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)_run_manager.on_text(result, color="yellow", verbose=self.verbose)if self.return_direct:final_result = resultelse:_run_manager.on_text("\nAnswer:", verbose=self.verbose)input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"llm_inputs["input"] = input_textintermediate_steps.append(llm_inputs)# input: final answerfinal_result = self.llm_chain.predict(callbacks=_run_manager.get_child(),**llm_inputs,).strip()final_result=final_result.split("\n")[-1]intermediate_steps.append(final_result)# output: final answer_run_manager.on_text(final_result, color="green", verbose=self.verbose)chain_result: Dict[str, Any] = {self.output_key: final_result}if self.return_intermediate_steps:chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps# chain_result=chain_result.split(";")if chain_result['result']:chain_result['result']=chain_result['result'].split("\n")[-1]return chain_resultexcept Exception as exc:# Append intermediate steps to exception, to aid in logging and later# improvement of few shot prompt seedsexc.intermediate_steps = intermediate_steps# type: ignoreraise exc@propertydef _chain_type(self) -> str:return "sql_database_chain"@classmethoddef from_llm(cls,llm: BaseLanguageModel,db: SQLDatabase,prompt: Optional[BasePromptTemplate] = None,**kwargs: Any,) -> SQLDatabaseChain:prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT)llm_chain = LLMChain(llm=llm, prompt=prompt)return cls(llm_chain=llm_chain, database=db, **kwargs)class SQLDatabaseSequentialChain(Chain):"""Chain for querying SQL database that is a sequential chain.The chain is as follows:1. Based on the query, determine which tables to use.2. Based on those tables, call the normal SQL database chain.This is useful in cases where the number of tables in the database is large."""decider_chain: LLMChainsql_chain: SQLDatabaseChaininput_key: str = "query"#: :meta private:output_key: str = "result"#: :meta private:return_intermediate_steps: bool = False@classmethoddef from_llm(cls,llm: BaseLanguageModel,database: SQLDatabase,query_prompt: BasePromptTemplate = PROMPT,decider_prompt: BasePromptTemplate = DECIDER_PROMPT,**kwargs: Any,) -> SQLDatabaseSequentialChain:"""Load the necessary chains."""sql_chain = SQLDatabaseChain.from_llm(llm, database, prompt=query_prompt, **kwargs)decider_chain = LLMChain(llm=llm, prompt=decider_prompt, output_key="table_names")return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs)@propertydef input_keys(self) -> List[str]:"""Return the singular input key.:meta private:"""return [self.input_key]@propertydef output_keys(self) -> List[str]:"""Return the singular output key.:meta private:"""if not self.return_intermediate_steps:return [self.output_key]else:return [self.output_key, INTERMEDIATE_STEPS_KEY]def _call(self,inputs: Dict[str, Any],run_manager: Optional[CallbackManagerForChainRun] = None,) -> Dict[str, Any]:_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()_table_names = self.sql_chain.database.get_usable_table_names()table_names = ", ".join(_table_names)llm_inputs = {"query": inputs[self.input_key],"table_names": table_names,}_lowercased_table_names = [name.lower() for name in _table_names]table_names_from_chain = self.decider_chain.predict_and_parse(**llm_inputs)table_names_to_use = [namefor name in table_names_from_chainif name.lower() in _lowercased_table_names]_run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose)_run_manager.on_text(str(table_names_to_use), color="yellow", verbose=self.verbose)new_inputs = {self.sql_chain.input_key: inputs[self.input_key],"table_names_to_use": table_names_to_use,}return self.sql_chain(new_inputs, callbacks=_run_manager.get_child(), return_only_outputs=True)@propertydef _chain_type(self) -> str:return "sql_database_sequential_chain"

效果如下:
图片[17] - chatglm6b和闻达的功能扩展 - MaxSSL

3.6自定义template与sql数据库交互并用flask前端展示

#1.自定义一个生成sql语句的template,这里需要传入table——info和questionfrom langchain.prompts import PromptTemplatefrom langchain import SQLDatabase, SQLDatabaseChaindb = SQLDatabase.from_uri("mysql+pymysql://root:xxx@localhost/xunlian")db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)table_info=db_chain.database.get_table_info(table_names=None)prompt1 = PromptTemplate.from_template("你是一个SQL专家,现在给你提供一个数据库表单的提示信息:{table_info}\n请根据上述#数据库表单的提示信息,""针对{qusetion},创建一个语法正确的MySQL查询语句,使用LIMIT子句查询最多3个结果,必须将查询语句中的字段使用反引号(`)包括起来,""必须使用数据库表单的提示信息中可见的列名创建MySQL查询语句,查询语句不能出现不存在的列名。请按以下输出示例进行输出:")#2.sql查询#3.将问题,sql语句,查询结果给模型,规定它输出人话prompt2=PromptTemplate.from_template("数据库查询语句是:{first_format}\n数据库查询结果是:{second_format}\n请根据上述查询过程,,回答的内容必须简单明了,必须在30个字以内:{question}")import pymysqlconn = pymysql.connect(host='localhost', user='root', password='xxx', database='xunlian')cursor = conn.cursor()#4.将上面的步骤封装到一个函数conn = pymysql.connect(host='localhost', user='root', password='xxx', database='xunlian')cursor = conn.cursor()table_info=db_chain.database.get_table_info(table_names=None)prompt1 = PromptTemplate.from_template("你是一个SQL专家,现在给你提供一个数据库表单的提示信息:{table_info}\n请根据上述#数据库表单的提示信息,""针对{qusetion},创建一个语法正确的MySQL查询语句,使用LIMIT子句查询最多3个结果,必须将查询语句中的字段使用反引号(`)包括起来,""必须使用数据库表单的提示信息中可见的列名创建MySQL查询语句,查询语句不能出现不存在的列名。请按以下输出示例进行输出:")# print(prompt1.format(table_info=table_info,qusetion="司明苏的单个人员成绩是多少?"))prompt2=PromptTemplate.from_template("数据库查询语句是:{first_format}\n数据库查询结果是:{second_format}\n请根据上述查询过程,,回答的内容必须简单明了,必须在30个字以内:{question}")def my_out(question):first=llm.predict(prompt1.format(table_info=table_info,qusetion=question))first_format=first.split("```")[1][len('sql'):].lstrip()cursor.execute(first_format)second_format= cursor.fetchall()third_format=llm.predict(prompt2.format(first_format=first_format, second_format=second_format,question=question))return third_format# print(my_out("单个人员成绩在70分以上的姓名有谁?"))while True:question = input("请输入一个名词:\n")if question == "结束":breakelse:print(my_out(question))continue

用flask进行交互

#1.index.html<!DOCTYPE html><html><head><title>一问一答</title><style>body {font-family: Arial, sans-serif;text-align: center;}h1 {color: #0080FF;}.container {margin: 50px auto;max-width: 400px;padding: 20px;border: 1px solid #ccc;border-radius: 10px;}.question-input {width: 100%;padding: 10px;margin-bottom: 10px;border: 1px solid #ccc;border-radius: 5px;}.submit-btn {background-color: #0080FF;color: #fff;border: none;padding: 10px 20px;border-radius: 5px;cursor: pointer;}.submit-btn:hover {background-color: #005eff;}.answer {margin-top: 20px;font-weight: bold;}</style></head><body><div class="container"><h1>一问一答</h1><p>请输入您的问题:</p><input type="text" id="question" class="question-input"><button onclick="submitQuestion()" class="submit-btn">提交</button><p class="answer">回答:</p><p id="answer" class="answer"></p></div><script>function submitQuestion() {// 获取用户输入的问题var question = document.getElementById('question').value;// 创建一个FormData对象,用于将数据添加到POST请求中var formData = new FormData();formData.append('question', question);// 发送POST请求fetch('/get_answer', {method: 'POST',body: formData}).then(response => response.text()).then(answer => {// 显示回答document.getElementById('answer').innerText = answer;}).catch(error => {console.error('Error:', error);});}</script></body></html>#2.app.pyfrom flask import Flask, render_template, requestfrom main_testdb import my_outapp = Flask(__name__)@app.route('/')def index():return render_template('index.html')@app.route('/get_answer', methods=['POST'])def get_answer():# 获取前端传递的问题question = request.form['question']# 在这里你可以处理问题并返回相应的答案# 假设你的问题回答逻辑与之前的JavaScript示例相同if question == '你叫什么名字?':answer = '我叫xxx,很高兴为您服务!'elif question == '你会说中文吗?':answer = '是的,我会说中文,还会说多种其他语言。'else:answer = my_out(question)# 返回答案到前端return answerif __name__ == '__main__':app.run(debug=True)

图片[18] - chatglm6b和闻达的功能扩展 - MaxSSL

3.7自定义自己的tool和chain

#1.模型from langchain.llms.base import LLMfrom transformers import AutoTokenizer, AutoModel, AutoConfigimport sysfrom typing import List,Optionalclass ChatGLM(LLM):max_token: int = 8192temperature: float = 0.1top_p = 0.1tokenizer: object = Nonemodel: object = Nonehistory_len: int = 0def __init__(self):super().__init__()@propertydef _llm_type(self) -> str:return "GLM"def load_model(self, llm_device="gpu", model_name_or_path=None):model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True).half().cuda()def _call(self, prompt: str, history: List[str] = [], stop: Optional[List[str]] = None):response, _ = self.model.chat(self.tokenizer, prompt,max_length=self.max_token, temperature=self.temperature,top_p=self.top_p)return responsemodelpath = r"C:\Users\robot\Desktop\wenda-main\wenda-main\model\chatglm2-6b"sys.path.append(modelpath)# print(modelpath)llm = ChatGLM()llm.load_model(model_name_or_path=modelpath)#2.toolsimport cv2def catch_video(name='my_video', video_index=0):cap = cv2.VideoCapture(video_index) # 创建摄像头识别类if not cap.isOpened():raise Exception('Check if the camera is on.')while cap.isOpened():catch, frame = cap.read()# 读取每一帧图片cv2.imshow(name, frame) # 在window上显示图片key = cv2.waitKey(10)if key & 0xFF == ord('q'):# 按q退出breakif cv2.getWindowProperty(name, cv2.WND_PROP_AUTOSIZE) < 1:# 点x退出break# 释放摄像头cap.release()cv2.destroyAllWindows()from langchain import SQLDatabase, SQLDatabaseChaindb = SQLDatabase.from_uri("mysql+pymysql://root:w19891207@localhost/xunlian")db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)import pymysqlconn = pymysql.connect(host='localhost', user='root', password='w19891207', database='xunlian')cursor = conn.cursor()from langchain.prompts import PromptTemplatetable_info="{"+db_chain.database.get_table_info(table_names=None)+"\n}"table_info = table_info.split('\n')table_info = ['#' + line for line in table_info]table_info= '\n'.join(table_info)# print(table_info)prompt_sql = PromptTemplate.from_template("""你是一个SQL专家,现在给你提供一个数据库表单的提示信息。\n#数据库表单的提示信息包括:{table_info}请根据上述#数据库表单的提示信息,针对"用户问题",创建一个语法正确的MySQL查询语句,使用LIMIT子句查询最多3个结果,必须将查询语句中的字段使用反引号(`)包括起来,必须使用"数据库表单的提示信息"中可见的列名创建MySQL查询语句,查询语句不能出现不存在的列名。请按以下#输出示例进行输出:#输出示例#用户问题:张五的岗位是什么?#MySQL查询语句:SELECT 岗位 FROM 总表 WHERE 姓名 = `张五`;现在我们开始:用户问题:{question}MySQL查询语句:""")prompt2=PromptTemplate.from_template("""数据库查询语句是:{first_format}\n数据库查询结果是:{second_format}\n请根据上述查询过程进行回答,回答的内容必须简单明了,必须在30个字以内:{question}""")def my_out(question):# print(table_info)first=llm.predict(prompt_sql.format(table_info=table_info, question=question))cursor.execute(first)second_format= cursor.fetchall()third_format=llm.predict(prompt2.format(first_format=first, second_format=second_format,question=question))return third_formatfrom langchain.prompts import PromptTemplateprompt1=PromptTemplate.from_template("""Chatbot单选题,以下哪个工具可以完成问题或任务#{question}()#A:《数据库》工具,该工具用于访问训练数据库,对人员姓名、年龄训练成绩进行查询。B:《作图》工具,绘图。C:《监控视频》工具,该工具用于调用指定地点的摄像头、监控、视频、态势进行查看。D:《备选》工具,该工具用于对措施、计划、管理、评估等定性问题进行回答,或其他不适合任何工具的情况完成任务或问题。""")def use_mytools(question):choice = llm.predict(prompt1.format(question=question))print(choice)if "C" in choice:catch_video(question)elif "A" in choice:print(my_out(question))else:print(llm.predict(question))# use_mytools("张国立的年龄是多少" />use_mytools("调用摄像头")
© 版权声明
THE END
喜欢就支持一下吧
点赞0 分享