diff --git a/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_query.py b/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_query.py index ff5e747e208cdbba62adbb1346b4a0ae1e6a2b42..60982e54f63f639f556d22c6e3eff2f90a2f0ac1 100644 --- a/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_query.py +++ b/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_query.py @@ -5,6 +5,7 @@ import argparse import threading import traceback from loguru import logger +from paddle.base import libpaddle from mx_rag.chain import SingleText2TextChain from mx_rag.embedding.local import TextEmbedding from mx_rag.embedding.service import TEIEmbedding @@ -15,7 +16,7 @@ from mx_rag.retrievers import Retriever from mx_rag.storage.document_store import SQLiteDocstore from mx_rag.storage.vectorstore import MindFAISS from mx_rag.utils import ClientParam -from paddle.base import libpaddle + class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): @@ -28,6 +29,10 @@ class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): class ThreadWithResult(threading.Thread): def __init__(self, group=None, target=None, name=None, args=None, kwargs=None, *, daemon=None): + if args is None: + args = () + if kwargs is None: + kwargs = {} def function(): self.result = target(*args, **kwargs)