From 728fed7ec3d84c724e05c231544a0ab405981595 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A9=BA=E4=B8=AD=E6=99=BA=E5=9B=8A?= Date: Wed, 4 Sep 2024 18:25:26 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=A2=9E=E5=8A=A0ES=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../i_search_dataset_step.py | 1 + .../serializers/application_serializers.py | 11 +- apps/common/config/embedding_config.py | 28 +- apps/common/event/listener_manage.py | 1 - apps/embedding/models/embedding.py | 14 + apps/embedding/vector/es_vector.py | 330 +++++++++++++++++ .../models_provider/base_model_provider.py | 2 - apps/setting/models_provider/impl/base_stt.py | 14 - apps/setting/models_provider/impl/base_tts.py | 14 - .../local_model_provider/model/embedding.py | 1 + .../openai_model_provider/credential/stt.py | 42 --- .../impl/openai_model_provider/model/stt.py | 53 --- .../impl/openai_model_provider/model/tts.py | 58 --- .../openai_model_provider.py | 12 +- .../credential/stt.py | 45 --- .../credential/tts.py | 45 --- .../model/stt.py | 345 ------------------ .../model/tts.py | 170 --------- .../volcanic_engine_model_provider.py | 18 +- .../impl/xf_model_provider/credential/stt.py | 46 --- .../impl/xf_model_provider/credential/tts.py | 46 --- .../impl/xf_model_provider/model/stt.py | 169 --------- .../impl/xf_model_provider/model/tts.py | 142 ------- .../xf_model_provider/xf_model_provider.py | 17 +- .../serializers/model_apply_serializers.py | 10 +- apps/smartdoc/const.py | 2 +- 26 files changed, 389 insertions(+), 1247 deletions(-) create mode 100644 apps/embedding/vector/es_vector.py delete mode 100644 apps/setting/models_provider/impl/base_stt.py delete mode 100644 apps/setting/models_provider/impl/base_tts.py delete mode 100644 apps/setting/models_provider/impl/openai_model_provider/credential/stt.py delete mode 100644 apps/setting/models_provider/impl/openai_model_provider/model/stt.py delete mode 100644 apps/setting/models_provider/impl/openai_model_provider/model/tts.py delete mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py delete mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py delete mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py delete mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py delete mode 100644 apps/setting/models_provider/impl/xf_model_provider/credential/stt.py delete mode 100644 apps/setting/models_provider/impl/xf_model_provider/credential/tts.py delete mode 100644 apps/setting/models_provider/impl/xf_model_provider/model/stt.py delete mode 100644 apps/setting/models_provider/impl/xf_model_provider/model/tts.py diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py index 97da2964..1c472f29 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -51,6 +51,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): def _run(self, manage: PipelineManage): paragraph_list = self.execute(**self.context['step_args']) manage.context['paragraph_list'] = paragraph_list + print("查询结果:",paragraph_list) self.context['paragraph_list'] = paragraph_list @abstractmethod diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index eecfc6a0..e1c633f7 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -461,10 +461,17 @@ class ApplicationSerializer(serializers.Serializer): self.data.get('similarity'), SearchMode(self.data.get('search_mode')), model) + # 获取段落标识 hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) + # 查询分段信息表 p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) - return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), - 'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list] + + print("p_list---->", p_list) + print("hit_dict---->", hit_dict) + return [{**p, + 'similarity': hit_dict.get(p.get('id')).get('similarity'), + 'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score') + } for p in p_list] class Query(serializers.Serializer): name = serializers.CharField(required=False, error_messages=ErrMessage.char("应用名称")) diff --git a/apps/common/config/embedding_config.py b/apps/common/config/embedding_config.py index a6e9ab9a..2fb9062c 100644 --- a/apps/common/config/embedding_config.py +++ b/apps/common/config/embedding_config.py @@ -48,19 +48,31 @@ class ModelManage: class VectorStore: - from embedding.vector.pg_vector import PGVector from embedding.vector.base_vector import BaseVectorStore - instance_map = { - 'pg_vector': PGVector, - } instance = None @staticmethod def get_embedding_vector() -> BaseVectorStore: - from embedding.vector.pg_vector import PGVector + if VectorStore.instance is None: from smartdoc.const import CONFIG - vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"), - PGVector) - VectorStore.instance = vector_store_class() + store_name = CONFIG.get("VECTOR_STORE_NAME") + if store_name == 'es_vector': + from embedding.vector.es_vector import ESVector + from elasticsearch import Elasticsearch + + es_host = CONFIG.get("VECTOR_STORE_ES_ENDPOINT") + es_username = CONFIG.get("VECTOR_STORE_ES_USERNAME") + es_password = CONFIG.get("VECTOR_STORE_ES_PASSWORD") + es_indexname = CONFIG.get("VECTOR_STORE_ES_INDEX") + http_auth = (es_username, es_password) + VectorStore.instance = ESVector(Elasticsearch(hosts=[es_host], + http_auth=http_auth) + , es_indexname) + if not VectorStore.instance.vector_is_create(): + VectorStore.instance.vector_create() + else: + from embedding.vector.pg_vector import PGVector + VectorStore.instance = PGVector() + return VectorStore.instance diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 537677ba..6b8d4ba4 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -127,7 +127,6 @@ class ListenerManagement: def is_save_function(): return QuerySet(Paragraph).filter(id=paragraph_id).exists() - # 批量向量化 VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) except Exception as e: diff --git a/apps/embedding/models/embedding.py b/apps/embedding/models/embedding.py index 5f954e36..c5414423 100644 --- a/apps/embedding/models/embedding.py +++ b/apps/embedding/models/embedding.py @@ -48,5 +48,19 @@ class Embedding(models.Model): meta = models.JSONField(verbose_name="元数据", default=dict) + # 将对象转换成es可用的json对象 + def to_dict(self): + return { + "id": str(self.id), + "document_id": self.document_id, + "paragraph_id": self.paragraph_id, + "dataset_id": self.dataset_id, + "is_active": self.is_active, + "source_id": self.source_id, + "source_type": self.source_type, + "embedding": self.embedding, + "search_vector": self.search_vector, + } + class Meta: db_table = "embedding" diff --git a/apps/embedding/vector/es_vector.py b/apps/embedding/vector/es_vector.py new file mode 100644 index 00000000..25d325c6 --- /dev/null +++ b/apps/embedding/vector/es_vector.py @@ -0,0 +1,330 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:ilxch@163.com + @file: es_vector.py + @date:2024/09/04 15:28 + @desc: ES数据库操作API +""" +import uuid +from abc import ABC, abstractmethod +from typing import Dict, List +from langchain_core.embeddings import Embeddings +from common.util.ts_vecto_util import to_ts_vector, to_query +from embedding.models import Embedding, SourceType, SearchMode +from embedding.vector.base_vector import BaseVectorStore +from elasticsearch import helpers + + +class ESVector(BaseVectorStore): + def __init__(self, es_connection, index_name): + self.es = es_connection + self.index_name = index_name + + def search(self, body: Dict, size: int): + print(body) + return self.es.search(index=self.index_name, body=body, size=size) + + def searchembed(self, query: Dict, vector: List[float], size, similarity): + return self.es.search(index=self.index_name, body={ + "size": size, + "query": query, + "knn": { + "field": "embedding", + "query_vector": vector, + "k": size, + "similarity": similarity + }, + "_source": ["paragraph_id"] # 添加这一行 + }) + + def delete_by_query(self, query): + return self.es.search(index=self.index_name, body=query) + + def bulk(self, documents: List[Dict]): + actions = [ + { + "_index": self.index_name, + "_id": document.get('id'), + "_source": document + } + for document in documents + ] + return helpers.bulk(self.es, actions) + + def delete_by_source_ids(self, source_ids: List[str], source_type: str): + if len(source_ids) == 0: + return + return self.delete_by_query(query={ + "query": { + "terms": { + "_id": source_ids + } + } + }) + + def update_by_source_ids(self, source_ids: List[str], instance: Dict): + return self.bulk(source_ids, instance) + + def vector_is_create(self) -> bool: + return self.es.indices.exists(index=self.index_name) + + def vector_create(self): + return self.es.indices.create(index=self.index_name, ignore=400) + + def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, + is_active: bool, + embedding: Embeddings): + text_embedding = embedding.embed_query(text) + embedding = Embedding(id=uuid.uuid1(), + dataset_id=dataset_id, + document_id=document_id, + is_active=is_active, + paragraph_id=paragraph_id, + source_id=source_id, + embedding=text_embedding, + source_type=source_type, + search_vector=to_ts_vector(text)) + res = self.es.index(index=self.index_name, body=embedding.to_dict()) + print("res---->" + str(res)) + return True + + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): + texts = [row.get('text') for row in text_list] + embeddings = embedding.embed_documents(texts) + embedding_list = [Embedding(id=uuid.uuid1(), + document_id=text_list[index].get('document_id'), + paragraph_id=text_list[index].get('paragraph_id'), + dataset_id=text_list[index].get('dataset_id'), + is_active=text_list[index].get('is_active', True), + source_id=text_list[index].get('source_id'), + source_type=text_list[index].get('source_type'), + embedding=embeddings[index], + search_vector=to_ts_vector(text_list[index]['text'])) for index in + range(0, len(texts))] + if is_save_function(): + embeddings = [embedding.to_dict() for embedding in embedding_list] + self.bulk(embeddings) + return True + + def update_by_source_id(self, source_id: str, instance: Dict): + self.es.update(index=self.index_name, id=source_id, body={"doc": instance}) + + def update_by_paragraph_id(self, paragraph_id: str, instance: Dict): + self.es.update(index=self.index_name, id=paragraph_id, body={"doc": instance}) + + def update_by_paragraph_ids(self, paragraph_ids: List[str], instance: Dict): + for paragraph_id in paragraph_ids: + self.es.update(index=self.index_name, id=paragraph_id, body={"doc": instance}) + + def delete_by_dataset_id(self, dataset_id: str): + self.es.delete_by_query(index=self.index_name, body={"query": {"match": {"dataset_id": dataset_id}}}) + + def delete_by_dataset_id_list(self, dataset_id_list: List[str]): + self.delete_by_query(body={"query": {"terms": {"dataset_id": dataset_id_list}}}) + + def delete_by_document_id(self, document_id: str): + self.delete_by_query(query={ + "query": { + "match": { + "document_id.keyword": document_id + } + } + }) + return True + + def delete_by_document_id_list(self, document_id_list: List[str]): + if len(document_id_list) == 0: + return True + self.delete_by_query(query={ + "query": { + "match": { + "document_id.keyword": document_id_list + } + } + }) + return True + + def delete_by_source_id(self, source_id: str, source_type: str): + self.delete_by_query(query={"query": { + "bool": {"must": [{"match": {"source_id.keyword": source_id}}, {"match": {"source_type": source_type}}]} + }}) + return True + + def delete_by_paragraph_id(self, paragraph_id: str): + self.delete_by_query(query={ + "query": { + "match": { + "paragraph_id.keyword": paragraph_id + } + } + }) + + def delete_by_paragraph_ids(self, paragraph_ids: List[str]): + self.delete_by_query(query={ + "query": { + "terms": { + "paragraph_id.keyword": paragraph_ids + } + } + }) + + def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, + search_mode: SearchMode, + embedding: Embeddings): + if dataset_id_list is None or len(dataset_id_list) == 0: + return [] + embedding_query = embedding.embed_query(query_text) + filter_query = { + "bool": { + "must": [ + {"term": {"is_active": True}}, + {"terms": {"dataset_id.keyword": [str(uuid) for uuid in dataset_id_list]}} + ], + "must_not": [ + {"terms": {"document_id.keyword": [str(uuid) for uuid in exclude_document_id_list]}} + ] + } + } + for search_handle in search_handle_list: + if search_handle.support(search_mode, self): + return search_handle.handle(filter_query, query_text, embedding_query, top_number, similarity, + search_mode) + + def query(self, query_text: str, query_embedding: List[float], dataset_id_list: list[str], + exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float, + search_mode: SearchMode): + if dataset_id_list is None or len(dataset_id_list) == 0: + return [] + filter_query = { + "bool": { + "must": [ + {"term": {"is_active": "true"}}, + {"terms": {"dataset_id.keyword": [str(uuid) for uuid in dataset_id_list]}} + ], + "must_not": [ + {"terms": {"document_id.keyword": [str(uuid) for uuid in exclude_document_id_list]}}, + {"terms": {"paragraph_id.keyword": [str(uuid) for uuid in exclude_paragraph_list]}} + ] + } + } + for search_handle in search_handle_list: + if search_handle.support(search_mode): + return search_handle.handle(filter_query, query_text, query_embedding, top_n, similarity, search_mode) + + +class ISearch(ABC): + @abstractmethod + def support(self, search_mode: SearchMode, esVector: ESVector): + pass + + @abstractmethod + def handle(self, filter_query: Dict, query_text: str, query_embedding: List[float], top_number: int, + similarity: float, search_mode: SearchMode): + pass + + +# 向量检索 +class EmbeddingSearch(ISearch): + def __init__(self): + self.esVector = None + + def handle(self, + filter_query, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + result = self.esVector.searchembed(query=filter_query, vector=query_embedding, size=top_number, + similarity=similarity) + hits = result['hits']['hits'] + return [{'paragraph_id': uuid.UUID(hit['_source']['paragraph_id']), + 'similarity': hit['_score'], + 'comprehensive_score': hit['_source'].get('_score', 0)} + for hit in hits] + + def support(self, search_mode: SearchMode, esVector: ESVector): + self.esVector = esVector + return search_mode.value == SearchMode.embedding.value + + +# 关键词检索 +class KeywordsSearch(ISearch): + def __init__(self): + self.esVector = None + + def handle(self, + filter_query, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + filter_query['bool']['must'].append({"multi_match": { + "query": to_query(query_text), + "fields": ["search_vector"] + }}) + result = self.esVector.search(body={ + "query": filter_query, + "sort": ["_score"] + }, size=top_number) + hits = [hit for hit in result['hits']['hits'] if hit['_score'] > similarity] + return [{'paragraph_id': uuid.UUID(hit['_source']['paragraph_id']), + 'similarity': hit['_score'], + 'comprehensive_score': hit['_source'].get('_score', 0)} + for hit in hits] + + def support(self, search_mode: SearchMode, esVector: ESVector): + self.esVector = esVector + return search_mode.value == SearchMode.keywords.value + + +# 混合检索 +class BlendSearch(ISearch): + def __init__(self): + self.esVector = None + + def handle(self, + filter_query, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + filter_query["bool"]["should"] = { + "function_score": { + "query": {'match': {'search_vector': query_text}}, + "functions": [{ + 'script_score': { + 'script': { + 'source': "cosineSimilarity(params.query_vector, 'embedding') + 1.0", + 'params': {'query_vector': []} + } + } + }], + 'boost_mode': 'replace' + } + } + result = self.esVector.search(body={ + "sort": ["_score"], + "query": filter_query, + "_source": ["paragraph_id"] # 添加这一行 + }, size=top_number) + + hits = [hit for hit in result['hits']['hits'] if hit['_score'] > similarity] + + return [{'paragraph_id': uuid.UUID(hit['_source']['paragraph_id']), + 'similarity': hit['_score'], + 'comprehensive_score': hit['_source'].get('_score', 0)} + for hit in hits] + + def support(self, search_mode: SearchMode, esVector: ESVector): + self.esVector = esVector + return search_mode.value == SearchMode.blend.value + + +# 混合检索 +search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()] diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 4de2f006..6b331398 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -139,8 +139,6 @@ class BaseModelCredential(ABC): class ModelTypeConst(Enum): LLM = {'code': 'LLM', 'message': '大语言模型'} EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'} - STT = {'code': 'STT', 'message': '语音识别'} - TTS = {'code': 'TTS', 'message': '语音合成'} class ModelInfo: diff --git a/apps/setting/models_provider/impl/base_stt.py b/apps/setting/models_provider/impl/base_stt.py deleted file mode 100644 index aae72a55..00000000 --- a/apps/setting/models_provider/impl/base_stt.py +++ /dev/null @@ -1,14 +0,0 @@ -# coding=utf-8 -from abc import abstractmethod - -from pydantic import BaseModel - - -class BaseSpeechToText(BaseModel): - @abstractmethod - def check_auth(self): - pass - - @abstractmethod - def speech_to_text(self, audio_file): - pass diff --git a/apps/setting/models_provider/impl/base_tts.py b/apps/setting/models_provider/impl/base_tts.py deleted file mode 100644 index 6311f268..00000000 --- a/apps/setting/models_provider/impl/base_tts.py +++ /dev/null @@ -1,14 +0,0 @@ -# coding=utf-8 -from abc import abstractmethod - -from pydantic import BaseModel - - -class BaseTextToSpeech(BaseModel): - @abstractmethod - def check_auth(self): - pass - - @abstractmethod - def text_to_speech(self, text): - pass diff --git a/apps/setting/models_provider/impl/local_model_provider/model/embedding.py b/apps/setting/models_provider/impl/local_model_provider/model/embedding.py index 38cf1aeb..b2c60f40 100644 --- a/apps/setting/models_provider/impl/local_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/local_model_provider/model/embedding.py @@ -33,6 +33,7 @@ class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings): res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/embed_query', {'text': text}) result = res.json() + #print("查询内容:",result) if result.get('code', 500) == 200: return result.get('data') raise Exception(result.get('msg')) diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py b/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py deleted file mode 100644 index 59506311..00000000 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py +++ /dev/null @@ -1,42 +0,0 @@ -# coding=utf-8 -from typing import Dict - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm -from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode - - -class OpenAISTTModelCredential(BaseForm, BaseModelCredential): - api_base = forms.TextInputField('API 域名', required=True) - api_key = forms.PasswordInputField('API Key', required=True) - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, - raise_exception=False): - model_type_list = provider.get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['api_base', 'api_key']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = provider.get_model(model_type, model_name, model_credential) - model.check_auth() - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'api_key': super().encryption(model.get('api_key', ''))} - - def get_model_params_setting_form(self, model_name): - pass diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/stt.py b/apps/setting/models_provider/impl/openai_model_provider/model/stt.py deleted file mode 100644 index 66b2daed..00000000 --- a/apps/setting/models_provider/impl/openai_model_provider/model/stt.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Dict - -from openai import OpenAI - -from common.config.tokenizer_manage_config import TokenizerManage -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_stt import BaseSpeechToText - - -def custom_get_token_ids(text: str): - tokenizer = TokenizerManage.get_tokenizer() - return tokenizer.encode(text) - - -class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): - api_base: str - api_key: str - model: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.api_key = kwargs.get('api_key') - self.api_base = kwargs.get('api_base') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return OpenAISpeechToText( - model=model_name, - api_base=model_credential.get('api_base'), - api_key=model_credential.get('api_key'), - **optional_params, - ) - - def check_auth(self): - client = OpenAI( - base_url=self.api_base, - api_key=self.api_key - ) - response_list = client.models.with_raw_response.list() - # print(response_list) - - def speech_to_text(self, audio_file): - client = OpenAI( - base_url=self.api_base, - api_key=self.api_key - ) - res = client.audio.transcriptions.create(model=self.model, language="zh", file=audio_file) - return res.text diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/tts.py b/apps/setting/models_provider/impl/openai_model_provider/model/tts.py deleted file mode 100644 index c0975484..00000000 --- a/apps/setting/models_provider/impl/openai_model_provider/model/tts.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Dict - -from openai import OpenAI - -from common.config.tokenizer_manage_config import TokenizerManage -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_tts import BaseTextToSpeech - - -def custom_get_token_ids(text: str): - tokenizer = TokenizerManage.get_tokenizer() - return tokenizer.encode(text) - - -class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech): - api_base: str - api_key: str - model: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.api_key = kwargs.get('api_key') - self.api_base = kwargs.get('api_base') - self.model = kwargs.get('model') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return OpenAITextToSpeech( - model=model_name, - api_base=model_credential.get('api_base'), - api_key=model_credential.get('api_key'), - **optional_params, - ) - - def check_auth(self): - client = OpenAI( - base_url=self.api_base, - api_key=self.api_key - ) - response_list = client.models.with_raw_response.list() - # print(response_list) - - def text_to_speech(self, text): - client = OpenAI( - base_url=self.api_base, - api_key=self.api_key - ) - with client.audio.speech.with_streaming_response.create( - model=self.model, - voice="alloy", - input=text, - ) as response: - return response.read() diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index cbec2d50..fb4c89d7 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -13,15 +13,11 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro ModelTypeConst, ModelInfoManage from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential -from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel -from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText -from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech from smartdoc.conf import PROJECT_DIR openai_llm_model_credential = OpenAILLMModelCredential() -openai_stt_model_credential = OpenAISTTModelCredential() model_info_list = [ ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel @@ -62,13 +58,7 @@ model_info_list = [ OpenAIChatModel), ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens', ModelTypeConst.LLM, openai_llm_model_credential, - OpenAIChatModel), - ModelInfo('whisper-1', '', - ModelTypeConst.STT, openai_stt_model_credential, - OpenAISpeechToText), - ModelInfo('tts-1', '', - ModelTypeConst.TTS, openai_stt_model_credential, - OpenAITextToSpeech) + OpenAIChatModel) ] open_ai_embedding_credential = OpenAIEmbeddingCredential() model_info_embedding_list = [ diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py deleted file mode 100644 index 94e65616..00000000 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py +++ /dev/null @@ -1,45 +0,0 @@ -# coding=utf-8 - -from typing import Dict - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm -from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode - - -class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): - volcanic_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://openspeech.bytedance.com/api/v2/asr') - volcanic_app_id = forms.TextInputField('App ID', required=True) - volcanic_token = forms.PasswordInputField('Token', required=True) - volcanic_cluster = forms.TextInputField('Cluster', required=True) - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, - raise_exception=False): - model_type_list = provider.get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = provider.get_model(model_type, model_name, model_credential) - model.check_auth() - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} - - def get_model_params_setting_form(self, model_name): - pass diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py deleted file mode 100644 index 353740a4..00000000 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py +++ /dev/null @@ -1,45 +0,0 @@ -# coding=utf-8 - -from typing import Dict - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm -from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode - - -class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential): - volcanic_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://openspeech.bytedance.com/api/v1/tts/ws_binary') - volcanic_app_id = forms.TextInputField('App ID', required=True) - volcanic_token = forms.PasswordInputField('Token', required=True) - volcanic_cluster = forms.TextInputField('Cluster', required=True) - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, - raise_exception=False): - model_type_list = provider.get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = provider.get_model(model_type, model_name, model_credential) - model.check_auth() - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} - - def get_model_params_setting_form(self, model_name): - pass diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py deleted file mode 100644 index 36aed642..00000000 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py +++ /dev/null @@ -1,345 +0,0 @@ -# coding=utf-8 - -""" -requires Python 3.6 or later - -pip install asyncio -pip install websockets -""" -import asyncio -import base64 -import gzip -import hmac -import json -import uuid -import wave -from enum import Enum -from hashlib import sha256 -from io import BytesIO -from typing import Dict -from urllib.parse import urlparse -import ssl -import websockets - -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_stt import BaseSpeechToText - -audio_format = "mp3" # wav 或者 mp3,根据实际音频格式设置 - -PROTOCOL_VERSION = 0b0001 -DEFAULT_HEADER_SIZE = 0b0001 - -PROTOCOL_VERSION_BITS = 4 -HEADER_BITS = 4 -MESSAGE_TYPE_BITS = 4 -MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4 -MESSAGE_SERIALIZATION_BITS = 4 -MESSAGE_COMPRESSION_BITS = 4 -RESERVED_BITS = 8 - -# Message Type: -CLIENT_FULL_REQUEST = 0b0001 -CLIENT_AUDIO_ONLY_REQUEST = 0b0010 -SERVER_FULL_RESPONSE = 0b1001 -SERVER_ACK = 0b1011 -SERVER_ERROR_RESPONSE = 0b1111 - -# Message Type Specific Flags -NO_SEQUENCE = 0b0000 # no check sequence -POS_SEQUENCE = 0b0001 -NEG_SEQUENCE = 0b0010 -NEG_SEQUENCE_1 = 0b0011 - -# Message Serialization -NO_SERIALIZATION = 0b0000 -JSON = 0b0001 -THRIFT = 0b0011 -CUSTOM_TYPE = 0b1111 - -# Message Compression -NO_COMPRESSION = 0b0000 -GZIP = 0b0001 -CUSTOM_COMPRESSION = 0b1111 - -ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -ssl_context.check_hostname = False -ssl_context.verify_mode = ssl.CERT_NONE - - -def generate_header( - version=PROTOCOL_VERSION, - message_type=CLIENT_FULL_REQUEST, - message_type_specific_flags=NO_SEQUENCE, - serial_method=JSON, - compression_type=GZIP, - reserved_data=0x00, - extension_header=bytes() -): - """ - protocol_version(4 bits), header_size(4 bits), - message_type(4 bits), message_type_specific_flags(4 bits) - serialization_method(4 bits) message_compression(4 bits) - reserved (8bits) 保留字段 - header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) - """ - header = bytearray() - header_size = int(len(extension_header) / 4) + 1 - header.append((version << 4) | header_size) - header.append((message_type << 4) | message_type_specific_flags) - header.append((serial_method << 4) | compression_type) - header.append(reserved_data) - header.extend(extension_header) - return header - - -def generate_full_default_header(): - return generate_header() - - -def generate_audio_default_header(): - return generate_header( - message_type=CLIENT_AUDIO_ONLY_REQUEST - ) - - -def generate_last_audio_default_header(): - return generate_header( - message_type=CLIENT_AUDIO_ONLY_REQUEST, - message_type_specific_flags=NEG_SEQUENCE - ) - - -def parse_response(res): - """ - protocol_version(4 bits), header_size(4 bits), - message_type(4 bits), message_type_specific_flags(4 bits) - serialization_method(4 bits) message_compression(4 bits) - reserved (8bits) 保留字段 - header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) - payload 类似与http 请求体 - """ - protocol_version = res[0] >> 4 - header_size = res[0] & 0x0f - message_type = res[1] >> 4 - message_type_specific_flags = res[1] & 0x0f - serialization_method = res[2] >> 4 - message_compression = res[2] & 0x0f - reserved = res[3] - header_extensions = res[4:header_size * 4] - payload = res[header_size * 4:] - result = {} - payload_msg = None - payload_size = 0 - if message_type == SERVER_FULL_RESPONSE: - payload_size = int.from_bytes(payload[:4], "big", signed=True) - payload_msg = payload[4:] - elif message_type == SERVER_ACK: - seq = int.from_bytes(payload[:4], "big", signed=True) - result['seq'] = seq - if len(payload) >= 8: - payload_size = int.from_bytes(payload[4:8], "big", signed=False) - payload_msg = payload[8:] - elif message_type == SERVER_ERROR_RESPONSE: - code = int.from_bytes(payload[:4], "big", signed=False) - result['code'] = code - payload_size = int.from_bytes(payload[4:8], "big", signed=False) - payload_msg = payload[8:] - if payload_msg is None: - return result - if message_compression == GZIP: - payload_msg = gzip.decompress(payload_msg) - if serialization_method == JSON: - payload_msg = json.loads(str(payload_msg, "utf-8")) - elif serialization_method != NO_SERIALIZATION: - payload_msg = str(payload_msg, "utf-8") - result['payload_msg'] = payload_msg - result['payload_size'] = payload_size - return result - - -def read_wav_info(data: bytes = None) -> (int, int, int, int, int): - with BytesIO(data) as _f: - wave_fp = wave.open(_f, 'rb') - nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4] - wave_bytes = wave_fp.readframes(nframes) - return nchannels, sampwidth, framerate, nframes, len(wave_bytes) - - -class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText): - workflow: str = "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate" - show_language: bool = False - show_utterances: bool = False - result_type: str = "full" - format: str = "mp3" - rate: int = 16000 - language: str = "zh-CN" - bits: int = 16 - channel: int = 1 - codec: str = "raw" - audio_type: int = 1 - secret: str = "access_secret" - auth_method: str = "token" - mp3_seg_size: int = 10000 - success_code: int = 1000 # success code, default is 1000 - seg_duration: int = 15000 - nbest: int = 1 - - volcanic_app_id: str - volcanic_cluster: str - volcanic_api_url: str - volcanic_token: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.volcanic_api_url = kwargs.get('volcanic_api_url') - self.volcanic_token = kwargs.get('volcanic_token') - self.volcanic_app_id = kwargs.get('volcanic_app_id') - self.volcanic_cluster = kwargs.get('volcanic_cluster') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return VolcanicEngineSpeechToText( - volcanic_api_url=model_credential.get('volcanic_api_url'), - volcanic_token=model_credential.get('volcanic_token'), - volcanic_app_id=model_credential.get('volcanic_app_id'), - volcanic_cluster=model_credential.get('volcanic_cluster'), - **optional_params - ) - - def construct_request(self, reqid): - req = { - 'app': { - 'appid': self.volcanic_app_id, - 'cluster': self.volcanic_cluster, - 'token': self.volcanic_token, - }, - 'user': { - 'uid': 'uid' - }, - 'request': { - 'reqid': reqid, - 'nbest': self.nbest, - 'workflow': self.workflow, - 'show_language': self.show_language, - 'show_utterances': self.show_utterances, - 'result_type': self.result_type, - "sequence": 1 - }, - 'audio': { - 'format': self.format, - 'rate': self.rate, - 'language': self.language, - 'bits': self.bits, - 'channel': self.channel, - 'codec': self.codec - } - } - return req - - @staticmethod - def slice_data(data: bytes, chunk_size: int) -> (list, bool): - """ - slice data - :param data: wav data - :param chunk_size: the segment size in one request - :return: segment data, last flag - """ - data_len = len(data) - offset = 0 - while offset + chunk_size < data_len: - yield data[offset: offset + chunk_size], False - offset += chunk_size - else: - yield data[offset: data_len], True - - def _real_processor(self, request_params: dict) -> dict: - pass - - def token_auth(self): - return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)} - - def signature_auth(self, data): - header_dicts = { - 'Custom': 'auth_custom', - } - - url_parse = urlparse(self.volcanic_api_url) - input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path) - auth_headers = 'Custom' - for header in auth_headers.split(','): - input_str += '{}\n'.format(header_dicts[header]) - input_data = bytearray(input_str, 'utf-8') - input_data += data - mac = base64.urlsafe_b64encode( - hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest()) - header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.volcanic_token, - str(mac, 'utf-8'), - auth_headers) - return header_dicts - - async def segment_data_processor(self, wav_data: bytes, segment_size: int): - reqid = str(uuid.uuid4()) - # 构建 full client request,并序列化压缩 - request_params = self.construct_request(reqid) - payload_bytes = str.encode(json.dumps(request_params)) - payload_bytes = gzip.compress(payload_bytes) - full_client_request = bytearray(generate_full_default_header()) - full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) - full_client_request.extend(payload_bytes) # payload - header = None - if self.auth_method == "token": - header = self.token_auth() - elif self.auth_method == "signature": - header = self.signature_auth(full_client_request) - async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000, - ssl=ssl_context) as ws: - # 发送 full client request - await ws.send(full_client_request) - res = await ws.recv() - result = parse_response(res) - if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: - return result - for seq, (chunk, last) in enumerate(VolcanicEngineSpeechToText.slice_data(wav_data, segment_size), 1): - # if no compression, comment this line - payload_bytes = gzip.compress(chunk) - audio_only_request = bytearray(generate_audio_default_header()) - if last: - audio_only_request = bytearray(generate_last_audio_default_header()) - audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) - audio_only_request.extend(payload_bytes) # payload - # 发送 audio-only client request - await ws.send(audio_only_request) - res = await ws.recv() - result = parse_response(res) - if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: - return result - return result['payload_msg']['result'][0]['text'] - - def check_auth(self): - header = self.token_auth() - - async def check(): - async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000, - ssl=ssl_context) as ws: - pass - - asyncio.run(check()) - - def speech_to_text(self, file): - data = file.read() - audio_data = bytes(data) - if self.format == "mp3": - segment_size = self.mp3_seg_size - return asyncio.run(self.segment_data_processor(audio_data, segment_size)) - if self.format != "wav": - raise Exception("format should in wav or mp3") - nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info( - audio_data) - size_per_sec = nchannels * sampwidth * framerate - segment_size = int(size_per_sec * self.seg_duration / 1000) - return asyncio.run(self.segment_data_processor(audio_data, segment_size)) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py deleted file mode 100644 index 3a5e0afb..00000000 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py +++ /dev/null @@ -1,170 +0,0 @@ -# coding=utf-8 - -''' -requires Python 3.6 or later - -pip install asyncio -pip install websockets - -''' - -import asyncio -import copy -import gzip -import json -import uuid -from typing import Dict -import ssl -import websockets - -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_tts import BaseTextToSpeech - -MESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"} -MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0", - 2: "last message from server (seq < 0)", 3: "sequence number < 0"} -MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"} -MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"} - -# version: b0001 (4 bits) -# header size: b0001 (4 bits) -# message type: b0001 (Full client request) (4bits) -# message type specific flags: b0000 (none) (4bits) -# message serialization method: b0001 (JSON) (4 bits) -# message compression: b0001 (gzip) (4bits) -# reserved data: 0x00 (1 byte) -default_header = bytearray(b'\x11\x10\x11\x00') - -ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -ssl_context.check_hostname = False -ssl_context.verify_mode = ssl.CERT_NONE - - -class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): - volcanic_app_id: str - volcanic_cluster: str - volcanic_api_url: str - volcanic_token: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.volcanic_api_url = kwargs.get('volcanic_api_url') - self.volcanic_token = kwargs.get('volcanic_token') - self.volcanic_app_id = kwargs.get('volcanic_app_id') - self.volcanic_cluster = kwargs.get('volcanic_cluster') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return VolcanicEngineTextToSpeech( - volcanic_api_url=model_credential.get('volcanic_api_url'), - volcanic_token=model_credential.get('volcanic_token'), - volcanic_app_id=model_credential.get('volcanic_app_id'), - volcanic_cluster=model_credential.get('volcanic_cluster'), - **optional_params - ) - - def check_auth(self): - header = self.token_auth() - - async def check(): - async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None, - ssl=ssl_context) as ws: - pass - - asyncio.run(check()) - - def text_to_speech(self, text): - request_json = { - "app": { - "appid": self.volcanic_app_id, - "token": "access_token", - "cluster": self.volcanic_cluster - }, - "user": { - "uid": "uid" - }, - "audio": { - "voice_type": "BV002_streaming", - "encoding": "mp3", - "speed_ratio": 1.0, - "volume_ratio": 1.0, - "pitch_ratio": 1.0, - }, - "request": { - "reqid": str(uuid.uuid4()), - "text": text, - "text_type": "plain", - "operation": "xxx" - } - } - - return asyncio.run(self.submit(request_json)) - - def token_auth(self): - return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)} - - async def submit(self, request_json): - submit_request_json = copy.deepcopy(request_json) - submit_request_json["request"]["reqid"] = str(uuid.uuid4()) - submit_request_json["request"]["operation"] = "submit" - payload_bytes = str.encode(json.dumps(submit_request_json)) - payload_bytes = gzip.compress(payload_bytes) # if no compression, comment this line - full_client_request = bytearray(default_header) - full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) - full_client_request.extend(payload_bytes) # payload - header = {"Authorization": f"Bearer; {self.volcanic_token}"} - async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None, - ssl=ssl_context) as ws: - await ws.send(full_client_request) - return await self.parse_response(ws) - - @staticmethod - async def parse_response(ws): - result = b'' - while True: - res = await ws.recv() - protocol_version = res[0] >> 4 - header_size = res[0] & 0x0f - message_type = res[1] >> 4 - message_type_specific_flags = res[1] & 0x0f - serialization_method = res[2] >> 4 - message_compression = res[2] & 0x0f - reserved = res[3] - header_extensions = res[4:header_size * 4] - payload = res[header_size * 4:] - if header_size != 1: - # print(f" Header extensions: {header_extensions}") - pass - if message_type == 0xb: # audio-only server response - if message_type_specific_flags == 0: # no sequence number as ACK - continue - else: - sequence_number = int.from_bytes(payload[:4], "big", signed=True) - payload_size = int.from_bytes(payload[4:8], "big", signed=False) - payload = payload[8:] - result += payload - if sequence_number < 0: - break - else: - continue - elif message_type == 0xf: - code = int.from_bytes(payload[:4], "big", signed=False) - msg_size = int.from_bytes(payload[4:8], "big", signed=False) - error_msg = payload[8:] - if message_compression == 1: - error_msg = gzip.decompress(error_msg) - error_msg = str(error_msg, "utf-8") - break - elif message_type == 0xc: - msg_size = int.from_bytes(payload[:4], "big", signed=False) - payload = payload[4:] - if message_compression == 1: - payload = gzip.decompress(payload) - else: - break - return result diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py index 1a0e17d8..48802f6b 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -14,34 +14,18 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel -from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel -from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential -from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText -from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech from smartdoc.conf import PROJECT_DIR volcanic_engine_llm_model_credential = OpenAILLMModelCredential() -volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential() -volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential() model_info_list = [ ModelInfo('ep-xxxxxxxxxx-yyyy', '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用', ModelTypeConst.LLM, volcanic_engine_llm_model_credential, VolcanicEngineChatModel - ), - ModelInfo('asr', - '', - ModelTypeConst.STT, - volcanic_engine_stt_model_credential, VolcanicEngineSpeechToText - ), - ModelInfo('tts', - '', - ModelTypeConst.TTS, - volcanic_engine_tts_model_credential, VolcanicEngineTextToSpeech - ), + ) ] open_ai_embedding_credential = OpenAIEmbeddingCredential() diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py b/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py deleted file mode 100644 index bf051c18..00000000 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py +++ /dev/null @@ -1,46 +0,0 @@ -# coding=utf-8 - -from typing import Dict - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm -from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode - - -class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): - spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://iat-api.xfyun.cn/v2/iat') - spark_app_id = forms.TextInputField('APP ID', required=True) - spark_api_key = forms.PasswordInputField("API Key", required=True) - spark_api_secret = forms.PasswordInputField('API Secret', required=True) - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, - raise_exception=False): - model_type_list = provider.get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = provider.get_model(model_type, model_name, model_credential) - model.check_auth() - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} - - - def get_model_params_setting_form(self, model_name): - pass diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py b/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py deleted file mode 100644 index 309f882b..00000000 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py +++ /dev/null @@ -1,46 +0,0 @@ -# coding=utf-8 - -from typing import Dict - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm -from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode - - -class XunFeiTTSModelCredential(BaseForm, BaseModelCredential): - spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://tts-api.xfyun.cn/v2/tts') - spark_app_id = forms.TextInputField('APP ID', required=True) - spark_api_key = forms.PasswordInputField("API Key", required=True) - spark_api_secret = forms.PasswordInputField('API Secret', required=True) - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, - raise_exception=False): - model_type_list = provider.get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = provider.get_model(model_type, model_name, model_credential) - model.check_auth() - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} - - - def get_model_params_setting_form(self, model_name): - pass diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py deleted file mode 100644 index f57e6bf1..00000000 --- a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py +++ /dev/null @@ -1,169 +0,0 @@ -# -*- coding:utf-8 -*- -# -# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看) -# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -import asyncio -import base64 -import datetime -import hashlib -import hmac -import json -from datetime import datetime -from typing import Dict -from urllib.parse import urlencode, urlparse -import ssl -import websockets - -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_stt import BaseSpeechToText - -STATUS_FIRST_FRAME = 0 # 第一帧的标识 -STATUS_CONTINUE_FRAME = 1 # 中间帧标识 -STATUS_LAST_FRAME = 2 # 最后一帧的标识 - -ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -ssl_context.check_hostname = False -ssl_context.verify_mode = ssl.CERT_NONE - - -class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): - spark_app_id: str - spark_api_key: str - spark_api_secret: str - spark_api_url: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.spark_api_url = kwargs.get('spark_api_url') - self.spark_app_id = kwargs.get('spark_app_id') - self.spark_api_key = kwargs.get('spark_api_key') - self.spark_api_secret = kwargs.get('spark_api_secret') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return XFSparkSpeechToText( - spark_app_id=model_credential.get('spark_app_id'), - spark_api_key=model_credential.get('spark_api_key'), - spark_api_secret=model_credential.get('spark_api_secret'), - spark_api_url=model_credential.get('spark_api_url'), - **optional_params - ) - - # 生成url - def create_url(self): - url = self.spark_api_url - host = urlparse(url).hostname - # 生成RFC1123格式的时间戳 - gmt_format = '%a, %d %b %Y %H:%M:%S GMT' - date = datetime.utcnow().strftime(gmt_format) - - # 拼接字符串 - signature_origin = "host: " + host + "\n" - signature_origin += "date: " + date + "\n" - signature_origin += "GET " + "/v2/iat " + "HTTP/1.1" - # 进行hmac-sha256进行加密 - signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() - signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') - - authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( - self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha) - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') - # 将请求的鉴权参数组合为字典 - v = { - "authorization": authorization, - "date": date, - "host": host - } - # 拼接鉴权参数,生成url - url = url + '?' + urlencode(v) - # print("date: ",date) - # print("v: ",v) - # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 - # print('websocket url :', url) - return url - - def check_auth(self): - async def check(): - async with websockets.connect(self.create_url(), ssl=ssl_context) as ws: - pass - - asyncio.run(check()) - - def speech_to_text(self, file): - async def handle(): - async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: - # 发送 full client request - await self.send(ws, file) - return await self.handle_message(ws) - - return asyncio.run(handle()) - - @staticmethod - async def handle_message(ws): - res = await ws.recv() - message = json.loads(res) - code = message["code"] - sid = message["sid"] - if code != 0: - errMsg = message["message"] - print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) - return errMsg - else: - data = message["data"]["result"]["ws"] - result = "" - for i in data: - for w in i["cw"]: - result += w["w"] - # print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False))) - return result - - # 收到websocket连接建立的处理 - async def send(self, ws, file): - frameSize = 8000 # 每一帧的音频大小 - status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧 - - while True: - buf = file.read(frameSize) - # 文件结束 - if not buf: - status = STATUS_LAST_FRAME - # 第一帧处理 - # 发送第一帧音频,带business 参数 - # appid 必须带上,只需第一帧发送 - if status == STATUS_FIRST_FRAME: - d = { - "common": {"app_id": self.spark_app_id}, - "business": { - "domain": "iat", - "language": "zh_cn", - "accent": "mandarin", - "vinfo": 1, - "vad_eos": 10000 - }, - "data": { - "status": 0, "format": "audio/L16;rate=16000", - "audio": str(base64.b64encode(buf), 'utf-8'), - "encoding": "lame"} - } - d = json.dumps(d) - await ws.send(d) - status = STATUS_CONTINUE_FRAME - # 中间帧处理 - elif status == STATUS_CONTINUE_FRAME: - d = {"data": {"status": 1, "format": "audio/L16;rate=16000", - "audio": str(base64.b64encode(buf), 'utf-8'), - "encoding": "lame"}} - await ws.send(json.dumps(d)) - # 最后一帧处理 - elif status == STATUS_LAST_FRAME: - d = {"data": {"status": 2, "format": "audio/L16;rate=16000", - "audio": str(base64.b64encode(buf), 'utf-8'), - "encoding": "lame"}} - await ws.send(json.dumps(d)) - break diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py deleted file mode 100644 index 8d775583..00000000 --- a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py +++ /dev/null @@ -1,142 +0,0 @@ -# -*- coding:utf-8 -*- -# -# author: iflytek -# -# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看) -# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -import asyncio -import base64 -import datetime -import hashlib -import hmac -import json -import os -from datetime import datetime -from typing import Dict -from urllib.parse import urlencode, urlparse -import ssl -import websockets - -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_tts import BaseTextToSpeech - -STATUS_FIRST_FRAME = 0 # 第一帧的标识 -STATUS_CONTINUE_FRAME = 1 # 中间帧标识 -STATUS_LAST_FRAME = 2 # 最后一帧的标识 - -ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -ssl_context.check_hostname = False -ssl_context.verify_mode = ssl.CERT_NONE - - -class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): - spark_app_id: str - spark_api_key: str - spark_api_secret: str - spark_api_url: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.spark_api_url = kwargs.get('spark_api_url') - self.spark_app_id = kwargs.get('spark_app_id') - self.spark_api_key = kwargs.get('spark_api_key') - self.spark_api_secret = kwargs.get('spark_api_secret') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return XFSparkTextToSpeech( - spark_app_id=model_credential.get('spark_app_id'), - spark_api_key=model_credential.get('spark_api_key'), - spark_api_secret=model_credential.get('spark_api_secret'), - spark_api_url=model_credential.get('spark_api_url'), - **optional_params - ) - - # 生成url - def create_url(self): - url = self.spark_api_url - host = urlparse(url).hostname - # 生成RFC1123格式的时间戳 - gmt_format = '%a, %d %b %Y %H:%M:%S GMT' - date = datetime.utcnow().strftime(gmt_format) - - # 拼接字符串 - signature_origin = "host: " + host + "\n" - signature_origin += "date: " + date + "\n" - signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" - # 进行hmac-sha256进行加密 - signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() - signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') - - authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( - self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha) - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') - # 将请求的鉴权参数组合为字典 - v = { - "authorization": authorization, - "date": date, - "host": host - } - # 拼接鉴权参数,生成url - url = url + '?' + urlencode(v) - # print("date: ",date) - # print("v: ",v) - # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 - # print('websocket url :', url) - return url - - def check_auth(self): - async def check(): - async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: - pass - - asyncio.run(check()) - - def text_to_speech(self, text): - - # 使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"” - # self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")} - async def handle(): - async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: - # 发送 full client request - await self.send(ws, text) - return await self.handle_message(ws) - - return asyncio.run(handle()) - - @staticmethod - async def handle_message(ws): - audio_bytes: bytes = b'' - while True: - res = await ws.recv() - message = json.loads(res) - # print(message) - code = message["code"] - sid = message["sid"] - audio = message["data"]["audio"] - audio = base64.b64decode(audio) - - if code != 0: - errMsg = message["message"] - print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) - else: - audio_bytes += audio - # 退出 - if message["data"]["status"] == 2: - break - return audio_bytes - - async def send(self, ws, text): - d = { - "common": {"app_id": self.spark_app_id}, - "business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"}, - "data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")}, - } - d = json.dumps(d) - await ws.send(d) diff --git a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py index 8c60083c..d33b944e 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py +++ b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py @@ -13,25 +13,16 @@ from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ ModelInfoManage from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential -from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential -from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM -from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText -from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech from smartdoc.conf import PROJECT_DIR ssl._create_default_https_context = ssl.create_default_context() qwen_model_credential = XunFeiLLMModelCredential() -stt_model_credential = XunFeiSTTModelCredential() -tts_model_credential = XunFeiTTSModelCredential() -model_info_list = [ - ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), - ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), - ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), - ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), - ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech), -] +model_info_list = [ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM) + ] model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)).build() diff --git a/apps/setting/serializers/model_apply_serializers.py b/apps/setting/serializers/model_apply_serializers.py index 2177b5fe..e4b43b6d 100644 --- a/apps/setting/serializers/model_apply_serializers.py +++ b/apps/setting/serializers/model_apply_serializers.py @@ -35,14 +35,16 @@ class EmbedQuery(serializers.Serializer): class ModelApplySerializers(serializers.Serializer): model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) - def embed_documents(self, instance, with_valid=True): if with_valid: self.is_valid(raise_exception=True) EmbedDocuments(data=instance).is_valid(raise_exception=True) model = get_embedding_model(self.data.get('model_id')) - return model.embed_documents(instance.getlist('texts')) + texts = instance.getlist('texts') + for text in texts: + print("to embed texts:",text) + return model.embed_documents(texts) def embed_query(self, instance, with_valid=True): if with_valid: @@ -50,4 +52,6 @@ class ModelApplySerializers(serializers.Serializer): EmbedQuery(data=instance).is_valid(raise_exception=True) model = get_embedding_model(self.data.get('model_id')) - return model.embed_query(instance.get('text')) + text = instance.get('text') + query = model.embed_query(text) + return query diff --git a/apps/smartdoc/const.py b/apps/smartdoc/const.py index 9b1159aa..c1fba498 100644 --- a/apps/smartdoc/const.py +++ b/apps/smartdoc/const.py @@ -9,4 +9,4 @@ __all__ = ['BASE_DIR', 'PROJECT_DIR', 'VERSION', 'CONFIG'] BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) PROJECT_DIR = os.path.dirname(BASE_DIR) VERSION = '1.0.0' -CONFIG = ConfigManager.load_user_config(root_path=os.path.abspath('/opt/maxkb/conf')) +CONFIG = ConfigManager.load_user_config(root_path=os.path.abspath('J:/workspace/openai/MaxKB/conf')) -- Gitee From 9c0131036673ea75fb3589ff91b0980342f4a98a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A9=BA=E4=B8=AD=E6=99=BA=E5=9B=8A?= Date: Thu, 5 Sep 2024 14:35:57 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E4=BF=AE=E6=94=B9es=E6=A3=80=E7=B4=A2?= =?UTF-8?q?=E6=97=B6=EF=BC=8C=E8=AE=BE=E7=BD=AE=E7=9A=84=E6=9C=80=E5=B0=8F?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E5=BA=A6=E6=8E=92=E5=90=8D=E5=BE=97=E5=88=86?= =?UTF-8?q?=E6=97=A0=E6=95=88=E9=97=AE=E9=A2=98\n=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=B7=B7=E5=90=88=E6=A3=80=E7=B4=A2=E3=80=81=E5=90=91=E9=87=8F?= =?UTF-8?q?=E6=A3=80=E7=B4=A2=E3=80=81=E5=85=B3=E9=94=AE=E8=AF=8D=E6=A3=80?= =?UTF-8?q?=E7=B4=A2=E5=BC=8F=E7=9A=84=E6=A3=80=E7=B4=A2=E6=96=B9=E5=BC=8F?= =?UTF-8?q?=EF=BC=8C=E6=8F=90=E9=AB=98=E6=A3=80=E7=B4=A2=E6=95=88=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .idea/icon.png | Bin 8154 -> 0 bytes apps/embedding/vector/es_vector.py | 62 +++++++++++++---------------- conf/config_example.yml | 29 ++++++++++++++ main.py | 5 ++- model/base/hub/version.txt | 1 + pyproject.toml | 4 +- ui/src/enums/model.ts | 4 +- 7 files changed, 65 insertions(+), 40 deletions(-) delete mode 100644 .idea/icon.png create mode 100644 conf/config_example.yml create mode 100644 model/base/hub/version.txt diff --git a/.idea/icon.png b/.idea/icon.png deleted file mode 100644 index 7d9781edbb6a38cf876574b7ae0566b2dd45672d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8154 zcmeHsS5#Bqw{EBbX)4mCi3*`h5s?lGij@4Q3PLC%QY6v|p@>pcq$r_Fi3kWN5hqSV|5L^pcZ@sk+kH3>d+fc}nse^8*P3&F-$+7&Ahh=Fi)v9SUzUzLBk0XOys*KP%ZKmsR@E(YBBEjhnr%k}M0b%qH)*l$tE5)H>Y}8BCzS~ngn)scxde@Iigvv$xb)v?3+L`ZQU zVQx@?PAU+E)30~jLF8*K8LcNvdA;MDBe)~+Q_QUG#VT)>aC(V_tv@=sG|<|yA@a7L zs8-$yo9iW5Fz1VXl{bm*JEi5zXem@eXY!X2HYIV#oLRA<<4e}2s%sDP?XL1Lp|#sX z%;e&9ODXFZ$47*e)1TxgDS9zXT9u|$CV!VmnS9#^B=IU#x`BsYun^2Zm zzIP93f7@q8h%h9aPKR<6(-Nj!wLP+gi7(E)>Y)>2`KE^12>QB+`bgMjO~1}VT|^Zz zkVc3do&1N50MczxvvIQNPRO0**PrXpw5^t01H}gP<9>&AubJ!jTb$~ODQ|E#Y%mTL z^|tTf4`BG`lt}9HXi6hBcjXMq=fM~{!Uqx97FiZ^?^(aedA)F*K&lN_Y(T+Gu#C8Z zMjG~`@RQYFCK%6!0cbQkLAQ|#b&mKn=iLWpWp%pnN@r(bOm(X-a>To{l=H{B(SD6S zE3275a0B#+MQ8Ikg-cRtD%XfJPT=F zJS|mldED7m$1_yX<;f|&t=8x$sp{ctDQ{6oPh$xtqms8-o_fV@>q6*hrUD1yk5NHH zT|s?KltLD=V`A89WWUZ!>M+q$*pxsA9={hcLK-P z06e)Lr`}{TKa(Z5vxz=XQ*YSS*UeGT0Sgs+)nyLGV9$B7*x*EhNHG;WnIAIMhis@M zom9)qR1r_~80!qxOxLGaU4IH5u~(q*XOb>DLy@()eX~}}rgte&=Y+h$8j?v}-1+wY{#CW-EsF7& zv3q0q{W|$`6`TxqcGF2#r<4wG>Q)fU%=4)Nkl*q~W$D(1{+f%12rit`_<9VY1GJqp zLi3SAjdlN2d^kGQva&pFx~i^R+HBfXfa|F)CeT3dW>JssZrk80JOx=@*8G&739r>% zHbn5@lu)ulE|I6yEAEh1Yz2|tn~$KZtt{zeh37i8Z%tXpFO8CA-=g5I zt6#bTLr9EoJKY24y!yRy>9-$Xt0=9SAJ2qtA z-uR3{x6)bPBu_}*x4}0-^pW1PqOs(p+bzXvw0!;ed<3cgdsxGRkITRyVgnMW;;_?P zTX%nN?I7ddp~a^wa$oB1UN&qn$|xO{htns1xvTWC$4J^eH}*hWp4gxn%fL)zU=_=- z)GbmJCfJpaE(%R5Nim%f`=Jwi289`=K=au*#0oFW9y%k-iB@TYdrSpY;}vQBMwjp0 z(ArST?s6LPoERZFr3qwRY?*6`H{EwI7dJhD8;UlkxA5XbZE(e4y_WGkrEab0;YPt9 z_w}oPl`Xf0CXtX@WvJ+W0j5GarR7LWzOxCGV?1XholeHwVqgOdw5r-su<#*=R!w`LsFrquj3)LR9bCgs#VSnNIz{gtnre} zlBa|rqQ^?m(i5^a`y*wtafI&b3Hu*wQ$MCxX^t;h7M75INEEkAbS|3Ktt?)%z zAbRcyTOtk+IQ;sK@}Myfh_$DidZ+-gGzm2;3Di1ouh{pRO%xX@X0dQ9e^%D767w76 z^1x=C4j&||g)YsHe>xqqcaWBVuamD`I(cB|!W2BT8)Uz^LEAx1c8sT`w+%mweUGZv zF2Io^Ze%Q#;oLoJpC|DgFi>E4YTkOJ<)~09VSm1fn=th{YS<{AJd;3vf6J6wO_nXg zF+Kb=oOrsI(HrWH8E zWT9{OB<5JL{4OLdmwO2){3GW9x_4Wwv-w{*DI(&Cogo~T|<25subS0q235|;CH4?z^~7-JA@Eg1uFIHAKPV`R-I ztZa?ozc=uO60%8(R#{_^o|)F1;9hv0w~LuC;1e)O;qS>yTjPo?iB{1Or1D$q02c1( zh+^vZgEKi|?$)nBk@Z0(wVR`vYeT@x<;~$0e-h#y-YH?krs7?cMj;!q2!K`Cm857& zc+(>>*@OJ`@J#Rpp9kqGN&dT^BATZgYgd%Fg+*JbF36~$)gJhr#aEe=4Ow#e?|+o} z~?yn`+G=bM{pY9=PrsjJ3rIvUY^g-9>4YpPz7QH2Q&z#LjIkUPH zqMMcwt#%u8s-T7Q*P4#s%Ar=lm;MVHS1@D=*^aTVmS*uE6 zzdK#82y5vhhm(+HA^<#ro^5&kJImYQ#&h_*`4y12P^;g8kwg4OqTbH8pWp5B2kXe% z+neLMJKx+E%Im92`K28Zw0i|p;U$-P7<6gfYAtgOSLfXKgyXMdlzx%7#O{*Mz~dH( z091773Rij~^KVb*5WLFcLuOwR4v~wODUat@2xF`C8zV-@f}7Km<-41{)&yfABTPDW zssYEh=2I0-;!U&$a=Y+&|80uSj8H3gVZG9OdvO1dgW1CSuA1W7aT-#zhX&SFWj3m) z09w=3>jJ6U9U3QBzSqb@A*!g4+E>yYeC0I1VM*G1%UzCsmXqMF2L3*yaxa7gDzo9Z zpfC^p3#QKYgS$Onye`TS(W=xwHBR0r_~T7UpcC95HxSBNOXbj~^+f~a7lY$XGO(fJ z&m?eEVzxX_Ak}{OJChI&@C@zLGoX$!-z##K|Jdom_uoH--9PNX>Rzmwg7b^SF^7qi zmF>oB#F%P=?2~=36hbXLZgS2TojTU0Gn;?@$?FT)LiHq|6MLoV#I1KwF^9l z?!--E7wf4eonho;!k7AiO;X~Y68`P=`KStDVO#c@6jj;AtmC#NcKEDnQR^$qVtwHfM7oQ**Y18 z1B1U``|C5eM$4jwgk|e#Lj!!n$xaQHnxH^i-`EdCo;5SQENkxe2izj_uPgOxR1aPD zE`uzex=H=DSAixiTNIzxSAbcG_6_xto5P|u*Ch&;+#c=@@y;Qsyn!k70j2njwC!wQ zDbfQ#@m=FjuJKKW(&p;|8D@=^6xMs4TMV;C zmbK&{#oo7_?aqIQ$VMLr87yNn`0Y0=efNIggqy$k{k5JUd-#M~&ggfjXu66~S$b(2 zr1(%hx^6=9{%nNi($txZSb0cVt`a3?(Gh(V9wwz6}K5VTGn4X()T zoo*h#rU!fTa-QB|pYzK$Yc(w9@U{vFD8PKkK0EYt%jh%5rNztDos+`mMv#&cotUv6@i?p>(?;_+6NQGjMQry~rrxB6b&oPy3%0`e8+61K|0C6A)x1k6zT zhPrL|+PB^=-wR0|iOUf@q%@t)$+8#ST#hhOt+kXYFyTa3twDh$a;gXp#dZx zk2!0jU}u1s*qV#$9s8s3zE>_#N&uEEG6Bp@mAvH^A7@P77qZ#`?sj0pJ9{#2y4N?O z^!}}U3%j9#F2;{in7*Xvd`rrU(z8!WcC}`x`3{QA-lu}lV;f-RZNz@xwJNvJoj9k= zp0CuWaM(gFB!nYZGbE(?y=YU|zL}PO4{SUbQSQ4`>$j)^^W#*BTwT4Z6{2o~C%6Az z@u*}HUK*e1znY*I7_>wf4u}i!_igek-YF)#g1`2xCW^B?sl5~uvK4pcRmys^_Qdc& zk%wcb^Km|&jryVCf&MIw@4G>xHKSwsrMDM+P8v<@3-r`jk7yCP`h0nJm|+}|mZ9c8 z;^`_NS0Hd75ZLssAU!T7N`;Vr59JR-9QQaQE$!by@7bN$WDr8!zmZ{;h-he>VcnnBFfTTT>?aqE&3-)BP= zZy(zb`wWG^(hD3NSSR+}$v>p*`}_?C391l+BSrJ-+*cKJ-*(*_)LkCkfG*{Wl-^G! zm8bZ~MU$B;>;$Wfk+_n%j3IyDL8~0ivsvhX1hEE{{9;N*)Z%Zn)jF^h!O&gsamV3j zlYh!aStuI?bL#{Y->lxuINuV#tt}cGP^#UZJKL~0o=*;%AL>|3X1-Z!!~>fMAmu-yCv`M%@@1mn?w3v?aiNWyYhe|LEJAzWlb)*?~lI* zAbTr+$#vFjhFdS9xSo%AuiDXdbijV&s~Y4-CQ5XVob26I;0Py&vi96V7<^R#jFgy( z9OeLfvc+_JIFd!L$lo<^Z+>U(?i6kYpTQIw`Vh++oHh1)>?Cf?s`j6<4j!1$3fqqh zqrA6`3Xr1cX|$keUIY<$PE>kDeSavOu%VD_lJwQUJ}_s=!EoPriBtL`zsK>E%*1f{ zQ?aAg5MVnwch5p5Q0|f7MZW`+TpQfj3EbG148w_IzJ3pWp?8igpLvofnq7XA$ZEb* zR~;<{Y(t8MMD!7m2ZY^ zlX0dET1L4>KLK$8e{*|Ekeo}T1#P*>&M8z1)g;rZJ*fMD@Vra!@yY~va;qCJrotn6 z^G7g;;r{oMIPwofh}Uer;PRzuO5H<=fru?9q|@9pXloq2waIH|f{N)w65D@-?a<_| z3OaVORK0y^Zy6 z8!P7`Oo4F2y>qT49YA?y3n}nSnWhIQ5BwRkOd?8;85B2j!}Qcu4)>}2JM4+Gdw(K1 ztn+^`GxdMizRXV8c8r=#{qeSUiR37#WFk4u!~J-rYit%<@6*Bw!{B{L*d^23Ty{UU zb`_pS(puvh0#0P=$A1c@)$VdR!mng3?f+<{f?6NIkhAzjz7^Gf6mzzir3E80txStB zt#JP_W5L$hdbjn!Z-L=sT~R-kwYyo!-DN1bjX0gPsgOP;CiCE*qj;&1_40?%%3gsT z`BS6h!Gxb3Ksjv+lP@+=+EOlwlB0C#tc`ZW--#)?RA41com}3baoGS^ek^h=hzg`o z?tU=i&536H%6!uXaBF?fGhdgtPO?nrrJSOlKjxH~2VC{eZl_qrFu#?~8Ot+V$b$tBd? zpJMM1E}OQG-U0A;@IZ`H5v_feSxqfjLd?(*!6rI0;e6VTnQ{}6^*DYW>Pxv`&syav z!Xt^E23hacRS<{CGKS%~h4CBDK#GsBH{YmhY@?6+F`}nktaqIi4a0qeMy7BX#6Ma? zj(J!a1^VIiw*>xs>gh)(Qw|o)piYVZqp^VC}k zeg84=#TZQf__V-`<1Uc}%~=sNCK)sB`m<<^Mb*Y!rb9O3w~JybTy1j@e&>;p>{--2 z|GK6`YE#x;sUB}1ln>46v>M?cf%)7Cp5yO>0v6=P4^jHgkDsbWD`vSC#i~qq0=se+ z@}&7In#B#}4pugg`hnRV7BZ{Zt8i=Oj&ODl;zm2*RV?Rc5A$8(w~ynX;Y57f$==cI zy(%4!MW10en7cBlD5acS9L`Sev?JD)Qb>@wth`c7djYybvcoQ_>O$H4j$diQP3yB8qVtfL2h z{61>RuZh;)Sj?FPsbwDZlJ7`kK3$~uhib5U2;Y!VwRy&U2qbh*E7PXD ziI_VJ0>VfC@ll&kKwgBc@l~7 zQbYF*9gM7pC|hPCWqIwZ+@j`yETh8a5kJY%qzIQ=f+*0QdA5#QYyk{;@df%U$}6@= zO9^fcN#ZklGZ41+EWy05W$)8sAQv25AM<~Sh5tRt{eL*VKJf-c(>sp*Z8Fea!v!4H PgRWn-xKe58^61|HgYQ similarity] return [{'paragraph_id': uuid.UUID(hit['_source']['paragraph_id']), @@ -294,23 +292,21 @@ class BlendSearch(ISearch): top_number: int, similarity: float, search_mode: SearchMode): - filter_query["bool"]["should"] = { - "function_score": { - "query": {'match': {'search_vector': query_text}}, - "functions": [{ - 'script_score': { - 'script': { - 'source': "cosineSimilarity(params.query_vector, 'embedding') + 1.0", - 'params': {'query_vector': []} - } - } - }], - 'boost_mode': 'replace' - } - } + knn_filter = filter_query + filter_query['bool']['must'] = {"multi_match": { + "query": to_query(query_text), + "fields": ["search_vector"] + }} result = self.esVector.search(body={ - "sort": ["_score"], + "min_score": similarity, "query": filter_query, + "knn": { + "field": "embedding", + "query_vector": query_embedding, + "k": top_number, + "num_candidates": 2 * top_number, # 默认为2倍size + "filter": knn_filter + }, "_source": ["paragraph_id"] # 添加这一行 }, size=top_number) @@ -325,6 +321,4 @@ class BlendSearch(ISearch): self.esVector = esVector return search_mode.value == SearchMode.blend.value - -# 混合检索 search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()] diff --git a/conf/config_example.yml b/conf/config_example.yml new file mode 100644 index 00000000..e47f5781 --- /dev/null +++ b/conf/config_example.yml @@ -0,0 +1,29 @@ +# 数据库链接信息 +DB_NAME: maxkb +DB_HOST: 192.168.10.245 +DB_PORT: 5432 +DB_USER: postgres +DB_PASSWORD: nmghl@123 +DB_ENGINE: django.db.backends.postgresql_psycopg2 + +DEBUG: false + +TIME_ZONE: Asia/Shanghai + + +EMBEDDING_MODEL_NAME: BAAI/bge-m3 +EMBEDDING_DEVICE: cpu + +LOCAL_MODEL_HOST: 192.168.10.217 +LOCAL_MODEL_PORT: 11636 +LOCAL_MODEL_PROTOCOL: http + +# 使用的向量数据库 pg_vector | es_vector +VECTOR_STORE_NAME : es_vector +# es链接地址 +VECTOR_STORE_ES_ENDPOINT: http://192.168.10.252:9200 +# es库用户名和密码,默认不开启用于验证 +VECTOR_STORE_ES_USERNAME: default +VECTOR_STORE_ES_PASSWORD: default +# es索引库名称 +VECTOR_STORE_ES_INDEX: maxkb_vector \ No newline at end of file diff --git a/main.py b/main.py index a8bd74af..b27e7ae8 100644 --- a/main.py +++ b/main.py @@ -70,7 +70,7 @@ def start_services(): def dev(): services = args.services if isinstance(args.services, list) else args.services if services.__contains__('web'): - management.call_command('runserver', "0.0.0.0:8080") + management.call_command('runserver', "192.168.10.217:8080") elif services.__contains__('celery'): management.call_command('celery', 'celery') elif services.__contains__('local_model'): @@ -81,7 +81,8 @@ def dev(): if __name__ == '__main__': - os.environ['HF_HOME'] = '/opt/maxkb/model/base' + os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' + os.environ['HF_HOME'] = 'J:/workspace/openai/maxkb/model/base' parser = argparse.ArgumentParser( description=""" qabot service control tools; diff --git a/model/base/hub/version.txt b/model/base/hub/version.txt new file mode 100644 index 00000000..56a6051c --- /dev/null +++ b/model/base/hub/version.txt @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 37d1f61d..c93ed0be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dashscope = "^1.17.0" zhipuai = "^2.0.1" httpx = "^0.27.0" httpx-sse = "^0.4.0" -websockets = "^13.0" +websocket-client = "^1.7.0" langchain-google-genai = "^1.0.3" openpyxl = "^3.1.2" xlrd = "^2.0.1" @@ -60,6 +60,8 @@ celery = { extras = ["sqlalchemy"], version = "^5.4.0" } eventlet = "^0.36.1" django-celery-beat = "^2.6.0" celery-once = "^3.0.1" +# 加入elastic依赖 +elasticsearch = "^8.15.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/ui/src/enums/model.ts b/ui/src/enums/model.ts index 52ff935d..c27fc61f 100644 --- a/ui/src/enums/model.ts +++ b/ui/src/enums/model.ts @@ -9,9 +9,7 @@ export enum PermissionDesc { export enum modelType { EMBEDDING = '向量模型', - LLM = '大语言模型', - STT = '语音识别', - TTS = '语音合成', + LLM = '大语言模型' } -- Gitee