Unverified 提交 2987c9cd 作者: akou 提交者: GitHub

增加允许跨域调用API功能 (#279)

上级 e1c56edb
...@@ -11,12 +11,14 @@ import pydantic ...@@ -11,12 +11,14 @@ import pydantic
import uvicorn import uvicorn
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Annotated from typing_extensions import Annotated
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from chains.local_doc_qa import LocalDocQA from chains.local_doc_qa import LocalDocQA
from configs.model_config import (VS_ROOT_PATH, EMBEDDING_DEVICE, EMBEDDING_MODEL, LLM_MODEL, UPLOAD_ROOT_PATH, from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
NLTK_DATA_PATH, VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN) EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
...@@ -310,6 +312,17 @@ def main(): ...@@ -310,6 +312,17 @@ def main():
args = parser.parse_args() args = parser.parse_args()
app = FastAPI() app = FastAPI()
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat) app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat)
app.post("/chat-docs/chat", response_model=ChatMessage)(chat) app.post("/chat-docs/chat", response_model=ChatMessage)(chat)
app.post("/chat-docs/chatno", response_model=ChatMessage)(no_knowledge_chat) app.post("/chat-docs/chatno", response_model=ChatMessage)(no_knowledge_chat)
......
...@@ -83,3 +83,7 @@ embedding device: {EMBEDDING_DEVICE} ...@@ -83,3 +83,7 @@ embedding device: {EMBEDDING_DEVICE}
dir: {os.path.dirname(os.path.dirname(__file__))} dir: {os.path.dirname(os.path.dirname(__file__))}
flagging username: {FLAG_USER_NAME} flagging username: {FLAG_USER_NAME}
""") """)
# 是否开启跨域,默认为False,如果需要开启,请设置为True
# is open cross domain
OPEN_CROSS_DOMAIN = False
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论