Unverified 提交 a0cb14de 作者: zqt996 提交者: GitHub

添加命令行管理脚本 (#355)

* 添加加命令行工具

* 添加加命令行工具

---------

Co-authored-by: zqt <1178747941@qq.com>
Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
上级 1678392c
...@@ -85,6 +85,7 @@ def get_vs_path(local_doc_id: str): ...@@ -85,6 +85,7 @@ def get_vs_path(local_doc_id: str):
def get_file_path(local_doc_id: str, doc_name: str): def get_file_path(local_doc_id: str, doc_name: str):
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name) return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
async def upload_file( async def upload_file(
file: UploadFile = File(description="A single binary file"), file: UploadFile = File(description="A single binary file"),
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
...@@ -112,6 +113,7 @@ async def upload_file( ...@@ -112,6 +113,7 @@ async def upload_file(
file_status = "文件上传失败,请重新上传" file_status = "文件上传失败,请重新上传"
return BaseResponse(code=500, msg=file_status) return BaseResponse(code=500, msg=file_status)
async def upload_files( async def upload_files(
files: Annotated[ files: Annotated[
List[UploadFile], File(description="Multiple files as UploadFile") List[UploadFile], File(description="Multiple files as UploadFile")
...@@ -308,13 +310,9 @@ async def document(): ...@@ -308,13 +310,9 @@ async def document():
return RedirectResponse(url="/docs") return RedirectResponse(url="/docs")
def main(): def api_start(host, port):
global app global app
global local_doc_qa global local_doc_qa
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
args = parser.parse_args()
app = FastAPI() app = FastAPI()
# Add CORS middleware to allow all origins # Add CORS middleware to allow all origins
...@@ -340,7 +338,6 @@ def main(): ...@@ -340,7 +338,6 @@ def main():
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs) app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs) app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
local_doc_qa = LocalDocQA() local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg( local_doc_qa.init_cfg(
llm_model=LLM_MODEL, llm_model=LLM_MODEL,
...@@ -349,8 +346,12 @@ def main(): ...@@ -349,8 +346,12 @@ def main():
llm_history_len=LLM_HISTORY_LEN, llm_history_len=LLM_HISTORY_LEN,
top_k=VECTOR_SEARCH_TOP_K, top_k=VECTOR_SEARCH_TOP_K,
) )
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=host, port=port)
if __name__ == "__main__": if __name__ == "__main__":
main() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
args = parser.parse_args()
api_start(args.host, args.port)
@echo off
python cli.py %*
import click
from api import api_start as api_start
from configs.model_config import llm_model_dict, embedding_model_dict
@click.group()
@click.version_option(version='1.0.0')
@click.pass_context
def cli(ctx):
pass
@cli.group()
def llm():
pass
@llm.command(name="ls")
def llm_ls():
for k in llm_model_dict.keys():
print(k)
@cli.group()
def embedding():
pass
@embedding.command(name="ls")
def embedding_ls():
for k in embedding_model_dict.keys():
print(k)
@cli.group()
def start():
pass
@start.command(name="api", context_settings=dict(help_option_names=['-h', '--help']))
@click.option('-i', '--ip', default='0.0.0.0', show_default=True, type=str, help='api_server listen address.')
@click.option('-p', '--port', default=7861, show_default=True, type=int, help='api_server listen port.')
def start_api(ip, port):
api_start(host=ip, port=port)
@start.command(name="cli", context_settings=dict(help_option_names=['-h', '--help']))
def start_cli():
import cli_demo
cli_demo.main()
@start.command(name="webui", context_settings=dict(help_option_names=['-h', '--help']))
def start_webui():
import webui
cli()
#!/bin/bash
python cli.py "$@"
...@@ -8,7 +8,8 @@ nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path ...@@ -8,7 +8,8 @@ nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
# Show reply with source text from input document # Show reply with source text from input document
REPLY_WITH_SOURCE = True REPLY_WITH_SOURCE = True
if __name__ == "__main__":
def main():
local_doc_qa = LocalDocQA() local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(llm_model=LLM_MODEL, local_doc_qa.init_cfg(llm_model=LLM_MODEL,
embedding_model=EMBEDDING_MODEL, embedding_model=EMBEDDING_MODEL,
...@@ -41,3 +42,7 @@ if __name__ == "__main__": ...@@ -41,3 +42,7 @@ if __name__ == "__main__":
for inum, doc in for inum, doc in
enumerate(resp["source_documents"])] enumerate(resp["source_documents"])]
print("\n\n" + "\n\n".join(source_text)) print("\n\n" + "\n\n".join(source_text))
if __name__ == "__main__":
main()
## 命令行工具
windows cli.bat
linux cli.sh
## 命令列表
### llm 管理
llm 支持列表
```shell
cli.bat llm ls
```
### embedding 管理
embedding 支持列表
```shell
cli.bat embedding ls
```
### start 启动管理
查看启动选择
```shell
cli.bat start
```
启动命令行交互
```shell
cli.bat start cli
```
启动Web 交互
```shell
cli.bat start webui
```
启动api服务
```shell
cli.bat start api
```
...@@ -18,4 +18,5 @@ uvicorn ...@@ -18,4 +18,5 @@ uvicorn
peft peft
pypinyin pypinyin
bitsandbytes bitsandbytes
click~=8.1.3
tabulate tabulate
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论