Unverified 提交 dd938373 作者: zhenkaivip 提交者: GitHub

使用paddleocr实现 (#342)

* jpg and png ocr

* fix

* write docs to tmp file

* fix

* [BUGFIX] local_doc_qa.py line 172: logging have no end args. (#323)

* image loader

* fix

* fix

* update api.py

* update api.py

* update api.py

* update README.md

* update api.py

* add pdf_loader

* fix

---------

Co-authored-by: RainGather <3255329+RainGather@users.noreply.github.com>
Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
上级 dcf6e4ff
...@@ -207,6 +207,6 @@ Web UI 可以实现如下功能: ...@@ -207,6 +207,6 @@ Web UI 可以实现如下功能:
- [ ] 实现调用 API 的 Web UI Demo - [ ] 实现调用 API 的 Web UI Demo
## 项目交流群 ## 项目交流群
![二维码](img/qr_code_16.jpg) ![二维码](img/qr_code_17.jpg)
🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
...@@ -22,6 +22,7 @@ from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVI ...@@ -22,6 +22,7 @@ from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVI
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code") code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field("success", description="HTTP status message") msg: str = pydantic.Field("success", description="HTTP status message")
...@@ -87,7 +88,7 @@ def get_vs_path(local_doc_id: str): ...@@ -87,7 +88,7 @@ def get_vs_path(local_doc_id: str):
def get_file_path(local_doc_id: str, doc_name: str): def get_file_path(local_doc_id: str, doc_name: str):
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name) return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
async def single_upload_file( async def upload_file(
file: UploadFile = File(description="A single binary file"), file: UploadFile = File(description="A single binary file"),
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
): ):
...@@ -106,21 +107,15 @@ async def single_upload_file( ...@@ -106,21 +107,15 @@ async def single_upload_file(
f.write(file_content) f.write(file_content)
vs_path = get_vs_path(knowledge_base_id) vs_path = get_vs_path(knowledge_base_id)
if os.path.exists(vs_path): vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store([file_path], vs_path)
added_files = await local_doc_qa.add_files_to_knowledge_vector_store(vs_path, [file_path])
if len(added_files) > 0:
file_status = f"文件 {file.filename} 已上传并已加载知识库,请开始提问。"
return BaseResponse(code=200, msg=file_status)
else:
vs_path, loaded_files = await local_doc_qa.init_knowledge_vector_store([file_path], vs_path)
if len(loaded_files) > 0: if len(loaded_files) > 0:
file_status = f"文件 {file.filename} 已上传至新的知识库,并已加载知识库,请开始提问。" file_status = f"文件 {file.filename} 已上传至新的知识库,并已加载知识库,请开始提问。"
return BaseResponse(code=200, msg=file_status) return BaseResponse(code=200, msg=file_status)
else:
file_status = "文件上传失败,请重新上传" file_status = "文件上传失败,请重新上传"
return BaseResponse(code=500, msg=file_status) return BaseResponse(code=500, msg=file_status)
async def upload_file( async def upload_files(
files: Annotated[ files: Annotated[
List[UploadFile], File(description="Multiple files as UploadFile") List[UploadFile], File(description="Multiple files as UploadFile")
], ],
...@@ -203,7 +198,7 @@ async def delete_docs( ...@@ -203,7 +198,7 @@ async def delete_docs(
return BaseResponse() return BaseResponse()
async def chat( async def local_doc_chat(
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
question: str = Body(..., description="Question", example="工伤保险是什么?"), question: str = Body(..., description="Question", example="工伤保险是什么?"),
history: List[List[str]] = Body( history: List[List[str]] = Body(
...@@ -238,7 +233,8 @@ async def chat( ...@@ -238,7 +233,8 @@ async def chat(
source_documents=source_documents, source_documents=source_documents,
) )
async def no_knowledge_chat(
async def chat(
question: str = Body(..., description="Question", example="工伤保险是什么?"), question: str = Body(..., description="Question", example="工伤保险是什么?"),
history: List[List[str]] = Body( history: List[List[str]] = Body(
[], [],
...@@ -251,12 +247,19 @@ async def no_knowledge_chat( ...@@ -251,12 +247,19 @@ async def no_knowledge_chat(
], ],
), ),
): ):
for resp, history in local_doc_qa.llm._call(
for resp, history in local_doc_qa._call( prompt=question, history=history, streaming=True
query=question, chat_history=history, streaming=True
): ):
pass pass
return ChatMessage(
question=question,
response=resp,
history=history,
source_documents=[],
)
async def stream_chat(websocket: WebSocket, knowledge_base_id: str): async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
await websocket.accept() await websocket.accept()
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
...@@ -323,15 +326,19 @@ def main(): ...@@ -323,15 +326,19 @@ def main():
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat) app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat)
app.post("/chat-docs/chat", response_model=ChatMessage)(chat)
app.post("/chat-docs/chatno", response_model=ChatMessage)(no_knowledge_chat)
app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file)
app.post("/chat-docs/uploadone", response_model=BaseResponse)(single_upload_file)
app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs)
app.delete("/chat-docs/delete", response_model=BaseResponse)(delete_docs)
app.get("/", response_model=BaseResponse)(document) app.get("/", response_model=BaseResponse)(document)
app.post("/chat", response_model=ChatMessage)(chat)
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
local_doc_qa = LocalDocQA() local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg( local_doc_qa.init_cfg(
llm_model=LLM_MODEL, llm_model=LLM_MODEL,
......
...@@ -10,6 +10,8 @@ import numpy as np ...@@ -10,6 +10,8 @@ import numpy as np
from utils import torch_gc from utils import torch_gc
from tqdm import tqdm from tqdm import tqdm
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin
from loader import UnstructuredPaddleImageLoader
from loader import UnstructuredPaddlePDFLoader
DEVICE_ = EMBEDDING_DEVICE DEVICE_ = EMBEDDING_DEVICE
DEVICE_ID = "0" if torch.cuda.is_available() else None DEVICE_ID = "0" if torch.cuda.is_available() else None
...@@ -21,16 +23,31 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE): ...@@ -21,16 +23,31 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE):
loader = UnstructuredFileLoader(filepath, mode="elements") loader = UnstructuredFileLoader(filepath, mode="elements")
docs = loader.load() docs = loader.load()
elif filepath.lower().endswith(".pdf"): elif filepath.lower().endswith(".pdf"):
loader = UnstructuredFileLoader(filepath, strategy="fast") loader = UnstructuredPaddlePDFLoader(filepath)
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size) textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
docs = loader.load_and_split(textsplitter) docs = loader.load_and_split(textsplitter)
elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"):
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(text_splitter=textsplitter)
else: else:
loader = UnstructuredFileLoader(filepath, mode="elements") loader = UnstructuredFileLoader(filepath, mode="elements")
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(text_splitter=textsplitter) docs = loader.load_and_split(text_splitter=textsplitter)
write_check_file(filepath, docs)
return docs return docs
def write_check_file(filepath, docs):
fout = open('load_file.txt', 'a')
fout.write("filepath=%s,len=%s" % (filepath, len(docs)))
fout.write('\n')
for i in docs:
fout.write(str(i))
fout.write('\n')
fout.close()
def generate_prompt(related_docs: List[str], query: str, def generate_prompt(related_docs: List[str], query: str,
prompt_template=PROMPT_TEMPLATE) -> str: prompt_template=PROMPT_TEMPLATE) -> str:
context = "\n".join([doc.page_content for doc in related_docs]) context = "\n".join([doc.page_content for doc in related_docs])
...@@ -176,7 +193,7 @@ class LocalDocQA: ...@@ -176,7 +193,7 @@ class LocalDocQA:
if len(failed_files) > 0: if len(failed_files) > 0:
logger.info("以下文件未能成功加载:") logger.info("以下文件未能成功加载:")
for file in failed_files: for file in failed_files:
logger.info(file, end="\n") logger.info(f"{file}\n")
else: else:
docs = [] docs = []
...@@ -212,7 +229,7 @@ class LocalDocQA: ...@@ -212,7 +229,7 @@ class LocalDocQA:
if not vs_path or not one_title or not one_conent: if not vs_path or not one_title or not one_conent:
logger.info("知识库添加错误,请确认知识库名字、标题、内容是否正确!") logger.info("知识库添加错误,请确认知识库名字、标题、内容是否正确!")
return None, [one_title] return None, [one_title]
docs = [Document(page_content=one_conent+"\n", metadata={"source": one_title})] docs = [Document(page_content=one_conent + "\n", metadata={"source": one_title})]
if not one_content_segmentation: if not one_content_segmentation:
text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = text_splitter.split_documents(docs) docs = text_splitter.split_documents(docs)
......
from .image_loader import UnstructuredPaddleImageLoader
from .pdf_loader import UnstructuredPaddlePDFLoader
"""Loader that loads image files."""
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from paddleocr import PaddleOCR
import os
class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
def _get_elements(self) -> List:
def image_ocr_txt(filepath, dir_path="tmp_files"):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
filename = os.path.split(filepath)[-1]
ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False)
result = ocr.ocr(img=filepath)
ocr_result = [i[1][0] for line in result for i in line]
txt_file_path = os.path.join(dir_path, "%s.txt" % (filename))
with open(txt_file_path, 'w', encoding='utf-8') as fout:
fout.write("\n".join(ocr_result))
return txt_file_path
txt_file_path = image_ocr_txt(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
"""Loader that loads image files."""
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from paddleocr import PaddleOCR
import os
import fitz
class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
def _get_elements(self) -> List:
def pdf_ocr_txt(filepath, dir_path="tmp_files"):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
filename = os.path.split(filepath)[-1]
ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False)
doc = fitz.open(filepath)
txt_file_path = os.path.join(dir_path, "%s.txt" % (filename))
img_name = './img/.tmp.png'
with open(txt_file_path, 'w', encoding='utf-8') as fout:
for i in range(doc.page_count):
page = doc[i]
text = page.get_text("")
fout.write(text)
fout.write("\n")
img_list = page.get_images()
for img in img_list:
pix = fitz.Pixmap(doc, img[0])
pix.save(img_name)
result = ocr.ocr(img_name)
ocr_result = [i[1][0] for line in result for i in line]
fout.write("\n".join(ocr_result))
os.remove(img_name)
return txt_file_path
txt_file_path = pdf_ocr_txt(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
pymupdf
paddlepaddle==2.4.2
paddleocr
langchain==0.0.146 langchain==0.0.146
transformers==4.27.1 transformers==4.27.1
unstructured[local-inference] unstructured[local-inference]
......
from configs.model_config import *
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
filepath = "./img/test.jpg"
from loader import UnstructuredPaddleImageLoader
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
docs = loader.load()
for doc in docs:
print(doc)
from configs.model_config import *
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
filepath = "docs/test.pdf"
from loader import UnstructuredPaddlePDFLoader
loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
docs = loader.load()
for doc in docs:
print(doc)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论