提交 fa431b90 作者: imClumsyPanda

Merge remote-tracking branch 'origin/dev' into dev

......@@ -85,6 +85,7 @@ def get_vs_path(local_doc_id: str):
def get_file_path(local_doc_id: str, doc_name: str):
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
async def upload_file(
file: UploadFile = File(description="A single binary file"),
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
......@@ -112,6 +113,7 @@ async def upload_file(
file_status = "文件上传失败,请重新上传"
return BaseResponse(code=500, msg=file_status)
async def upload_files(
files: Annotated[
List[UploadFile], File(description="Multiple files as UploadFile")
......@@ -265,6 +267,10 @@ async def chat(
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
await websocket.accept()
turn = 1
while True:
input_json = await websocket.receive_json()
question, history, knowledge_base_id = input_json[""], input_json["history"], input_json["knowledge_base_id"]
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
if not os.path.exists(vs_path):
......@@ -272,10 +278,6 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
await websocket.close()
return
history = []
turn = 1
while True:
question = await websocket.receive_text()
await websocket.send_json({"question": question, "turn": turn, "flag": "start"})
last_print_len = 0
......@@ -304,18 +306,13 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
)
turn += 1
async def document():
return RedirectResponse(url="/docs")
def main():
def api_start(host, port):
global app
global local_doc_qa
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
args = parser.parse_args()
app = FastAPI()
# Add CORS middleware to allow all origins
......@@ -341,7 +338,6 @@ def main():
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(
llm_model=LLM_MODEL,
......@@ -350,8 +346,12 @@ def main():
llm_history_len=LLM_HISTORY_LEN,
top_k=VECTOR_SEARCH_TOP_K,
)
uvicorn.run(app, host=args.host, port=args.port)
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
args = parser.parse_args()
api_start(args.host, args.port)
@echo off
python cli.py %*
import click
from api import api_start as api_start
from configs.model_config import llm_model_dict, embedding_model_dict
@click.group()
@click.version_option(version='1.0.0')
@click.pass_context
def cli(ctx):
pass
@cli.group()
def llm():
pass
@llm.command(name="ls")
def llm_ls():
for k in llm_model_dict.keys():
print(k)
@cli.group()
def embedding():
pass
@embedding.command(name="ls")
def embedding_ls():
for k in embedding_model_dict.keys():
print(k)
@cli.group()
def start():
pass
@start.command(name="api", context_settings=dict(help_option_names=['-h', '--help']))
@click.option('-i', '--ip', default='0.0.0.0', show_default=True, type=str, help='api_server listen address.')
@click.option('-p', '--port', default=7861, show_default=True, type=int, help='api_server listen port.')
def start_api(ip, port):
api_start(host=ip, port=port)
@start.command(name="cli", context_settings=dict(help_option_names=['-h', '--help']))
def start_cli():
import cli_demo
cli_demo.main()
@start.command(name="webui", context_settings=dict(help_option_names=['-h', '--help']))
def start_webui():
import webui
cli()
#!/bin/bash
python cli.py "$@"
......@@ -8,7 +8,8 @@ nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
# Show reply with source text from input document
REPLY_WITH_SOURCE = True
if __name__ == "__main__":
def main():
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(llm_model=LLM_MODEL,
embedding_model=EMBEDDING_MODEL,
......@@ -41,3 +42,7 @@ if __name__ == "__main__":
for inum, doc in
enumerate(resp["source_documents"])]
print("\n\n" + "\n\n".join(source_text))
if __name__ == "__main__":
main()
## 命令行工具
windows cli.bat
linux cli.sh
## 命令列表
### llm 管理
llm 支持列表
```shell
cli.bat llm ls
```
### embedding 管理
embedding 支持列表
```shell
cli.bat embedding ls
```
### start 启动管理
查看启动选择
```shell
cli.bat start
```
启动命令行交互
```shell
cli.bat start cli
```
启动Web 交互
```shell
cli.bat start webui
```
启动api服务
```shell
cli.bat start api
```
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论