From 7255760a8c5dfceb8fba637452966b2fd8ae5d74 Mon Sep 17 00:00:00 2001 From: huxinjia Date: Fri, 29 Nov 2024 18:11:12 +0800 Subject: [PATCH 01/11] add app.py --- .../samples/travel_agent_demo/front/app.py | 105 ++++++++++++++++++ .../samples/travel_agent_demo/travelagent.py | 13 ++- 2 files changed, 113 insertions(+), 5 deletions(-) create mode 100644 mxAgent/samples/travel_agent_demo/front/app.py diff --git a/mxAgent/samples/travel_agent_demo/front/app.py b/mxAgent/samples/travel_agent_demo/front/app.py new file mode 100644 index 000000000..96c8e8f9b --- /dev/null +++ b/mxAgent/samples/travel_agent_demo/front/app.py @@ -0,0 +1,105 @@ + +import argparse +import gradio as gr +from mxAgent. samples.travel_agent_demo.travelagent import TravelAgent + + +global mx_agent + +def agent_run(messgae): + query = messgae.get("content") + TravelAgent + + +def user_query(user_message, history): + return "", history + [{"role": "user", "content":user_message}] + + +def clear_history(history): + return [] + +def bot_response(history): + # 将最新的问题传给RAG + try: + response = agent_run(history[-1]) + + # 返回迭代器 + history += [{"role": "assistant", "content":""}] + + history[-1]["content"] = '推理错误' + for res in response: + history[-1]["content"] = '推理错误' if res['result'] is None else res['result'] + yield history + yield history + print(history) + except Exception as err: + history[-1]["content"] = "推理错误" + yield history + +def build_demo(): + with gr.Blocks() as demo: + gr.HTML("""

旅行规划Agent

+

例如:从北京到西安旅游规划

+

例如:西安有哪些免费的博物馆景点

+

例如:查一下西安的酒店

