diff --git a/mxAgent/agent_sdk/agentchain/recipe_agent.py b/mxAgent/agent_sdk/agentchain/recipe_agent.py index ec20af47211a19589cc0a14981392fa4889615fc..c12a1d6f0c39ad7000762ef0df8b55943653a62d 100644 --- a/mxAgent/agent_sdk/agentchain/recipe_agent.py +++ b/mxAgent/agent_sdk/agentchain/recipe_agent.py @@ -200,7 +200,7 @@ SUGGESTION = """1. 翻译结果请严格按照YAML格式输出,不要添加任 5. 每个节点的dependency字段必须准确,能匹配伪代码中的依赖关系逻辑,dependency的节点必须是存在的节点 6. 每个节点的input字段必须有参数,每个节点的input字段名务必准确,必须是工具有的参数名,input字段中的每个参数输入值必须且只能是具体值或者依赖节点的工具输出参数, 不要使用python代码或者其他表达式, -7. 每个节点的input字段的每个参数值,优先使用依赖节点的工具输出参数,若无法通过依赖得到可以问题中提取,若存在多个答案,请使用加号+隔开, +7. 每个节点的input字段的每个参数值,优先使用依赖节点的工具输出参数,若无法通过依赖得到可以从问题中提取,若存在多个答案,请使用加号+隔开,若仍无法得到参数,统一使用【无】 8. 【伪代码】的步骤:一个步骤只能翻译成一个对应的节点 9. 生成的内容请严格遵循YAML的语法和格式 """ diff --git a/mxAgent/agent_sdk/prompts/pre_prompt.py b/mxAgent/agent_sdk/prompts/pre_prompt.py index 140ad9ac50870027b29bfe4228627460d901cdcb..fb40e51fee5d6b618aa491ca7bf14e84b13b03cb 100644 --- a/mxAgent/agent_sdk/prompts/pre_prompt.py +++ b/mxAgent/agent_sdk/prompts/pre_prompt.py @@ -178,7 +178,7 @@ single_action_agent_prompt = PromptTemplate( template=SINGLE_AGENT_ACTION_INSTRUCTION, ) -final_prompt = PromptTemplate( +single_action_final_prompt = PromptTemplate( input_variables=["query", "answer"], template=FINAL_PROMPT, ) diff --git a/mxAgent/samples/basic_demo/intent_router.py b/mxAgent/samples/basic_demo/intent_router.py index e59df3d28cdfcfc1ba5c00b3951ff6d8162d57df..c9aaf3edc740081571b5f3b63aa293f39b000e18 100644 --- a/mxAgent/samples/basic_demo/intent_router.py +++ b/mxAgent/samples/basic_demo/intent_router.py @@ -40,7 +40,7 @@ if __name__ == "__main__": API_KEY = args.pop("api_key") LLM_NAME = args.pop("model_name") llm = get_llm_backend(backend=BACKEND_OPENAI_COMPATIBLE, - api_base=API_BASE, api_key=API_KEY, llm_name=LLM_NAME).run + base_url=API_BASE, api_key=API_KEY, llm_name=LLM_NAME).run agent = RouterAgent(llm=llm, intents=INTENT) for query in querys: response = agent.run(query) diff --git a/mxAgent/samples/tools/common.py b/mxAgent/samples/tools/common.py index 2ed383491743ada62fc19ba12e79c9a91040067f..98fac08fdd68ea823769197870231d2d01aecb5e 100644 --- a/mxAgent/samples/tools/common.py +++ b/mxAgent/samples/tools/common.py @@ -1,8 +1,5 @@ -import json -from samples.tools.web_summary_api import WebSummary - -def get_website_summary(keys, prompt, llm): +def filter_website_keywords(keys): filtered = [] for val in keys: if val is None or len(val) == 0: @@ -16,10 +13,7 @@ def get_website_summary(keys, prompt, llm): if len(filtered) == 0: raise Exception("keywords has no been found") - - webs = WebSummary.web_summary( - filtered, search_num=3, summary_num=3, summary_prompt=prompt, llm=llm) - return json.dumps(webs, ensure_ascii=False) + return filtered def flatten(nested_list): diff --git a/mxAgent/samples/tools/duck_search.py b/mxAgent/samples/tools/duck_search.py index 7da42455e442a58bb7707603d33fcf37bf1cc251..9beba59bb656ea9835a10632cfd398ac2e40a6aa 100644 --- a/mxAgent/samples/tools/duck_search.py +++ b/mxAgent/samples/tools/duck_search.py @@ -5,7 +5,7 @@ import re from langchain_community.tools import DuckDuckGoSearchResults from loguru import logger -from toolmngt.api import API +from agent_sdk.toolmngt.api import API class DuckDuckGoSearch(API): diff --git a/mxAgent/samples/tools/tool_query_accommodations.py b/mxAgent/samples/tools/tool_query_accommodations.py index 3c80febebfb9690862ceccb5921bbb30b17f29c2..c4a5d86355c486da5df8d438162aec7b46ec5742 100644 --- a/mxAgent/samples/tools/tool_query_accommodations.py +++ b/mxAgent/samples/tools/tool_query_accommodations.py @@ -8,7 +8,8 @@ from loguru import logger from agent_sdk.toolmngt.api import API from agent_sdk.toolmngt.tool_manager import ToolManager -from samples.tools.common import get_website_summary +from samples.tools.common import filter_website_keywords +from samples.tools.web_summary_api import WebSummary @ToolManager.register_tool() @@ -44,7 +45,7 @@ class QueryAccommodations(API): position = input_parameter.get("position") rank = input_parameter.get("rank") llm = kwargs.get("llm", None) - keys = [destination, position, rank, "住宿"] + keys = [destination, position, rank] logger.debug(f"search accommodation key words: {','.join(keys)}") prompt = """你是一个擅长文字处理和信息总结的智能助手,你的任务是将提供的网页信息进行总结,并以精简的文本的形式进行返回, @@ -59,8 +60,11 @@ class QueryAccommodations(API): 请生成总结: """ try: - content = get_website_summary(keys, prompt, llm) - res = {"accommodation": content} + filtered = filter_website_keywords(keys) + filtered.append("住宿") + webs = WebSummary.web_summary( + filtered, search_num=3, summary_num=3, summary_prompt=prompt, llm=llm) + res = {"accommodation": json.dumps(webs)} return self.make_response(input_parameter, results=res, exception="") except Exception as e: logger.error(e) diff --git a/mxAgent/samples/tools/tool_query_attractions.py b/mxAgent/samples/tools/tool_query_attractions.py index de83491c954ec4a38082655c5b73b5a45140cf6b..b9c53b81c1ce9c3e56094fe899fa45a432aa9486 100644 --- a/mxAgent/samples/tools/tool_query_attractions.py +++ b/mxAgent/samples/tools/tool_query_attractions.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +import json import tiktoken -import yaml from loguru import logger from agent_sdk.toolmngt.api import API from agent_sdk.toolmngt.tool_manager import ToolManager -from samples.tools.common import get_website_summary +from samples.tools.common import filter_website_keywords +from samples.tools.web_summary_api import WebSummary @ToolManager.register_tool() @@ -38,7 +39,7 @@ class QueryAttractions(API): "destination": "Paris", "scene": "The Louvre Museum", "type": "Museum", - "requirement": "historical" + "requirement": "free" }""") def __init__(self): @@ -51,7 +52,7 @@ class QueryAttractions(API): requirement = input_parameter.get('requirement') llm = kwargs.get("llm", None) - keys = [destination, scene, scene_type, requirement, "景点"] + keys = [destination, scene, scene_type, requirement] summary_prompt = """你是一个擅长于网页信息总结的智能助手,提供的网页是关于旅游规划的信息,现在已经从网页中获取到了相关的文字内容信息,你需要从网页中找到与**景区**介绍相关的内容,并进行提取, 你务必保证提取的内容都来自所提供的文本,保证结果的客观性,真实性。 网页中可能包含多个景点的介绍,你需要以YAML文件的格式返回,每个景点的返回的参数和格式如下: @@ -68,8 +69,11 @@ class QueryAttractions(API): 请开始生成: """ try: - content = get_website_summary(keys, summary_prompt, llm) - res = {'attractions': content} + filtered = filter_website_keywords(keys) + filtered.append("景点") + webs = WebSummary.web_summary( + filtered, search_num=3, summary_num=3, summary_prompt=summary_prompt, llm=llm) + res = {'attractions': json.dumps(webs)} return self.make_response(input_parameter, results=res, exception="") except Exception as e: logger.error(e) diff --git a/mxAgent/samples/tools/tool_query_city.py b/mxAgent/samples/tools/tool_query_city.py index bc20932ddf00814bd70f9c8a05e9abaea39266bf..e7f32521603d58b62fc03e1827e34b8edc2c9e48 100644 --- a/mxAgent/samples/tools/tool_query_city.py +++ b/mxAgent/samples/tools/tool_query_city.py @@ -25,13 +25,6 @@ class CitySearch(API): "city": {'type': 'str', 'description': "the name of the city in the state"} } - usage = f"""{name}[state]: - Description: This api can be used to retrieve cities in your target state. - Parameter: - state: The name of the state where you're finding cities. - Example: {name}[state: New York] would return cities in New York. - """ - example = ( """ { diff --git a/mxAgent/samples/tools/tool_query_transports.py b/mxAgent/samples/tools/tool_query_transports.py index 74bb5cc12b3716c5f6552448355e7a00c97897f9..bd693028f47bbd8f3adfbc87067f486188af9f30 100644 --- a/mxAgent/samples/tools/tool_query_transports.py +++ b/mxAgent/samples/tools/tool_query_transports.py @@ -8,7 +8,8 @@ import tiktoken from agent_sdk.toolmngt.api import API from agent_sdk.toolmngt.tool_manager import ToolManager -from samples.tools.common import get_website_summary +from samples.tools.common import filter_website_keywords +from samples.tools.web_summary_api import WebSummary @ToolManager.register_tool() @@ -50,7 +51,7 @@ class QueryTransports(API): try: prefix = f"从{origin}出发" if origin else "" prefix += f"前往{destination}" if destination else "" - keys = [prefix, req, travel_mode, "购票"] + keys = [prefix, req, travel_mode] prompt = """你的任务是将提供的网页信息进行总结,并以精简的文本的形式进行返回, 请添加适当的词语,使得语句内容连贯,通顺。输入是为用户查询的航班、高铁等交通数据,请将这些信息总结 @@ -60,8 +61,11 @@ class QueryTransports(API): {input} 请生成总结: """ - content = get_website_summary(keys, prompt, llm) - res = {'transport': content} + filtered = filter_website_keywords(keys) + filtered.append("购票") + webs = WebSummary.web_summary( + filtered, search_num=3, summary_num=3, summary_prompt=prompt, llm=llm) + res = {'transport': json.dumps(webs)} return self.make_response(input_parameter, results=res, exception="") except Exception as e: logger.error(e) diff --git a/mxAgent/samples/travel_agent_demo/travelagent.py b/mxAgent/samples/travel_agent_demo/travelagent.py index 06edca9dca94f86d40ac4a63817263da3747fda2..e57a271292aa6b13abffe0a9670107727636644e 100644 --- a/mxAgent/samples/travel_agent_demo/travelagent.py +++ b/mxAgent/samples/travel_agent_demo/travelagent.py @@ -143,15 +143,18 @@ class TalkShowAgent(ToollessAgent, ABC): class TravelAgent: - @classmethod - def route_query(cls, query): - router_agent = RouterAgent(llm=llm, intents=intents) + def __init__(self, base_url, api_key, llm_name): + self.llm = get_llm_backend(backend=BACKEND_OPENAI_COMPATIBLE, + base_url=base_url, api_key=api_key, llm_name=llm_name).run + + def route_query(self, query): + router_agent = RouterAgent(llm=self.llm, intents=intents) classify = router_agent.run(query).answer if classify not in classifer or classify == OTHERS: - return TalkShowAgent(llm=llm) + return TalkShowAgent(llm=self.llm) return RecipeAgent(name=classify, description="你的名字叫昇腾智搜,是一个帮助用户完成旅行规划的助手,你的能力范围包括:目的地推荐、行程规划、交通信息查询、酒店住宿推荐、旅行攻略推荐", - llm=llm, + llm=self.llm, tool_list=TOOL_LIST_MAP[classify], recipe=INST_MAP[classify], max_steps=3, @@ -180,7 +183,7 @@ if __name__ == "__main__": base_url=base_url, api_key=api_key, llm_name=llm_name).run query = "帮我制定一份从北京到上海6天的旅游计划" - travel_agent = TravelAgent() + travel_agent = TravelAgent(base_url, api_key, llm_name) res = travel_agent.run(query, stream=False) if isinstance(res, AgentRunResult): logger.info("-----------run agent success-------------")