1 Star 0 Fork 8

衣沾不足惜/gitee-ai-docs-test

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
api_server.py 25.71 KB
一键复制 编辑 原始数据 按行查看 历史
衣沾不足惜 提交于 2024-07-16 13:46 +08:00 . update
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773
import os
import time
# from asyncio.log import logger
import re
import uvicorn
import gc
import json
import torch
import random
import string
import asyncio
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
from fastapi import FastAPI, HTTPException, Response
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, LogitsProcessor
from sse_starlette.sse import EventSourceResponse
from sentence_transformers import SentenceTransformer
import numpy as np
import tiktoken
import subprocess
import threading
import multiprocessing
import time
from huggingface_hub import snapshot_download
# queue = multiprocessing.Queue()
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
# Qwen1.5-32B-Chat-GPTQ-Int4 Qwen2-72B-Instruct-GPTQ-Int4 Qwen2-72B-Instruct-AWQ codegeex4-all-9b
# webui_command = ["python", "webui.py"]
# webui_process = subprocess.Popen(
# webui_command, text=True)
# def read_webui_process_output(process):
# """Reads the process output and prints it."""
# while True:
# output = process.stdout.readline()
# if output == '' and process.poll() is not None:
# break
# if output:
# print("Web UI: "+output.strip())
# 启动一个线程读取 subprocess 输出
# webui_thread = threading.Thread(
# target=read_webui_process_output, args=(webui_process,))
# webui_thread.start()
MODEL_PATH = os.environ.get(
'MODEL_PATH', 'hf-models/glm-4-9b-chat')
MAX_MODEL_LENGTH = 20000 # 131072 65536 32768
# /v1/embeddings
EMBEDDING_PATH = os.environ.get(
'EMBEDDING_PATH', 'hf-models/bge-m3')
MODEL_ID = "gitee"
# 改为 snapshot_download 下载
MODEL_PATH = snapshot_download(
repo_id=MODEL_PATH, repo_type="model", local_dir="./glm-4-9b-chat", etag_timeout=60, max_workers=2)
EMBEDDING_PATH = snapshot_download(repo_id=EMBEDDING_PATH,
repo_type="model", local_dir="./bge-m3", etag_timeout=60, max_workers=2)
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def generate_id(prefix: str, k=29) -> str:
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=k))
return f"{prefix}{suffix}"
class ModelCard(BaseModel):
id: str = ""
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = [MODEL_ID]
class FunctionCall(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
class ChoiceDeltaToolCallFunction(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionMessageToolCall(BaseModel):
index: Optional[int] = 0
id: Optional[str] = None
function: FunctionCall
type: Optional[Literal["function"]] = 'function'
class ChatMessage(BaseModel):
# “function” 字段解释:
# 使用较老的OpenAI API版本需要注意在这里添加 function 字段并在 process_messages函数中添加相应角色转换逻辑为 observation
role: Literal["user", "assistant", "system", "tool"]
content: Optional[str] = None
function_call: Optional[ChoiceDeltaToolCallFunction] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
function_call: Optional[ChoiceDeltaToolCallFunction] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "tool_calls"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str
id: Optional[str] = Field(
default_factory=lambda: generate_id('chatcmpl-', 29))
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
system_fingerprint: Optional[str] = Field(
default_factory=lambda: generate_id('fp_', 9))
usage: Optional[UsageInfo] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
tool_choice: Optional[Union[str, dict]] = None
repetition_penalty: Optional[float] = 1.1
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
def process_response(output: str, tools: dict | List[dict] = None, use_tool: bool = False) -> Union[str, dict]:
lines = output.strip().split("\n")
arguments_json = None
special_tools = ["cogview", "simple_browser"]
tools = {tool['function']['name'] for tool in tools} if tools else {}
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
# TODO 如果你希望做更多判断,可以在这里进行逻辑完善。
if len(lines) >= 2 and lines[1].startswith("{"):
function_name = lines[0].strip()
arguments = "\n".join(lines[1:]).strip()
if function_name in tools or function_name in special_tools:
try:
arguments_json = json.loads(arguments)
is_tool_call = True
except json.JSONDecodeError:
is_tool_call = function_name in special_tools
if is_tool_call and use_tool:
content = {
"name": function_name,
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
ensure_ascii=False)
}
if function_name == "simple_browser":
search_pattern = re.compile(
r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
match = search_pattern.match(arguments)
if match:
content["arguments"] = json.dumps({
"query": match.group(1),
"recency_days": int(match.group(2))
}, ensure_ascii=False)
elif function_name == "cogview":
content["arguments"] = json.dumps({
"prompt": arguments
}, ensure_ascii=False)
return content
return output.strip()
@torch.inference_mode()
async def generate_stream(params):
messages = params["messages"]
tools = params["tools"]
tool_choice = params["tool_choice"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_tokens", 8192))
messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False)
params_dict = {
"n": 1,
"best_of": 1,
"presence_penalty": 1.0,
"frequency_penalty": 0.0,
"temperature": temperature,
"top_p": top_p,
"top_k": -1,
"repetition_penalty": repetition_penalty,
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
# [151329, 151336, 151338], [tokenizer.eos_token_id]
"stop_token_ids": [151329, 151336, 151338],
"ignore_eos": False,
"max_tokens": max_new_tokens,
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
}
sampling_params = SamplingParams(**params_dict)
async for output in engine.generate(inputs, sampling_params, f"{time.time()}"):
output_len = len(output.outputs[0].token_ids)
input_len = len(output.prompt_token_ids)
ret = {
"text": output.outputs[0].text,
"usage": {
"prompt_tokens": input_len,
"completion_tokens": output_len,
"total_tokens": output_len + input_len
},
"finish_reason": output.outputs[0].finish_reason,
}
yield ret
gc.collect()
torch.cuda.empty_cache()
def process_messages(messages, tools=None, tool_choice="none"):
_messages = messages
processed_messages = []
msg_has_sys = False
def filter_tools(tool_choice, tools):
function_name = tool_choice.get('function', {}).get('name', None)
if not function_name:
return []
filtered_tools = [
tool for tool in tools
if tool.get('function', {}).get('name') == function_name
]
return filtered_tools
if tool_choice != "none":
if isinstance(tool_choice, dict):
tools = filter_tools(tool_choice, tools)
if tools:
processed_messages.append(
{
"role": "system",
"content": None,
"tools": tools
}
)
msg_has_sys = True
if isinstance(tool_choice, dict) and tools:
processed_messages.append(
{
"role": "assistant",
"metadata": tool_choice["function"]["name"],
"content": ""
}
)
for m in _messages:
role, content, func_call = m.role, m.content, m.function_call
tool_calls = getattr(m, 'tool_calls', None)
if role == "function":
processed_messages.append(
{
"role": "observation",
"content": content
}
)
elif role == "tool":
processed_messages.append(
{
"role": "observation",
"content": content,
"function_call": True
}
)
elif role == "assistant":
if tool_calls:
for tool_call in tool_calls:
processed_messages.append(
{
"role": "assistant",
"metadata": tool_call.function.name,
"content": tool_call.function.arguments
}
)
else:
for response in content.split("\n"):
if "\n" in response:
metadata, sub_content = response.split(
"\n", maxsplit=1)
else:
metadata, sub_content = "", response
processed_messages.append(
{
"role": role,
"metadata": metadata,
"content": sub_content.strip()
}
)
else:
if role == "system" and msg_has_sys:
msg_has_sys = False
continue
processed_messages.append({"role": role, "content": content})
if not tools or tool_choice == "none":
for m in _messages:
if m.role == 'system':
processed_messages.insert(
0, {"role": m.role, "content": m.content})
break
return processed_messages
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/v1/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(id=MODEL_ID)
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
gen_params = dict(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=request.stream,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
tool_choice=request.tool_choice,
)
# logger.debug(f"==== request ====\n{gen_params}")
if request.stream:
predict_stream_generator = predict_stream(request.model, gen_params)
output = await anext(predict_stream_generator)
if output:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
# logger.debug(f"First result output:\n{output}")
function_call = None
if output and request.tools:
try:
function_call = process_response(
output, request.tools, use_tool=True)
except:
print("Failed to parse tool call")
if isinstance(function_call, dict):
function_call = ChoiceDeltaToolCallFunction(**function_call)
generate = parse_output_text(
request.model, output, function_call=function_call)
return EventSourceResponse(generate, media_type="text/event-stream")
else:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
response = ""
async for response in generate_stream(gen_params):
pass
if response["text"].startswith("\n"):
response["text"] = response["text"][1:]
response["text"] = response["text"].strip()
usage = UsageInfo()
function_call, finish_reason = None, "stop"
tool_calls = None
if request.tools:
try:
function_call = process_response(
response["text"], request.tools, use_tool=True)
except Exception as e:
print(f"Failed to parse tool call: {e}")
if isinstance(function_call, dict):
finish_reason = "tool_calls"
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
function_call_instance = FunctionCall(
name=function_call_response.name,
arguments=function_call_response.arguments
)
tool_calls = [
ChatCompletionMessageToolCall(
id=generate_id('call_', 24),
function=function_call_instance,
type="function")]
message = ChatMessage(
role="assistant",
content=None if tool_calls else response["text"],
function_call=None,
tool_calls=tool_calls,
)
print(f"==== message ====\n{message}")
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason=finish_reason,
)
task_usage = UsageInfo.model_validate(response["usage"])
for usage_key, usage_value in task_usage.model_dump().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(
model=request.model,
choices=[choice_data],
object="chat.completion",
usage=usage
)
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class EmbeddingResponse(BaseModel):
data: list
model: str
object: str
usage: CompletionUsage
class EmbeddingRequest(BaseModel):
input: List[str]
model: str
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base')
num_tokens = len(encoding.encode(string))
return num_tokens
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):
if isinstance(request.input, str):
embeddings = [embedding_model.encode(request.input)]
else:
embeddings = [embedding_model.encode(text) for text in request.input]
embeddings = [embedding.tolist() for embedding in embeddings]
def num_tokens_from_string(string: str) -> int:
"""
Returns the number of tokens in a text string.
use cl100k_base tokenizer
"""
encoding = tiktoken.get_encoding('cl100k_base')
num_tokens = len(encoding.encode(string))
return num_tokens
response = {
"data": [
{
"object": "embedding",
"embedding": embedding,
"index": index
}
for index, embedding in enumerate(embeddings)
],
"model": request.model,
"object": "list",
"usage": CompletionUsage(
prompt_tokens=sum(len(text.split()) for text in request.input),
completion_tokens=0,
total_tokens=sum(num_tokens_from_string(text)
for text in request.input),
)
}
return response
async def predict_stream(model_id, gen_params):
output = ""
is_function_call = False
has_send_first_chunk = False
created_time = int(time.time())
function_name = None
response_id = generate_id('chatcmpl-', 29)
system_fingerprint = generate_id('fp_', 9)
tools = {tool['function']['name']
for tool in gen_params['tools']} if gen_params['tools'] else {}
async for new_response in generate_stream(gen_params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):]
output = decoded_unicode
lines = output.strip().split("\n")
# 检查是否为工具
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
# TODO 如果你希望做更多处理,可以在这里进行逻辑完善。
if not is_function_call and len(lines) >= 2:
first_line = lines[0].strip()
if first_line in tools:
is_function_call = True
function_name = first_line
# 工具调用返回
if is_function_call:
if not has_send_first_chunk:
function_call = {"name": function_name, "arguments": ""}
tool_call = ChatCompletionMessageToolCall(
index=0,
id=generate_id('call_', 24),
function=FunctionCall(**function_call),
type="function"
)
message = DeltaMessage(
content=None,
role="assistant",
function_call=None,
tool_calls=[tool_call]
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield ""
yield chunk.model_dump_json(exclude_unset=True)
has_send_first_chunk = True
function_call = {"name": None, "arguments": delta_text}
tool_call = ChatCompletionMessageToolCall(
index=0,
id=None,
function=FunctionCall(**function_call),
type="function"
)
message = DeltaMessage(
content=None,
role=None,
function_call=None,
tool_calls=[tool_call]
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
# 用户请求了 Function Call 但是框架还没确定是否为Function Call
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
continue
# 常规返回
else:
finish_reason = new_response.get("finish_reason", None)
if not has_send_first_chunk:
message = DeltaMessage(
content="",
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
has_send_first_chunk = True
message = DeltaMessage(
content=delta_text,
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
# 工具调用需要额外返回一个字段以对齐 OpenAI 接口
if is_function_call:
yield ChatCompletionResponse(
model=model_id,
id=response_id,
system_fingerprint=system_fingerprint,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(
content=None,
role=None,
function_call=None,
),
finish_reason="tool_calls"
)],
created=created_time,
object="chat.completion.chunk",
usage=None
).model_dump_json(exclude_unset=True)
yield '[DONE]'
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
delta = DeltaMessage(role="assistant", content=value)
if function_call is not None:
delta.function_call = function_call
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=delta,
finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id,
choices=[choice_data],
object="chat.completion.chunk"
)
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]'
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH, trust_remote_code=True, use_fast=True)
embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
tokenizer=MODEL_PATH,
# 如果你有多张显卡,可以在这里设置成你的显卡数量
tensor_parallel_size=2,
dtype="float16",
# quantization="gptq",
trust_remote_code=True,
gpu_memory_utilization=0.7,
enforce_eager=True,
worker_use_ray=False,
engine_use_ray=False,
disable_log_requests=True,
max_model_len=MAX_MODEL_LENGTH,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
# jupyter 使用以下方法解决 asyncio.run() cannot be called from a running event loop,jupyter 已内置事件循环
# config = uvicorn.Config(app, host='0.0.0.0', port=8000, workers=4)
# server = uvicorn.Server(config)
print("Starting server...")
uvicorn.run(app="api_server:app", host='0.0.0.0', port=8000, workers=2)
# await server.serve()
# queue.put('STARTED')
# 使用 vllm cli 性能更好但目前设备兼容的 vllm 0.3.3 兼容的 openai 不支持 /v1/embeddings
# 正常运行
# async def main():
# uvicorn_config = {
# 'app': app,
# 'host': '0.0.0.0',
# 'port': 8000,
# 'workers': 2,
# 'loop': 'asyncio',
# }
# await server.serve()
# server = uvicorn.Server(**uvicorn_config)
# asyncio.run(main())
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/stringify/gitee-ai-docs-test.git
git@gitee.com:stringify/gitee-ai-docs-test.git
stringify
gitee-ai-docs-test
gitee-ai-docs-test
master

搜索帮助