+ + """) + with gr.Row(): + with gr.Column(scale=200): + + initial_msg = [ {"role": "assistant", + "content": "这条消息下想说明的是:如果 Chatbot 的 type 参数为 'messages',那么发送到/从 Chatbot " + "的数据将是一个包含 role 和 content 键的字典列表。这种格式符合大多数 LLM API(如 " + "HuggingChat、OpenAI、Claude)期望的格式。role 键可以是 'user' 或 " + "'assistant',content 键可以是一个字符串(支持 markdown/html 格式),一个 " + "FileDataDict(用于表示在 Chatbot 中显示的文件),或者一个 gradio 组件。"} + ], + + # chatbot = gr.Chatbot(initial_msg, type="messages",) + chatbot = gr.Chatbot( + [ + + {"role": "assistant", "content": "你好,我是你的AI小助手,这是你自己预制的一个信息。"}, + {"role": "assistant", "content": "这条消息下想说明的是:如果 Chatbot 的 type 参数为 'messages',那么发送到/从 Chatbot " + "的数据将是一个包含 role 和 content 键的字典列表。这种格式符合大多数 LLM API(如 " + "HuggingChat、OpenAI、Claude)期望的格式。role 键可以是 'user' 或 " + "'assistant',content 键可以是一个字符串(支持 markdown/html 格式),一个 " + "FileDataDict(用于表示在 Chatbot 中显示的文件),或者一个 gradio 组件。"} + ], + type="messages", + show_label=False, + height=500, + show_copy_button=True + + + ) + + with gr.Row(): + msg = gr.Textbox(placeholder="在此输入问题...", container=False) + with gr.Row(): + send_btn = gr.Button(value="发送", variant="primary") + clean_btn = gr.Button(value="清空历史") + send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False).then(bot_response, + [chatbot], chatbot) + clean_btn.click(clear_history, chatbot, chatbot) + return demo + + + + +def get_args(): + parse = argparse.ArgumentParser() + parse.add_argument("--model_name", type=str, default="Qwen1.5-32B-Chat", help="OpenAI客户端模型名") + parse.add_argument("--base_url", type=str, default="http://10.44.115.108:1055/v1", help="OpenAI客户端模型地址") + parse.add_argument("--api_key", type=str, default="EMPTY", help="OpenAI客户端api key") + return parse.parse_args().__dict__ + +if __name__ == "__main__": + args = get_args() + base_url = args.pop("base_url") + api_key = args.pop("api_key") + llm_name = args.pop("model_name") + + mx_agent = TravelAgent(base_url, api_key, llm_name) + demo = build_demo() + demo.launch(share=True) \ No newline at end of file diff --git a/mxAgent/samples/travel_agent_demo/travelagent.py b/mxAgent/samples/travel_agent_demo/travelagent.py index 06edca9dc..86081e4a5 100644 --- a/mxAgent/samples/travel_agent_demo/travelagent.py +++ b/mxAgent/samples/travel_agent_demo/travelagent.py @@ -143,15 +143,18 @@ class TalkShowAgent(ToollessAgent, ABC): class TravelAgent: - @classmethod - def route_query(cls, query): - router_agent = RouterAgent(llm=llm, intents=intents) + def __init__(self, base_url, api_key, llm_name): + self.llm = get_llm_backend(backend=BACKEND_OPENAI_COMPATIBLE, + base_url=base_url, api_key=api_key, llm_name=llm_name).run + + def route_query(self, query): + router_agent = RouterAgent(llm=self.llm, intents=intents) classify = router_agent.run(query).answer if classify not in classifer or classify == OTHERS: - return TalkShowAgent(llm=llm) + return TalkShowAgent(llm=self.llm) return RecipeAgent(name=classify, description="你的名字叫昇腾智搜,是一个帮助用户完成旅行规划的助手,你的能力范围包括:目的地推荐、行程规划、交通信息查询、酒店住宿推荐、旅行攻略推荐", - llm=llm, + llm=self.llm, tool_list=TOOL_LIST_MAP[classify], recipe=INST_MAP[classify], max_steps=3, -- Gitee From 39e76213eaccb887d6508598781b42d4d1b41319 Mon Sep 17 00:00:00 2001 From: huxinjia Date: Sat, 30 Nov 2024 15:08:10 +0800 Subject: [PATCH 02/11] optimize tools --- mxAgent/samples/tools/common.py | 10 ++-------- mxAgent/samples/tools/duck_search.py | 2 +- mxAgent/samples/tools/tool_query_accommodations.py | 13 ++++++++----- mxAgent/samples/tools/tool_query_attractions.py | 14 +++++++++----- mxAgent/samples/tools/tool_query_city.py | 7 ------- mxAgent/samples/tools/tool_query_transports.py | 12 ++++++++---- 6 files changed, 28 insertions(+), 30 deletions(-) diff --git a/mxAgent/samples/tools/common.py b/mxAgent/samples/tools/common.py index 2ed383491..5e1ca6542 100644 --- a/mxAgent/samples/tools/common.py +++ b/mxAgent/samples/tools/common.py @@ -1,8 +1,5 @@ -import json -from samples.tools.web_summary_api import WebSummary - -def get_website_summary(keys, prompt, llm): +def filter_website_keywords(keys): filtered = [] for val in keys: if val is None or len(val) == 0: @@ -16,10 +13,7 @@ def get_website_summary(keys, prompt, llm): if len(filtered) == 0: raise Exception("keywords has no been found") - - webs = WebSummary.web_summary( - filtered, search_num=3, summary_num=3, summary_prompt=prompt, llm=llm) - return json.dumps(webs, ensure_ascii=False) + return filtered def flatten(nested_list): diff --git a/mxAgent/samples/tools/duck_search.py b/mxAgent/samples/tools/duck_search.py index 7da42455e..9beba59bb 100644 --- a/mxAgent/samples/tools/duck_search.py +++ b/mxAgent/samples/tools/duck_search.py @@ -5,7 +5,7 @@ import re from langchain_community.tools import DuckDuckGoSearchResults from loguru import logger -from toolmngt.api import API +from agent_sdk.toolmngt.api import API class DuckDuckGoSearch(API): diff --git a/mxAgent/samples/tools/tool_query_accommodations.py b/mxAgent/samples/tools/tool_query_accommodations.py index 3c80febeb..be2615bf7 100644 --- a/mxAgent/samples/tools/tool_query_accommodations.py +++ b/mxAgent/samples/tools/tool_query_accommodations.py @@ -8,8 +8,8 @@ from loguru import logger from agent_sdk.toolmngt.api import API from agent_sdk.toolmngt.tool_manager import ToolManager -from samples.tools.common import get_website_summary - +from samples.tools.common import filter_website_keywords +from samples.tools.web_summary_api import WebSummary @ToolManager.register_tool() class QueryAccommodations(API): @@ -44,7 +44,7 @@ class QueryAccommodations(API): position = input_parameter.get("position") rank = input_parameter.get("rank") llm = kwargs.get("llm", None) - keys = [destination, position, rank, "住宿"] + keys = [destination, position, rank] logger.debug(f"search accommodation key words: {','.join(keys)}") prompt = """你是一个擅长文字处理和信息总结的智能助手,你的任务是将提供的网页信息进行总结,并以精简的文本的形式进行返回, @@ -59,8 +59,11 @@ class QueryAccommodations(API): 请生成总结: """ try: - content = get_website_summary(keys, prompt, llm) - res = {"accommodation": content} + filtered = filter_website_keywords(keys) + filtered.append("住宿") + webs = WebSummary.web_summary( + filtered, search_num=3, summary_num=3, summary_prompt=prompt, llm=llm) + res = {"accommodation": json.dumps(webs)} return self.make_response(input_parameter, results=res, exception="") except Exception as e: logger.error(e) diff --git a/mxAgent/samples/tools/tool_query_attractions.py b/mxAgent/samples/tools/tool_query_attractions.py index de83491c9..ccc193ec8 100644 --- a/mxAgent/samples/tools/tool_query_attractions.py +++ b/mxAgent/samples/tools/tool_query_attractions.py @@ -2,12 +2,13 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. import tiktoken -import yaml +import json from loguru import logger from agent_sdk.toolmngt.api import API from agent_sdk.toolmngt.tool_manager import ToolManager -from samples.tools.common import get_website_summary +from samples.tools.common import filter_website_keywords +from samples.tools.web_summary_api import WebSummary @ToolManager.register_tool() @@ -51,7 +52,7 @@ class QueryAttractions(API): requirement = input_parameter.get('requirement') llm = kwargs.get("llm", None) - keys = [destination, scene, scene_type, requirement, "景点"] + keys = [destination, scene, scene_type, requirement] summary_prompt = """你是一个擅长于网页信息总结的智能助手,提供的网页是关于旅游规划的信息,现在已经从网页中获取到了相关的文字内容信息,你需要从网页中找到与**景区**介绍相关的内容,并进行提取, 你务必保证提取的内容都来自所提供的文本,保证结果的客观性,真实性。 网页中可能包含多个景点的介绍,你需要以YAML文件的格式返回,每个景点的返回的参数和格式如下: @@ -68,8 +69,11 @@ class QueryAttractions(API): 请开始生成: """ try: - content = get_website_summary(keys, summary_prompt, llm) - res = {'attractions': content} + filtered = filter_website_keywords(keys) + filtered.append("景点") + webs = WebSummary.web_summary( + filtered, search_num=3, summary_num=3, summary_prompt=summary_prompt, llm=llm) + res = {'attractions': json.dumps(webs)} return self.make_response(input_parameter, results=res, exception="") except Exception as e: logger.error(e) diff --git a/mxAgent/samples/tools/tool_query_city.py b/mxAgent/samples/tools/tool_query_city.py index bc20932dd..e7f325216 100644 --- a/mxAgent/samples/tools/tool_query_city.py +++ b/mxAgent/samples/tools/tool_query_city.py @@ -25,13 +25,6 @@ class CitySearch(API): "city": {'type': 'str', 'description': "the name of the city in the state"} } - usage = f"""{name}[state]: - Description: This api can be used to retrieve cities in your target state. - Parameter: - state: The name of the state where you're finding cities. - Example: {name}[state: New York] would return cities in New York. - """ - example = ( """ { diff --git a/mxAgent/samples/tools/tool_query_transports.py b/mxAgent/samples/tools/tool_query_transports.py index 74bb5cc12..bd693028f 100644 --- a/mxAgent/samples/tools/tool_query_transports.py +++ b/mxAgent/samples/tools/tool_query_transports.py @@ -8,7 +8,8 @@ import tiktoken from agent_sdk.toolmngt.api import API from agent_sdk.toolmngt.tool_manager import ToolManager -from samples.tools.common import get_website_summary +from samples.tools.common import filter_website_keywords +from samples.tools.web_summary_api import WebSummary @ToolManager.register_tool() @@ -50,7 +51,7 @@ class QueryTransports(API): try: prefix = f"从{origin}出发" if origin else "" prefix += f"前往{destination}" if destination else "" - keys = [prefix, req, travel_mode, "购票"] + keys = [prefix, req, travel_mode] prompt = """你的任务是将提供的网页信息进行总结,并以精简的文本的形式进行返回, 请添加适当的词语,使得语句内容连贯,通顺。输入是为用户查询的航班、高铁等交通数据,请将这些信息总结 @@ -60,8 +61,11 @@ class QueryTransports(API): {input} 请生成总结: """ - content = get_website_summary(keys, prompt, llm) - res = {'transport': content} + filtered = filter_website_keywords(keys) + filtered.append("购票") + webs = WebSummary.web_summary( + filtered, search_num=3, summary_num=3, summary_prompt=prompt, llm=llm) + res = {'transport': json.dumps(webs)} return self.make_response(input_parameter, results=res, exception="") except Exception as e: logger.error(e) -- Gitee From 423ad0f1f63c4dfda91e517a3e796a3f6b218a5a Mon Sep 17 00:00:00 2001 From: huxinjia Date: Sat, 30 Nov 2024 15:15:50 +0800 Subject: [PATCH 03/11] fix samples --- mxAgent/agent_sdk/agentchain/recipe_agent.py | 2 +- mxAgent/samples/tools/tool_query_attractions.py | 2 +- mxAgent/samples/travel_agent_demo/travelagent.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mxAgent/agent_sdk/agentchain/recipe_agent.py b/mxAgent/agent_sdk/agentchain/recipe_agent.py index ec20af472..c12a1d6f0 100644 --- a/mxAgent/agent_sdk/agentchain/recipe_agent.py +++ b/mxAgent/agent_sdk/agentchain/recipe_agent.py @@ -200,7 +200,7 @@ SUGGESTION = """1. 翻译结果请严格按照YAML格式输出,不要添加任 5. 每个节点的dependency字段必须准确,能匹配伪代码中的依赖关系逻辑,dependency的节点必须是存在的节点 6. 每个节点的input字段必须有参数,每个节点的input字段名务必准确,必须是工具有的参数名,input字段中的每个参数输入值必须且只能是具体值或者依赖节点的工具输出参数, 不要使用python代码或者其他表达式, -7. 每个节点的input字段的每个参数值,优先使用依赖节点的工具输出参数,若无法通过依赖得到可以问题中提取,若存在多个答案,请使用加号+隔开, +7. 每个节点的input字段的每个参数值,优先使用依赖节点的工具输出参数,若无法通过依赖得到可以从问题中提取,若存在多个答案,请使用加号+隔开,若仍无法得到参数,统一使用【无】 8. 【伪代码】的步骤:一个步骤只能翻译成一个对应的节点 9. 生成的内容请严格遵循YAML的语法和格式 """ diff --git a/mxAgent/samples/tools/tool_query_attractions.py b/mxAgent/samples/tools/tool_query_attractions.py index ccc193ec8..eb54cb042 100644 --- a/mxAgent/samples/tools/tool_query_attractions.py +++ b/mxAgent/samples/tools/tool_query_attractions.py @@ -39,7 +39,7 @@ class QueryAttractions(API): "destination": "Paris", "scene": "The Louvre Museum", "type": "Museum", - "requirement": "historical" + "requirement": "free" }""") def __init__(self): diff --git a/mxAgent/samples/travel_agent_demo/travelagent.py b/mxAgent/samples/travel_agent_demo/travelagent.py index 86081e4a5..e57a27129 100644 --- a/mxAgent/samples/travel_agent_demo/travelagent.py +++ b/mxAgent/samples/travel_agent_demo/travelagent.py @@ -183,7 +183,7 @@ if __name__ == "__main__": base_url=base_url, api_key=api_key, llm_name=llm_name).run query = "帮我制定一份从北京到上海6天的旅游计划" - travel_agent = TravelAgent() + travel_agent = TravelAgent(base_url, api_key, llm_name) res = travel_agent.run(query, stream=False) if isinstance(res, AgentRunResult): logger.info("-----------run agent success-------------") -- Gitee From 922e3aa603101407e1c021ccd1c2573b52e3d6e5 Mon Sep 17 00:00:00 2001 From: huxinjia Date: Mon, 2 Dec 2024 11:06:17 +0800 Subject: [PATCH 04/11] add openmind --- mxAgent/app.py | 113 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 mxAgent/app.py diff --git a/mxAgent/app.py b/mxAgent/app.py new file mode 100644 index 000000000..170296a09 --- /dev/null +++ b/mxAgent/app.py @@ -0,0 +1,113 @@ +import os +from threading import Thread +from typing import Iterator, List, Tuple + +import gradio as gr +import torch +from openmind import AutoModelForCausalLM, AutoTokenizer +from transformers import TextIteratorStreamer + +MAX_MAX_NEW_TOKENS = 2048 +DEFAULT_MAX_NEW_TOKENS = 1024 +MAX_INPUT_TOKEN_LENGTH = 4096 +MAX_HISTORY_LENGTH = 5 # 限制历史对话的长度 + +DESCRIPTION = """\ +# Qwen 中文对话模型 +这个space使用模型 [OpenSource/Qwen2-0.5B-Instruct](https://modelers.cn/models/OpenSource/Qwen2-0.5B-Instruct)运行在cpu上的一个demo +""" + +device = torch.device('npu') +model_name = 'KunLun/Qwen1.5-32B-chat' +model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device) +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + +def generate( + user_message: str, + chat_history: List[Tuple[str, str]], + system_prompt: str, + max_new_tokens: int = 1024, + temperature: float = 0.6, + top_p: float = 0.9, + top_k: int = 50, + repetition_penalty: float = 1.2, +) -> Iterator[str]: + # 限制历史记录的长度 + print(user_message) + if len(chat_history) > MAX_HISTORY_LENGTH: + chat_history = chat_history[-MAX_HISTORY_LENGTH:] + + conversation = [] + if system_prompt: + conversation.append({"role": "system", "content": system_prompt}) + conversation.append(user_message) + + + input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") + if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: + input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] + gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") + input_ids = input_ids.to(model.device) + + streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) + generate_kwargs = dict( + {"input_ids": input_ids}, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=True, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_beams=1, + repetition_penalty=repetition_penalty, + ) + + t = Thread(target=model.generate, kwargs=generate_kwargs) + t.start() + + outputs = [] + for text in streamer: + outputs.append(text) + yield "".join(outputs) + + +chat_interface = gr.ChatInterface( + fn=generate, + stop_btn=None, + examples=[ + ["作为程序员该如何防止脱发?"], + ["请给我简短的介绍一下python语言的历史?"], + ["写一篇关于'如何参与到开源贡献社区'的博客"], + ["给我一份自驾游西藏的攻略"], + ["写一篇赞美老师的诗词"], + ], + cache_examples=False, +) + +css = """ +h1 { + text-align: center; + display: block; +} + +#duplicate-button { + margin: auto; + color: white; + background: #1565c0; + border-radius: 100vh; +} + +.contain { + max-width: 900px; + margin: auto; + padding-top: 1.5rem; +} +""" + +with gr.Blocks(css=css, fill_height=True) as demo: + gr.Markdown(DESCRIPTION) + chat_interface.render() + +if __name__ == "__main__": + demo.queue(max_size=20).launch() \ No newline at end of file -- Gitee From 436017b9579cf9ebe7f51305fb25e5fc44727951 Mon Sep 17 00:00:00 2001 From: huxinjia Date: Tue, 3 Dec 2024 16:18:27 +0800 Subject: [PATCH 05/11] support the local model --- mxAgent/app.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/mxAgent/app.py b/mxAgent/app.py index 170296a09..0722c8b28 100644 --- a/mxAgent/app.py +++ b/mxAgent/app.py @@ -26,12 +26,13 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) def generate( user_message: str, chat_history: List[Tuple[str, str]], - system_prompt: str, + system_prompt: str = "", max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, + stream: bool = False ) -> Iterator[str]: # 限制历史记录的长度 print(user_message) @@ -41,7 +42,7 @@ def generate( conversation = [] if system_prompt: conversation.append({"role": "system", "content": system_prompt}) - conversation.append(user_message) + conversation.append({"role": "user", "content": user_message}) input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") @@ -49,11 +50,8 @@ def generate( input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) - - streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, - streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, @@ -62,7 +60,18 @@ def generate( num_beams=1, repetition_penalty=repetition_penalty, ) + if stream: + return stream_generate(generate_kwargs) + output_ids = model.generate(generate_kwargs) + output_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, output_ids) + ] + return tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] + +def stream_generate(generate_kwargs): + streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) + generate_kwargs['streamer'] = streamer t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() @@ -75,6 +84,7 @@ def generate( chat_interface = gr.ChatInterface( fn=generate, stop_btn=None, + type="messages" examples=[ ["作为程序员该如何防止脱发?"], ["请给我简短的介绍一下python语言的历史?"], -- Gitee From 546318edd9fce5e122bce6399f8cd086d0adb683 Mon Sep 17 00:00:00 2001 From: huxinjia Date: Tue, 10 Dec 2024 10:23:53 +0800 Subject: [PATCH 06/11] fix grammar --- mxAgent/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mxAgent/app.py b/mxAgent/app.py index 0722c8b28..3a52554db 100644 --- a/mxAgent/app.py +++ b/mxAgent/app.py @@ -84,7 +84,7 @@ def stream_generate(generate_kwargs): chat_interface = gr.ChatInterface( fn=generate, stop_btn=None, - type="messages" + type="messages", examples=[ ["作为程序员该如何防止脱发?"], ["请给我简短的介绍一下python语言的历史?"], -- Gitee From 2832000480fbeca60c5d53340e42b8123a3e3318 Mon Sep 17 00:00:00 2001 From: hu-xinjia Date: Fri, 13 Dec 2024 17:07:02 +0800 Subject: [PATCH 07/11] fix varible name --- mxAgent/agent_sdk/prompts/pre_prompt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mxAgent/agent_sdk/prompts/pre_prompt.py b/mxAgent/agent_sdk/prompts/pre_prompt.py index 140ad9ac5..fb40e51fe 100644 --- a/mxAgent/agent_sdk/prompts/pre_prompt.py +++ b/mxAgent/agent_sdk/prompts/pre_prompt.py @@ -178,7 +178,7 @@ single_action_agent_prompt = PromptTemplate( template=SINGLE_AGENT_ACTION_INSTRUCTION, ) -final_prompt = PromptTemplate( +single_action_final_prompt = PromptTemplate( input_variables=["query", "answer"], template=FINAL_PROMPT, ) -- Gitee From 859fe0e1881ee531ea92b734cc522e6be99135a7 Mon Sep 17 00:00:00 2001 From: hu-xinjia Date: Mon, 16 Dec 2024 16:42:12 +0800 Subject: [PATCH 08/11] delete gradio app --- mxAgent/app.py | 123 ------------------------------------------------- 1 file changed, 123 deletions(-) delete mode 100644 mxAgent/app.py diff --git a/mxAgent/app.py b/mxAgent/app.py deleted file mode 100644 index 3a52554db..000000000 --- a/mxAgent/app.py +++ /dev/null @@ -1,123 +0,0 @@ -import os -from threading import Thread -from typing import Iterator, List, Tuple - -import gradio as gr -import torch -from openmind import AutoModelForCausalLM, AutoTokenizer -from transformers import TextIteratorStreamer - -MAX_MAX_NEW_TOKENS = 2048 -DEFAULT_MAX_NEW_TOKENS = 1024 -MAX_INPUT_TOKEN_LENGTH = 4096 -MAX_HISTORY_LENGTH = 5 # 限制历史对话的长度 - -DESCRIPTION = """\ -# Qwen 中文对话模型 -这个space使用模型 [OpenSource/Qwen2-0.5B-Instruct](https://modelers.cn/models/OpenSource/Qwen2-0.5B-Instruct)运行在cpu上的一个demo -""" - -device = torch.device('npu') -model_name = 'KunLun/Qwen1.5-32B-chat' -model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device) -tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - - -def generate( - user_message: str, - chat_history: List[Tuple[str, str]], - system_prompt: str = "", - max_new_tokens: int = 1024, - temperature: float = 0.6, - top_p: float = 0.9, - top_k: int = 50, - repetition_penalty: float = 1.2, - stream: bool = False -) -> Iterator[str]: - # 限制历史记录的长度 - print(user_message) - if len(chat_history) > MAX_HISTORY_LENGTH: - chat_history = chat_history[-MAX_HISTORY_LENGTH:] - - conversation = [] - if system_prompt: - conversation.append({"role": "system", "content": system_prompt}) - conversation.append({"role": "user", "content": user_message}) - - - input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") - if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: - input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] - gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") - input_ids = input_ids.to(model.device) - generate_kwargs = dict( - {"input_ids": input_ids}, - max_new_tokens=max_new_tokens, - do_sample=True, - top_p=top_p, - top_k=top_k, - temperature=temperature, - num_beams=1, - repetition_penalty=repetition_penalty, - ) - if stream: - return stream_generate(generate_kwargs) - - output_ids = model.generate(generate_kwargs) - output_ids = [ - output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, output_ids) - ] - return tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] - -def stream_generate(generate_kwargs): - streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) - generate_kwargs['streamer'] = streamer - t = Thread(target=model.generate, kwargs=generate_kwargs) - t.start() - - outputs = [] - for text in streamer: - outputs.append(text) - yield "".join(outputs) - - -chat_interface = gr.ChatInterface( - fn=generate, - stop_btn=None, - type="messages", - examples=[ - ["作为程序员该如何防止脱发?"], - ["请给我简短的介绍一下python语言的历史?"], - ["写一篇关于'如何参与到开源贡献社区'的博客"], - ["给我一份自驾游西藏的攻略"], - ["写一篇赞美老师的诗词"], - ], - cache_examples=False, -) - -css = """ -h1 { - text-align: center; - display: block; -} - -#duplicate-button { - margin: auto; - color: white; - background: #1565c0; - border-radius: 100vh; -} - -.contain { - max-width: 900px; - margin: auto; - padding-top: 1.5rem; -} -""" - -with gr.Blocks(css=css, fill_height=True) as demo: - gr.Markdown(DESCRIPTION) - chat_interface.render() - -if __name__ == "__main__": - demo.queue(max_size=20).launch() \ No newline at end of file -- Gitee From 051a0bbb13ee15f1d4bb681d02331ae39582b614 Mon Sep 17 00:00:00 2001 From: hu-xinjia Date: Mon, 16 Dec 2024 17:38:18 +0800 Subject: [PATCH 09/11] delete front end file --- .../samples/travel_agent_demo/front/app.py | 105 ------------------ 1 file changed, 105 deletions(-) delete mode 100644 mxAgent/samples/travel_agent_demo/front/app.py diff --git a/mxAgent/samples/travel_agent_demo/front/app.py b/mxAgent/samples/travel_agent_demo/front/app.py deleted file mode 100644 index 96c8e8f9b..000000000 --- a/mxAgent/samples/travel_agent_demo/front/app.py +++ /dev/null @@ -1,105 +0,0 @@ - -import argparse -import gradio as gr -from mxAgent. samples.travel_agent_demo.travelagent import TravelAgent - - -global mx_agent - -def agent_run(messgae): - query = messgae.get("content") - TravelAgent - - -def user_query(user_message, history): - return "", history + [{"role": "user", "content":user_message}] - - -def clear_history(history): - return [] - -def bot_response(history): - # 将最新的问题传给RAG - try: - response = agent_run(history[-1]) - - # 返回迭代器 - history += [{"role": "assistant", "content":""}] - - history[-1]["content"] = '推理错误' - for res in response: - history[-1]["content"] = '推理错误' if res['result'] is None else res['result'] - yield history - yield history - print(history) - except Exception as err: - history[-1]["content"] = "推理错误" - yield history - -def build_demo(): - with gr.Blocks() as demo: - gr.HTML("""

