提交 07ff81a1 作者: imClumsyPanda

update torch_gc

上级 b03634fb
...@@ -16,16 +16,10 @@ from typing_extensions import Annotated ...@@ -16,16 +16,10 @@ from typing_extensions import Annotated
from chains.local_doc_qa import LocalDocQA from chains.local_doc_qa import LocalDocQA
from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE, from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
EMBEDDING_MODEL, LLM_MODEL) EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN)
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 6
# LLM input history length
LLM_HISTORY_LEN = 3
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code") code: int = pydantic.Field(200, description="HTTP status code")
......
...@@ -10,11 +10,6 @@ from langchain.docstore.document import Document ...@@ -10,11 +10,6 @@ from langchain.docstore.document import Document
import numpy as np import numpy as np
from utils import torch_gc from utils import torch_gc
# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 6
# LLM input history length
LLM_HISTORY_LEN = 3
DEVICE_ = EMBEDDING_DEVICE DEVICE_ = EMBEDDING_DEVICE
DEVICE_ID = "0" if torch.cuda.is_available() else None DEVICE_ID = "0" if torch.cuda.is_available() else None
...@@ -109,7 +104,7 @@ def similarity_search_with_score_by_vector( ...@@ -109,7 +104,7 @@ def similarity_search_with_score_by_vector(
if not isinstance(doc, Document): if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}") raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append((doc, scores[0][j])) docs.append((doc, scores[0][j]))
torch_gc(DEVICE) torch_gc()
return docs return docs
...@@ -181,13 +176,13 @@ class LocalDocQA: ...@@ -181,13 +176,13 @@ class LocalDocQA:
if vs_path and os.path.isdir(vs_path): if vs_path and os.path.isdir(vs_path):
vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store = FAISS.load_local(vs_path, self.embeddings)
vector_store.add_documents(docs) vector_store.add_documents(docs)
torch_gc(DEVICE) torch_gc()
else: else:
if not vs_path: if not vs_path:
vs_path = os.path.join(VS_ROOT_PATH, vs_path = os.path.join(VS_ROOT_PATH,
f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""") f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
vector_store = FAISS.from_documents(docs, self.embeddings) vector_store = FAISS.from_documents(docs, self.embeddings)
torch_gc(DEVICE) torch_gc()
vector_store.save_local(vs_path) vector_store.save_local(vs_path)
return vs_path, loaded_files return vs_path, loaded_files
...@@ -206,6 +201,7 @@ class LocalDocQA: ...@@ -206,6 +201,7 @@ class LocalDocQA:
related_docs_with_score = vector_store.similarity_search_with_score(query, related_docs_with_score = vector_store.similarity_search_with_score(query,
k=self.top_k) k=self.top_k)
related_docs = get_docs_with_score(related_docs_with_score) related_docs = get_docs_with_score(related_docs_with_score)
torch_gc()
prompt = generate_prompt(related_docs, query) prompt = generate_prompt(related_docs, query)
# if streaming: # if streaming:
...@@ -220,11 +216,13 @@ class LocalDocQA: ...@@ -220,11 +216,13 @@ class LocalDocQA:
for result, history in self.llm._call(prompt=prompt, for result, history in self.llm._call(prompt=prompt,
history=chat_history, history=chat_history,
streaming=streaming): streaming=streaming):
torch_gc()
history[-1][0] = query history[-1][0] = query
response = {"query": query, response = {"query": query,
"result": result, "result": result,
"source_documents": related_docs} "source_documents": related_docs}
yield response, history yield response, history
torch_gc()
if __name__ == "__main__": if __name__ == "__main__":
...@@ -244,9 +242,4 @@ if __name__ == "__main__": ...@@ -244,9 +242,4 @@ if __name__ == "__main__":
for inum, doc in for inum, doc in
enumerate(resp["source_documents"])] enumerate(resp["source_documents"])]
print("\n\n" + "\n\n".join(source_text)) print("\n\n" + "\n\n".join(source_text))
# for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
# vs_path=vs_path,
# chat_history=[],
# streaming=False):
# print(resp["result"])
pass pass
...@@ -3,13 +3,7 @@ from chains.local_doc_qa import LocalDocQA ...@@ -3,13 +3,7 @@ from chains.local_doc_qa import LocalDocQA
import os import os
import nltk import nltk
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 6
# LLM input history length
LLM_HISTORY_LEN = 3
# Show reply with source text from input document # Show reply with source text from input document
REPLY_WITH_SOURCE = True REPLY_WITH_SOURCE = True
......
...@@ -49,4 +49,12 @@ PROMPT_TEMPLATE = """已知信息: ...@@ -49,4 +49,12 @@ PROMPT_TEMPLATE = """已知信息:
根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}""" 根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
# 匹配后单段上下文长度 # 匹配后单段上下文长度
CHUNK_SIZE = 500 CHUNK_SIZE = 250
# LLM input history length
LLM_HISTORY_LEN = 3
# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 5
NLTK_DATA_PATH = os.path.join(os.path.dirname(__file__), "nltk_data")
\ No newline at end of file
...@@ -69,12 +69,13 @@ class ChatGLM(LLM): ...@@ -69,12 +69,13 @@ class ChatGLM(LLM):
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
)): )):
torch_gc(DEVICE) torch_gc()
if inum == 0: if inum == 0:
history += [[prompt, stream_resp]] history += [[prompt, stream_resp]]
else: else:
history[-1] = [prompt, stream_resp] history[-1] = [prompt, stream_resp]
yield stream_resp, history yield stream_resp, history
torch_gc()
else: else:
response, _ = self.model.chat( response, _ = self.model.chat(
self.tokenizer, self.tokenizer,
...@@ -83,9 +84,10 @@ class ChatGLM(LLM): ...@@ -83,9 +84,10 @@ class ChatGLM(LLM):
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
) )
torch_gc(DEVICE) torch_gc()
history += [[prompt, response]] history += [[prompt, response]]
yield response, history yield response, history
torch_gc()
# def chat(self, # def chat(self,
# prompt: str) -> str: # prompt: str) -> str:
......
import torch import torch
def torch_gc(DEVICE): def torch_gc():
if torch.cuda.is_available(): if torch.cuda.is_available():
with torch.cuda.device(DEVICE): # with torch.cuda.device(DEVICE):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
......
...@@ -5,13 +5,7 @@ from chains.local_doc_qa import LocalDocQA ...@@ -5,13 +5,7 @@ from chains.local_doc_qa import LocalDocQA
from configs.model_config import * from configs.model_config import *
import nltk import nltk
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 6
# LLM input history length
LLM_HISTORY_LEN = 3
def get_vs_list(): def get_vs_list():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论