旅行规划Agent

-

例如:从北京到西安旅游规划

-

例如:西安有哪些免费的博物馆景点

-

例如:查一下西安的酒店

- - """) - with gr.Row(): - with gr.Column(scale=200): - - initial_msg = [ {"role": "assistant", - "content": "这条消息下想说明的是:如果 Chatbot 的 type 参数为 'messages',那么发送到/从 Chatbot " - "的数据将是一个包含 role 和 content 键的字典列表。这种格式符合大多数 LLM API(如 " - "HuggingChat、OpenAI、Claude)期望的格式。role 键可以是 'user' 或 " - "'assistant',content 键可以是一个字符串(支持 markdown/html 格式),一个 " - "FileDataDict(用于表示在 Chatbot 中显示的文件),或者一个 gradio 组件。"} - ], - - # chatbot = gr.Chatbot(initial_msg, type="messages",) - chatbot = gr.Chatbot( - [ - - {"role": "assistant", "content": "你好,我是你的AI小助手,这是你自己预制的一个信息。"}, - {"role": "assistant", "content": "这条消息下想说明的是:如果 Chatbot 的 type 参数为 'messages',那么发送到/从 Chatbot " - "的数据将是一个包含 role 和 content 键的字典列表。这种格式符合大多数 LLM API(如 " - "HuggingChat、OpenAI、Claude)期望的格式。role 键可以是 'user' 或 " - "'assistant',content 键可以是一个字符串(支持 markdown/html 格式),一个 " - "FileDataDict(用于表示在 Chatbot 中显示的文件),或者一个 gradio 组件。"} - ], - type="messages", - show_label=False, - height=500, - show_copy_button=True - - - ) - - with gr.Row(): - msg = gr.Textbox(placeholder="在此输入问题...", container=False) - with gr.Row(): - send_btn = gr.Button(value="发送", variant="primary") - clean_btn = gr.Button(value="清空历史") - send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False).then(bot_response, - [chatbot], chatbot) - clean_btn.click(clear_history, chatbot, chatbot) - return demo - - - - -def get_args(): - parse = argparse.ArgumentParser() - parse.add_argument("--model_name", type=str, default="Qwen1.5-32B-Chat", help="OpenAI客户端模型名") - parse.add_argument("--base_url", type=str, default="http://10.44.115.108:1055/v1", help="OpenAI客户端模型地址") - parse.add_argument("--api_key", type=str, default="EMPTY", help="OpenAI客户端api key") - return parse.parse_args().__dict__ - -if __name__ == "__main__": - args = get_args() - base_url = args.pop("base_url") - api_key = args.pop("api_key") - llm_name = args.pop("model_name") - - mx_agent = TravelAgent(base_url, api_key, llm_name) - demo = build_demo() - demo.launch(share=True) \ No newline at end of file -- Gitee From 8373c3338508402250489a4989b78fbcb90a3b87 Mon Sep 17 00:00:00 2001 From: hu-xinjia Date: Mon, 16 Dec 2024 17:41:44 +0800 Subject: [PATCH 10/11] fix intent router --- mxAgent/samples/basic_demo/intent_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mxAgent/samples/basic_demo/intent_router.py b/mxAgent/samples/basic_demo/intent_router.py index e59df3d28..c9aaf3edc 100644 --- a/mxAgent/samples/basic_demo/intent_router.py +++ b/mxAgent/samples/basic_demo/intent_router.py @@ -40,7 +40,7 @@ if __name__ == "__main__": API_KEY = args.pop("api_key") LLM_NAME = args.pop("model_name") llm = get_llm_backend(backend=BACKEND_OPENAI_COMPATIBLE, - api_base=API_BASE, api_key=API_KEY, llm_name=LLM_NAME).run + base_url=API_BASE, api_key=API_KEY, llm_name=LLM_NAME).run agent = RouterAgent(llm=llm, intents=INTENT) for query in querys: response = agent.run(query) -- Gitee From 726e5c0ed1dd032ad285b9a3be2b713024c59d43 Mon Sep 17 00:00:00 2001 From: hu-xinjia Date: Mon, 16 Dec 2024 19:42:02 +0800 Subject: [PATCH 11/11] fix clean code --- mxAgent/samples/tools/common.py | 2 +- mxAgent/samples/tools/tool_query_accommodations.py | 1 + mxAgent/samples/tools/tool_query_attractions.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mxAgent/samples/tools/common.py b/mxAgent/samples/tools/common.py index 5e1ca6542..98fac08fd 100644 --- a/mxAgent/samples/tools/common.py +++ b/mxAgent/samples/tools/common.py @@ -13,7 +13,7 @@ def filter_website_keywords(keys): if len(filtered) == 0: raise Exception("keywords has no been found") - return filtered + return filtered def flatten(nested_list): diff --git a/mxAgent/samples/tools/tool_query_accommodations.py b/mxAgent/samples/tools/tool_query_accommodations.py index be2615bf7..c4a5d8635 100644 --- a/mxAgent/samples/tools/tool_query_accommodations.py +++ b/mxAgent/samples/tools/tool_query_accommodations.py @@ -11,6 +11,7 @@ from agent_sdk.toolmngt.tool_manager import ToolManager from samples.tools.common import filter_website_keywords from samples.tools.web_summary_api import WebSummary + @ToolManager.register_tool() class QueryAccommodations(API): name = "QueryAccommodations" diff --git a/mxAgent/samples/tools/tool_query_attractions.py b/mxAgent/samples/tools/tool_query_attractions.py index eb54cb042..b9c53b81c 100644 --- a/mxAgent/samples/tools/tool_query_attractions.py +++ b/mxAgent/samples/tools/tool_query_attractions.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. -import tiktoken import json +import tiktoken from loguru import logger from agent_sdk.toolmngt.api import API -- Gitee