diff --git a/jiuwen/agent_builder/prompt_builder/tune/base/case.py b/jiuwen/agent_builder/prompt_builder/tune/base/case.py new file mode 100644 index 0000000000000000000000000000000000000000..d43d516e486334d02df4b165ff97f9cdb970c37b --- /dev/null +++ b/jiuwen/agent_builder/prompt_builder/tune/base/case.py @@ -0,0 +1,137 @@ +#!/usr/bin/python3.10 +# coding: utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved + +from typing import List, Dict, Optional, Any, Callable + +from pydantic import BaseModel, field_validator, Field, FieldValidationInfo + +from jiuwen.agent_builder.prompt_builder.tune.common.exception import ParamCheckFailedException +from jiuwen.agent_builder.prompt_builder.tune.base.exception import CaseValidationException +from jiuwen.agent_builder.prompt_builder.tune.base.constant import TuneConstant + + +class Case(BaseModel): + """Definition of prompt optimization user case""" + messages: List[Dict] = Field(default=[]) + tools: Optional[List[Dict]] = Field(default=None) + + @field_validator("messages") + @classmethod + def check_message_list_content(cls, value: List[Dict], info: FieldValidationInfo) -> List[Dict]: + """check message list content is valid""" + if not isinstance(value, list): + raise ParamCheckFailedException(f"input value field name {info.field_name} check failed! " + f"value type not correct, expected list") + + if not value: + raise ParamCheckFailedException(f"input value field name {info.field_name} check failed! " + f"value is empty") + if value[-1].get(TuneConstant.MESSAGE_ROLE_KEY) != TuneConstant.ASSISTANT_ROLE: + raise ParamCheckFailedException(f"the last message role should be {TuneConstant.ASSISTANT_ROLE}") + return value + + +class CaseInfo(BaseModel): + """Definition of case information""" + role: str + name: str + content: str + tools: Any + tool_calls: Any + variable: Any + + +class CaseManager: + """Definition of case manager""" + VALID_ROLES_SET = { + TuneConstant.USER_ROLE, + TuneConstant.SYSTEM_ROLE, + TuneConstant.ASSISTANT_ROLE, + TuneConstant.TOOL_ROLE + } + + @staticmethod + def validate_with_convert(data: List[Dict[str, Any]], + convertor: Optional[Callable] = None, + default_tools: Optional[List[Dict]] = None) -> List: + """validate and convert cases to optimizer acceptable input""" + optimizer_input = [] + if not data: + raise CaseValidationException(f"input data is empty", TuneConstant.ROOT_CASE_INDEX) + + if len(data) > TuneConstant.DEFAULT_MAX_CASE_NUM: + raise CaseValidationException( + f"The number of input cases should be less than {TuneConstant.DEFAULT_MAX_CASE_NUM}", + TuneConstant.DEFAULT_MAX_CASE_NUM + ) + + if not isinstance(data, list) or not isinstance(data[0], dict): + raise CaseValidationException("Input type is not json-list", TuneConstant.ROOT_CASE_INDEX) + + for case_idx, case in enumerate(data): + messages = CaseManager._get_case_value_with_check(case, TuneConstant.MESSAGE_KEY, case_idx) + if not isinstance(messages, list): + raise CaseValidationException("Input type is not json-list", TuneConstant.ROOT_CASE_INDEX) + tools = case.get(TuneConstant.TOOL_ROLE, None) or default_tools + role = "" + for msg_idx, msg in enumerate(messages): + if not isinstance(msg, dict): + raise CaseValidationException("Message in 'messages' is not json-dict", + TuneConstant.ROOT_CASE_INDEX) + role = CaseManager._get_case_value_with_check(msg, TuneConstant.MESSAGE_ROLE_KEY, case_idx) + if role not in CaseManager.VALID_ROLES_SET: + raise CaseValidationException(f"Invalid role type '{role}' in case-{case_idx}", case_idx) + content = CaseManager._get_case_value_with_check(msg, TuneConstant.MESSAGE_CONTENT_KEY, msg_idx) + tool_calls = msg.get(TuneConstant.MESSAGE_TOOL_CALLS_KEY, []) + if not tool_calls and not content.strip(): + raise CaseValidationException(f"Empty message content in case-{case_idx}", case_idx) + + tool_name = msg.get(TuneConstant.NAME_KEY, "") + variable = msg.get(TuneConstant.VARIABLE_KEY, dict()) + if convertor: + convertor(optimizer_input, case_idx, msg_idx == (len(messages) - 1), + CaseInfo(role=role, tools=tools, content=content, + name=tool_name, tool_calls=tool_calls, variable=variable)) + + if role != TuneConstant.ASSISTANT_ROLE: + raise CaseValidationException( + f"The last message role should be {TuneConstant.ASSISTANT_ROLE} in case-{case_idx}", case_idx + ) + return optimizer_input + + @staticmethod + def default_convertor(cases: List, idx: int, is_last: bool, info: CaseInfo) -> None: + """default optimizer input convertor""" + if len(cases) <= idx: + cases.append(dict( + question="", + label="", + tools=info.tools, + variable=dict() + )) + case = cases[-1] + if info.variable: + case[TuneConstant.VARIABLE_KEY] = info.variable + if info.role == TuneConstant.ASSISTANT_ROLE: + content = str(info.tool_calls) if info.tool_calls else info.content + if is_last: + if not case.get(TuneConstant.QUESTION_KEY, ""): + raise CaseValidationException(f"case-{idx} is not a question-label pair", idx) + case[TuneConstant.LABEL_KEY] = content + else: + case[TuneConstant.QUESTION_KEY] = f"{info.role}: {content}\n" + elif info.role == TuneConstant.TOOL_ROLE: + content = f"{info.role}: name={info.name}, content={info.content}\n" + case[TuneConstant.QUESTION_KEY] += content + else: + content = str(info.tool_calls) if info.tool_calls else info.content + case[TuneConstant.QUESTION_KEY] += f"{info.role}: {content}\n" + + @staticmethod + def _get_case_value_with_check(data: Dict[str, Any], key: str, index: int) -> Any: + """get case value with existence check""" + value = data.get(key, None) + if value is None: + raise CaseValidationException(f"{key} is not in \"case={index}\"", index) + return value \ No newline at end of file diff --git a/jiuwen/agent_builder/prompt_builder/tune/base/constant.py b/jiuwen/agent_builder/prompt_builder/tune/base/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..8ee98023c5c4d00dba197a088a26ba4352c0c8a6 --- /dev/null +++ b/jiuwen/agent_builder/prompt_builder/tune/base/constant.py @@ -0,0 +1,72 @@ +#!/usr/bin/python3.10 +# coding: utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved + +class TuneConstant: + """prompt tuning constants""" + + """message roles constant""" + SYSTEM_ROLE = "system" + ASSISTANT_ROLE = "assistant" + USER_ROLE = "user" + TOOL_ROLE = "tools" + + """message keys constant""" + MESSAGE_KEY = "messages" + TOOLS_KEY = "tools" + MESSAGE_ROLE_KEY = "role" + MESSAGE_CONTENT_KEY = "content" + MESSAGE_TOOL_CALLS_KEY = "tool_calls" + QUESTION_KEY = "question" + REASON_KEY = "reason" + LABEL_KEY = "label" + NAME_KEY = "name" + VARIABLE_KEY = "variable" + PREDICT_KEY = "predict" + + """optimizer parameters type constant""" + ROOT_CASE_INDEX: int = -1 + EVALUATION_METHOD_TEXT = "TEXT" + EVALUATION_METHOD_LLM = "LLM" + OPTIMIZATION_METHOD_JOINT = "JOINT" + RAW_PROMPT_TAG = "" + + """optimizer parameters default value constant""" + DEFAULT_EXAMPLE_NUM: int = 0 + DEFAULT_COT_EXAMPLE_NUM: int = 0 + DEFAULT_ITERATION_NUM: int = 3 + DEFAULT_MAX_SAMPLED_EXAMPLE_NUM: int = 10 + DEFAULT_LLM_PARALLEL_DEGREE: int = 1 + DEFAULT_LLM_CALL_RETRY_NUM: int = 5 + DEFAULT_COT_EXAMPLE_RATIO: float = 0.25 + DEFAULT_MAX_CASE_NUM: int = 300 + DEFAULT_OPTIMIZATION_METHOD = OPTIMIZATION_METHOD_JOINT + DEFAULT_EVALUATION_METHOD = EVALUATION_METHOD_LLM + DEFAULT_MAX_RUNNING_TASK_NUM: int = 64 + + """optimizer parameters threshold constant""" + MIN_ITERATION_NUM: int = 1 + MAX_ITERATION_NUM: int = 20 + MIN_LLM_CALL_RETRY_NUM: int = 1 + MAX_LLM_CALL_RETRY_NUM: int = 10 + MIN_LLM_PARALLEL_DEGREE: int = 1 + MAX_LLM_PARALLEL_DEGREE: int = 10 + MIN_EXAMPLE_NUM: int = 0 + MAX_EXAMPLE_NUM: int = 10 + MIN_COT_EXAMPLE_NUM: int = 0 + MAX_COT_EXAMPLE_NUM: int = 5 + + DEFAULT_TOOL_CALL_PROMPT_PREFIX: str = """ + API/工具说明:\n{{APIS_DESCRIPTION}} + """ + +class TaskStatus: + """optimizer task status""" + TASK_STATUS = "status" + TASK_RUNNING = "running" + TASK_FINISHED = "finished" + TASK_FAILED = "failed" + TASK_STOPPED = "stopped" + TASK_STOPPING = "stopping" + TASK_DELETED = "deleted" + TASK_QUEUED = "queued" \ No newline at end of file diff --git a/jiuwen/agent_builder/prompt_builder/tune/base/context_manager.py b/jiuwen/agent_builder/prompt_builder/tune/base/context_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..961d859b0324d8f87ddedbb738feae4b70e1c382 --- /dev/null +++ b/jiuwen/agent_builder/prompt_builder/tune/base/context_manager.py @@ -0,0 +1,110 @@ +#!/usr/bin/python3.10 +# coding: utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved + +import copy +import os +import threading +import hashlib +import time +from typing import Dict, Optional, Any, Callable, List, Set +from datetime import datetime, timezone, timedelta + +from cacheout import Cache +from pydantic import BaseModel, Field + +from jiuwen.agent_builder.prompt_builder.tune.common.singleton import Singleton + +from logging import getLogger +logger = getLogger(__name__) + +from jiuwen.agent_builder.prompt_builder.tune.common.exception import JiuWenBaseException, StatusCode +from jiuwen.agent_builder.prompt_builder.tune.base.constant import TuneConstant, TaskStatus +from jiuwen.agent_builder.prompt_builder.tune.base.utils import OptimizeInfo + +Context = Dict[str, Any] +STOP_EVENT = "stop_event" + +class ContextManager(metaclass=Singleton): + """manage optimizer train context""" + def __init__(self): + self._cache: Cache = Cache(maxsize=65536) # context memory cache + self._check_points: Cache = Cache(maxsize=65536) # context checkpoint + self._lock: threading.Lock = threading.Lock() + + self._n_max_running_task: int = TuneConstant.DEFAULT_MAX_RUNNING_TASK_NUM + + def __len__(self): + """get length of current contexts""" + with self._lock: + return self._cache.size() + + @staticmethod + def _generate_context_id(): + """generate unique context id""" + timestamp = str(int(time.time() * 1000)) + return f"JNT_{hashlib.sha256(timestamp.encode()).hexdigest()}" + + def get_context_attr(self, ctx_id: str, attr_name: str): + context = self._cache.get(ctx_id) + if context: + return context.get(attr_name, None) + + def set_context_attr(self, ctx_id: str, attr_name: str, value: Any): + context = self._cache.get(ctx_id) + if context: + context[attr_name] = value + checkpoint = self._check_points.get(ctx_id) + if checkpoint: + checkpoint[attr_name] = value + + def set(self, ctx_id: str, context: Context): + """set context to cache""" + with self._lock: + self._cache.set(ctx_id, context) + + def get(self, ctx_id: str) -> Optional[Context]: + """get context from cache""" + with self._lock: + return self._cache.get(ctx_id, None) + + def is_executable(self): + n_running_task = 0 + for task in self._cache.values(): + status = task.status + if status in (TaskStatus.TASK_STOPPING, TaskStatus.TASK_RUNNING): + n_running_task += 1 + return n_running_task <= self._n_max_running_task + + def set_checkpoint(self, ctx_id: str, context: Context): + """set checkpoint to cache""" + with self._lock: + stop_event = context.get(STOP_EVENT) + context[STOP_EVENT] = None + copy_context = copy.deepcopy(context) + context[STOP_EVENT] = stop_event + self._check_points.set(ctx_id, copy_context) + + def get_checkpoint(self, ctx_id: str) -> Optional[Context]: + """get checkpoint from cache""" + with self._lock: + context = copy.deepcopy(self._check_points.get(ctx_id), None) + if context: + context[STOP_EVENT] = threading.Event() + return context + + def delete(self, ctx_id: str): + """delete checkpoint from cache""" + with self._lock: + self._cache.delete(ctx_id) + self._check_points.delete(ctx_id) + + def clear(self): + """clear checkpoint from cache""" + with self._lock: + self._cache.clear() + self._check_points.clear() + + def items(self): + with self._lock: + return self._cache.items() \ No newline at end of file diff --git a/jiuwen/agent_builder/prompt_builder/tune/base/exception.py b/jiuwen/agent_builder/prompt_builder/tune/base/exception.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf8e98d1f6f70da9993840ebb66654b3b30faff --- /dev/null +++ b/jiuwen/agent_builder/prompt_builder/tune/base/exception.py @@ -0,0 +1,23 @@ +from jiuwen.agent_builder.prompt_builder.tune.common.exception import JiuWenBaseException, StatusCode +from jiuwen.agent_builder.prompt_builder.tune.base.constant import TuneConstant + +class CaseValidationException(JiuWenBaseException): + """Definition of cases validation exception""" + def __init__(self, message: str, error_index: int = TuneConstant.ROOT_CASE_INDEX): + super().__init__(StatusCode.PROMPT_OPTIMIZE_CASE_VALIDATION_ERROR.code, + StatusCode.PROMPT_OPTIMIZE_CASE_VALIDATION_ERROR.errmsg.format( + error_msg=message + )) + + self._error_index = error_index + + @property + def error_index(self): + """index of the error case""" + return self._error_index + + +class OnStopException(JiuWenBaseException): + """Definition of on-stop exception""" + def __init__(self, message: str): + super().__init__(StatusCode.SUCCESS.code, message) \ No newline at end of file diff --git a/jiuwen/agent_builder/prompt_builder/tune/base/utils.py b/jiuwen/agent_builder/prompt_builder/tune/base/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..08207c3c47d8615bbe00f3333e5d2c4919e95028 --- /dev/null +++ b/jiuwen/agent_builder/prompt_builder/tune/base/utils.py @@ -0,0 +1,193 @@ +# -*- coding: utf-8 -*- + +""" +prompt optimization utils +""" + +import os +from typing import List, Dict, Optional, Any, Union +from datetime import datetime, timezone, timedelta + +from pydantic import BaseModel, field_validator, Field, FieldValidationInfo +import yaml + +from jiuwen.agent_builder.prompt_builder.tune.base.exception import OnStopException +from jiuwen.agent_builder.prompt_builder.tune.common.exception import JiuWenBaseException, StatusCode +from jiuwen.agent_builder.prompt_builder.tune.base.case import Case +from jiuwen.agent_builder.prompt_builder.tune.base.constant import TuneConstant, TaskStatus + + +class TaskInfo(BaseModel): + """prompt optimization input task info""" + task_id: str = Field(...) + task_name: str = Field(...) + task_description: str = Field(default="") + create_time: str = Field(...) + + +class OptimizeInfo(BaseModel): + """definition of prompt optimization info""" + cases: List[Case] = Field(default=[]) + num_iterations: int = Field(default=TuneConstant.DEFAULT_ITERATION_NUM) + num_parallel: int = Field(default=TuneConstant.DEFAULT_LLM_PARALLEL_DEGREE, + ge=TuneConstant.MIN_LLM_PARALLEL_DEGREE, le=TuneConstant.MAX_LLM_PARALLEL_DEGREE) + num_examples: int = Field(default=TuneConstant.DEFAULT_EXAMPLE_NUM, + ge=TuneConstant.MIN_EXAMPLE_NUM, le=TuneConstant.MAX_EXAMPLE_NUM) + num_cot_examples: int = Field(default=TuneConstant.DEFAULT_COT_EXAMPLE_NUM, + ge=TuneConstant.MIN_COT_EXAMPLE_NUM, le=TuneConstant.MAX_COT_EXAMPLE_NUM) + num_retires: int = Field(default=TuneConstant.DEFAULT_LLM_CALL_RETRY_NUM, + ge=TuneConstant.MIN_LLM_CALL_RETRY_NUM, le=TuneConstant.MAX_LLM_CALL_RETRY_NUM) + optimize_method: str = Field(default=TuneConstant.OPTIMIZATION_METHOD_JOINT) + placeholder: Optional[List] = Field(default=[]) + evaluation_method: str = Field(default=TuneConstant.DEFAULT_EVALUATION_METHOD) + tools: Union[List[Dict[str, Any]], Any] = Field(default=[]) + user_compare_rules: str = Field(default="None") + + +class JointParameters(BaseModel): + """Joint optimization parameters""" + num_examples: int = Field(default=TuneConstant.DEFAULT_EXAMPLE_NUM) + num_cot_examples: int = Field(default=TuneConstant.DEFAULT_COT_EXAMPLE_NUM) + base_instructions: str = Field(default="") + filled_instructions: str = Field(default="") + full_prompt: str = Field(default="") + answer_format: str = Field(default="") + task_description: str = Field(default="") + num_iterations: int = Field(default=TuneConstant.DEFAULT_ITERATION_NUM) + opt_placeholder_names: List[str] = Field(default=[]) + placeholders: List = Field(default=[]) + original_placeholders: List = Field(default=[]) + examples: List = Field(default=[]) + cot_examples: List = Field(default=[]) + + +class History(BaseModel): + """optimization history for every round""" + optimized_prompt: str = Field(default="") + original_placeholder: Dict[str, str] = Field(default={}) + optimized_placeholder: Dict[str, str] = Field(default={}) + examples: List[str] = Field(default=[]) + filled_prompt: str = Field(default="") + success_rate: float = Field(default=0.0) + iteration_round: int = Field(default=0) + + +class LLMModelInfo(BaseModel): + """LLM model config info""" + url: str = Field(default="", min_length=0, max_length=256) + model: str = Field(default="", min_length=0, max_length=256) + type: str = Field(default="", min_length=0, max_length=256) + headers: Optional[Dict] = Field(default={}) + model_source: str = Field(default="", min_length=0, max_length=256) + api_key: str = Field(default="", min_length=0, max_length=256) + + +class QwenLLM: + def __init__(self, model_info: LLMModelInfo): + self.url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions" + self.api_key = "sk-f9410e0700c94022a16b78341e860c45" + self.model_name = model_info.model + + def chat(self, messages): + import requests + body = { + "messages": messages, + "model": self.model_name, + "stream": False + } + headers = { + "Authorization": f"{self.api_key}", + } + response = requests.post(self.url, json=body, headers=headers) + if response.status_code != 200: + print(response.text) + raise Exception(response.text) + content = response.json().get("choices")[0].get("message").get("content") + return dict(content=content) + +class LLMModelProcess: + """LLM invoke process""" + def __init__(self, llm_model_info: LLMModelInfo): + if llm_model_info.headers is None: + raise JiuWenBaseException( + error_code=StatusCode.LLM_CONFIG_MISS_ERROR.code, + message=StatusCode.LLM_CONFIG_MISS_ERROR.errmsg.format( + error_msg="prompt optimization llm config is missing" + ) + ) + self.chat_llm = None + self.model_info = llm_model_info + + def chat(self, messages: List[Any]) -> Dict: + """chat""" + return QwenLLM(self.model_info).chat(messages) + + +def load_yaml_to_dict(file_path: str) -> Dict: + """load yaml file""" + with open(file_path, "r", encoding="utf-8") as f: + try: + yaml_content = f.read() + parsed_dict = yaml.safe_load(yaml_content) + except Exception as e: + raise JiuWenBaseException( + error_code=StatusCode.LLM_CONFIG_MISS_ERROR.code, + message=StatusCode.LLM_CONFIG_MISS_ERROR.errmsg + ) + return parsed_dict + +def calculate_runtime(start_time: str) -> int: + """calculate task runtime""" + if not start_time: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.code, + message=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.errmsg.format( + error_msg="invalid start time" + ) + ) + try: + start_time = datetime.strptime(start_time, "%Y-%m-%d %H:%M:%S") + cur_time = datetime.now(tz=timezone(timedelta(hours=8))).strftime("%Y-%m-%d %H:%M:%S") + cur_time = datetime.strptime(cur_time, "%Y-%m-%d %H:%M:%S") + except Exception as e: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.code, + message=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.errmsg.format( + error_msg="invalid time format" + ) + ) from e + return int((cur_time - start_time).total_seconds()) + + +def placeholder_to_dict(placeholder_list: List, select_all: bool = False) -> Dict: + """convert placeholder list to dict""" + if not placeholder_list: + return {} + + placeholder_dict = {} + for placeholder in placeholder_list: + if select_all or placeholder.get("need_optimize"): + placeholder_dict[placeholder["name"]] = placeholder["content"] + return placeholder_dict + + +def examples_to_string_list(example_list: List) -> List[str]: + """convert example list to string list""" + if not example_list: + return [] + example_string_format = (f"[query]: {TuneConstant.QUESTION_KEY}\n" + f"[assistant answer]: {TuneConstant.LABEL_KEY}") + example_string_list = [] + for example in example_list: + example_string_list.append(example_string_format.format( + question=example.get(TuneConstant.QUESTION_KEY, ""), + label=example.get(TuneConstant.LABEL_KEY, "") + )) + return example_string_list + +def get_example_question(example): + """get example question""" + content = example.get(TuneConstant.QUESTION_KEY, "") + if TuneConstant.RAW_PROMPT_TAG in content: + return str(example.get(TuneConstant.VARIABLE_KEY, "")) + return content \ No newline at end of file diff --git a/jiuwen/agent_builder/prompt_builder/tune/common/exception.py b/jiuwen/agent_builder/prompt_builder/tune/common/exception.py new file mode 100644 index 0000000000000000000000000000000000000000..04d306dbeb094224c9e0435b433951e78466468c --- /dev/null +++ b/jiuwen/agent_builder/prompt_builder/tune/common/exception.py @@ -0,0 +1,60 @@ +from enum import Enum + +class JiuWenException(Exception): + def __init__(self, + message, + *args, **kwargs): + super().__init__(message, args, kwargs) + +class JiuWenBaseException(Exception): + def __init__(self, + error_code, + message): + super().__init__(error_code, message) + self._error_code = error_code + self._message = message + + def __str__(self): + return f"[{self._error_code}]{self._message}" + + @property + def error_code(self): + return self._error_code + + @property + def message(self): + return self._message + + +class ParamCheckFailedException(JiuWenBaseException): + def __init__(self, message: str): + super().__init__(error_code=StatusCode.PARAM_CHECK_FAILED_ERROR.code, + message=f"{StatusCode.PARAM_CHECK_FAILED_ERROR.errmsg}, root cause = {message}") + + +class StatusCode(Enum): + SUCCESS = (200, "success") + + PARAM_CHECK_FAILED_ERROR = (100002, "Error occur when input parameter varification failed") + LLM_CONFIG_MISS_ERROR = (100021, "LLM service configuration is missing: {error_msg}") + LLM_FALSE_RESULT_ERROR = (102003, "LLM service return false result due to {error_msg}") + + PROMPT_OPTIMIZE_REFINE_INSTRUCTION_ERROR = ( + 102162, "Prompt optimization failed to refine instruction, root cause: {error_msg}" + ) + + PROMPT_OPTIMIZE_RESTART_TASK_ERROR = (102159, "Prompt optimization restart task error: {error_msg}") + PROMPT_OPTIMIZE_EVALUATE_ERROR = (102157, "Prompt optimization evaluate failed, root cause: {error_msg}") + PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR = ( + 102154, "Prompt optimization parameters are invalid, root cause = {error_msg}") + PROMPT_OPTIMIZE_CASE_VALIDATION_ERROR = ( + 102161, "Prompt optimization validate input case failed, root cause = {error_msg}" + ) + + @property + def code(self) -> int: + return self.value[0] + + @property + def errmsg(self) -> str: + return self.value[1] \ No newline at end of file diff --git a/jiuwen/agent_builder/prompt_builder/tune/common/singleton.py b/jiuwen/agent_builder/prompt_builder/tune/common/singleton.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb4d79ede583d0d24f592611c86f17dfee814af --- /dev/null +++ b/jiuwen/agent_builder/prompt_builder/tune/common/singleton.py @@ -0,0 +1,14 @@ +import abc +import threading + +singleton_lock = threading.Lock() + + +class Singleton(abc.ABCMeta, type): + _instances = {} + + def __call__(cls, *args, **kwargs): + with singleton_lock: + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] \ No newline at end of file diff --git a/jiuwen/agent_builder/prompt_builder/tune/joint_evaluator.py b/jiuwen/agent_builder/prompt_builder/tune/joint_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..fd00e6e1fbc43734ec4b83c23e49691af6af0e34 --- /dev/null +++ b/jiuwen/agent_builder/prompt_builder/tune/joint_evaluator.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- + +""" +prompt optimization evaluators +""" + +import re +import json +import threading +import copy +import random +from typing import Dict, Any +from concurrent.futures import ThreadPoolExecutor, as_completed + +from jiuwen.agent_builder.prompt_builder.tune.base.exception import JiuWenBaseException +from logging import getLogger +logger = getLogger(__name__) +from jiuwen.agent_builder.prompt_builder.tune.base.utils import OptimizeInfo, LLMModelProcess, LLMModelInfo +from jiuwen.agent_builder.prompt_builder.tune.base.exception import OnStopException +from jiuwen.agent_builder.prompt_builder.tune.base.constant import TuneConstant + + +class JointEvaluatorWithRef: + """Prompt evaluator, evaluate the response of model""" + def __init__(self, opt_model_info: LLMModelInfo, infer_model_info: LLMModelInfo, + optimize_info: OptimizeInfo, compare_prompt: str): + self.opt_model = LLMModelProcess(opt_model_info) + self.infer_model = LLMModelProcess(infer_model_info) + self._num_retires = optimize_info.num_retires + self._llm_parallel = optimize_info.num_parallel + self._evaluation_method = optimize_info.evaluation_method + self._compare_answer_prompt = compare_prompt + + @staticmethod + def parse_json(json_like_string: str) -> Dict[str, Any]: + """Parse json string""" + pattern = r"```json(.*?)```" + match = re.search(pattern, json_like_string, re.DOTALL) + + if match: + json_string = match.group(1).strip() + try: + parsed_data = json.loads(json_string) + except json.decoder.JSONDecodeError: + logger.warning("Failed to decode json string") + return None + return parsed_data + + logger.warning("No valid json string found") + return None + + @staticmethod + def compare_text(label: str, predict: str): + """Compare the predicted text with label""" + if not isinstance(label, str) or not isinstance(predict, str): + return 0, "Failed to compare non-string result" + result = label.strip() == predict.strip() + return int(result), "Same" if result else "Different" + + def compare_llm(self, question, label, predict): + """Compare the predicted text with label by LLM""" + prompt_template = self._compare_answer_prompt.format( + question=question, + answer_match=predict, + actual_answer=label + ) + try: + response = self.handle_inference_with_retry(prompt_template, is_assistant=False) + except JiuWenBaseException: + return 0, "Failed to compare result due to LLM calling" + + if not response or "content" not in response: + return 0, "Failed to get response from LLM" + + compare_result = self.parse_json(response.get("content")) + if not compare_result: + return 0, "Parse llm compare result failed" + + result = compare_result.get("result", False) + score = 0 + reason = compare_result.get("reason", "") + if result is True or (isinstance(result, str) and result.strip().lower() == "true"): + score = 1 + return score, reason + + def evaluate_result(self, question, label, predict): + """evaluate result""" + if self._evaluation_method == TuneConstant.EVALUATION_METHOD_TEXT: + return self.compare_text(label, predict) + return self.compare_llm(question, label, predict) + + def chat_completion(self, user_prompt, system_prompt, is_assistant: bool = True): + """get llm response""" + messages = [] + if system_prompt: + messages.append({TuneConstant.MESSAGE_ROLE_KEY: TuneConstant.SYSTEM_ROLE, + TuneConstant.MESSAGE_CONTENT_KEY: system_prompt}) + messages.append({TuneConstant.MESSAGE_ROLE_KEY: TuneConstant.USER_ROLE, + TuneConstant.MESSAGE_CONTENT_KEY: user_prompt}) + return self.infer_model.chat(messages) if is_assistant else self.opt_model.chat(messages) + + def handle_inference_with_retry(self, user_prompt, system_prompt=None, is_assistant: bool = True): + """chat llm with retry""" + for i in range(self._num_retires): + try: + return self.chat_completion(user_prompt, system_prompt, is_assistant) + except JiuWenBaseException as e: + logger.info(f"Inference failed at round {i}/{self._num_retires}: {str(e)}") + if i == self._num_retires - 1: + raise e + return None + + def infer_and_compare_example(self, prompt, example, stop_event: threading.Event): + """inference and compare example""" + if stop_event.is_set(): + raise OnStopException("Task stop from evaluation") + question, label = example.get(TuneConstant.QUESTION_KEY), example.get(TuneConstant.LABEL_KEY) + example_evaluation = copy.deepcopy(example) + tools = example_evaluation.pop(TuneConstant.TOOLS_KEY, None) + variable = example_evaluation.get(TuneConstant.VARIABLE_KEY, None) + if variable: + for var_name, var_content in variable.items(): + prompt = prompt.replace(f"{{{{{var_name}}}}}", var_content) + if TuneConstant.RAW_PROMPT_TAG in question: + full_question = question.replace(TuneConstant.RAW_PROMPT_TAG, prompt) + response = self.handle_inference_with_retry(full_question) + question = str(variable) + else: + if tools: + system_prompt_prefix = TuneConstant.DEFAULT_TOOL_CALL_PROMPT_PREFIX.replace( + "{{APIS_DESCRIPTION}}", str(tools) if tools else "" + ) + response = self.handle_inference_with_retry(question, system_prompt_prefix + prompt) + else: + response = self.handle_inference_with_retry(question, prompt) + predict = response.get(TuneConstant.MESSAGE_CONTENT_KEY, "") + if not predict: + example_evaluation["score"] = 0 + example_evaluation["predict"] = None + return example_evaluation, 0, "Failed to get predict" + + compare_result, reason = self.evaluate_result(question, label, predict) + example_evaluation["score"] = compare_result + example_evaluation["predict"] = predict + return example_evaluation, compare_result, reason + + def evaluate(self, prompt, dataset, stop_event: threading.Event): + """evaluate dataset""" + eval_results = [] + score_results = [] + num_workers = min(self._llm_parallel, len(dataset)) + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [ + executor.submit(self.infer_and_compare_example, prompt, example, stop_event) + for example in dataset + ] + + for future in as_completed(futures): + if stop_event.is_set(): + raise OnStopException("Task stop from evaluation") + + try: + example_evaluation, compare_result, _ = future.result() + except Exception as e: + logger.error(f"Error occur during evaluation: {str(e)}") + raise e + score_results.append(compare_result) + eval_results.append(example_evaluation) + accuracy = sum(score_results) / len(score_results) if score_results else 0 + eval_results = [item for item in eval_results if item.get("score", 0) == 0] + eval_results = random.sample(eval_results, min(TuneConstant.DEFAULT_MAX_SAMPLED_EXAMPLE_NUM, + len(eval_results))) + return accuracy, eval_results \ No newline at end of file diff --git a/jiuwen/agent_builder/prompt_builder/tune/joint_optimizer.py b/jiuwen/agent_builder/prompt_builder/tune/joint_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbaebbada5c10568858b2aa8c0f421e2b171d91 --- /dev/null +++ b/jiuwen/agent_builder/prompt_builder/tune/joint_optimizer.py @@ -0,0 +1,809 @@ +# -*- coding: utf-8 -*- + +""" +prompt optimization evaluators +""" + +import re +import random +import threading +import copy +from os.path import dirname, join +from dataclasses import dataclass +from typing import List, Dict, Optional, Any, Tuple + +from logging import getLogger +logger = getLogger(__name__) + +from jiuwen.agent_builder.prompt_builder.tune.common.exception import JiuWenBaseException, StatusCode +from jiuwen.agent_builder.prompt_builder.tune.base.exception import OnStopException +from jiuwen.agent_builder.prompt_builder.tune.base.constant import TuneConstant, TaskStatus +from jiuwen.agent_builder.prompt_builder.tune.base.utils import (LLMModelProcess, LLMModelInfo, TaskInfo, load_yaml_to_dict, JointParameters, + placeholder_to_dict, calculate_runtime, History, OptimizeInfo, get_example_question) +from jiuwen.agent_builder.prompt_builder.tune.base.context_manager import ContextManager, STOP_EVENT, Context +from jiuwen.agent_builder.prompt_builder.tune.joint_evaluator import JointEvaluatorWithRef +from jiuwen.agent_builder.prompt_builder.tune.base.case import CaseManager + +@dataclass +class SpecificMatch: + QUESTION_KEY = "[query]:" + ANSWER_KEY = "[assistant answer]:" + EXAMPLE_DELIMITER_PATTERN = r"(?s)(?<=)(.*?)(?=)" + +class JointOptimizer: + def __init__(self): + self._opt_model = None + self._infer_model = None + self.evaluator = None + self.dataset = None + self.params = JointParameters() + self.cur_iteration = 0 + self.best_accuracy = 0.0 + self.sampled_incorrect_data = [] + self.variable = [] + default_opt_prompt_config_path = join(dirname(__file__), "joint_prompt_pool.yaml") + self.prompt_pool = load_yaml_to_dict(default_opt_prompt_config_path) + self.instr_optimizer = None + + @staticmethod + def get_optimize_placeholder(placeholder): + """return a list of placeholders requiring optimization""" + try: + if not placeholder: + return None + return [p.get("name") for p in placeholder if p.get("need_optimize")] + except Exception as e: + logger.warning(f"get optimize placeholder error: {e}") + return None + + @staticmethod + def extract_optimized_prompt_from_response(content) -> Optional[str]: + """extract optimized prompt from response""" + optimized_prompt_pattern = r"(.*?)" + match = re.search(optimized_prompt_pattern, content, re.DOTALL) + if not match: + return None + optimized_prompt = match.group(1) + return optimized_prompt.replace("", "").replace("", "") + + @staticmethod + def extract_optimized_placeholder_from_response(content) -> Optional[List]: + """extract optimized placeholder from response""" + optimized_placeholder_pattern = r"(.*?)" + match = re.search(optimized_placeholder_pattern, content, re.DOTALL) + if not match: + return None + optimized_placeholder = match.group(1) + single_placeholder_pattern = re.compile(r"<(.*?)>(.*?)", re.S) + matches = single_placeholder_pattern.findall(optimized_placeholder) + return [{TuneConstant.NAME_KEY: tag.strip(), + TuneConstant.MESSAGE_CONTENT_KEY: content.strip() + } for tag, content in matches] + + @staticmethod + def fill_prompt(instruction: str, placeholder: List) -> str: + if not placeholder: + return instruction + for p in placeholder: + name = p.get(TuneConstant.NAME_KEY) + content = p.get(TuneConstant.MESSAGE_CONTENT_KEY) + if name and content: + instruction = instruction.replace(f"{{{{{name}}}}}", content) + return instruction + + @staticmethod + def extract_examples_from_response(content): + """extract examples from response""" + optimized_examples = [] + question_pattern = re.compile( + re.escape(SpecificMatch.QUESTION_KEY) + r"(.*?)" + + re.escape(SpecificMatch.ANSWER_KEY) + r"(.*)", + re.DOTALL + ) + + for text in re.findall(SpecificMatch.EXAMPLE_DELIMITER_PATTERN, content): + match = question_pattern.search(text.strip()) + if match: + question = match.group(1).strip() + answer_with_reason = match.group(2).strip() + optimized_examples.append({TuneConstant.QUESTION_KEY: question, + TuneConstant.LABEL_KEY: answer_with_reason}) + return optimized_examples + + @staticmethod + def prepare_optimization_template(template, instruction, placeholders, error_cases, reflection, tools): + """generate a critique prompt""" + prompt = template.replace("{{PROMPT_META_TEMPLATE}}", instruction) + prompt = prompt.replace("{{PLACEHOLDER_CONTENTS}}", placeholders) + prompt = prompt.replace("{{ERROR_CASES}}", error_cases) + prompt = prompt.replace("{{REFLECTIONS_ON_ERROR_CASES}}", reflection) + prompt = prompt.replace("{{API_TOOLS_DESCRIPTION}}", tools) + return prompt + + @staticmethod + def validate_placeholder(prompt, placeholders): + """validate the placeholder""" + placeholders_in_prompt = re.findall(r"\{\{([\w_]+)\}\}", prompt) + placeholder_names = [item.get("name", "") for item in placeholders] + for item in placeholders: + if not item.get("name"): + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.code, + message=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.errmsg.format( + error_msg="Placeholder item 'name' cannot be empty")) + if not item.get("content"): + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.code, + message=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.errmsg.format( + error_msg="Placeholder item 'content' cannot be empty")) + if "need_optimize" not in item: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.code, + message=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.errmsg.format( + error_msg="Placeholder item 'need_optimize' cannot be empty")) + + if not isinstance(item["name"], str): + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.code, + message=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.errmsg.format( + error_msg="Placeholder item 'name' must be a string")) + if not isinstance(item["content"], str): + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.code, + message=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.errmsg.format( + error_msg="Placeholder item 'content' must be a string")) + if not isinstance(item["need_optimize"], bool): + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.code, + message=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.errmsg.format( + error_msg="Placeholder item 'need_optimize' must be a boolean")) + if not all(name in placeholders_in_prompt for name in placeholder_names): + missing_names = ", ".join([name for name in placeholder_names if name not in placeholders_in_prompt]) + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.code, + message=StatusCode.PROMPT_OPTIMIZE_INVALID_PARAMS_ERROR.errmsg.format( + error_msg=f"Placeholder names {missing_names} not found in prompt")) + logger.info("Schema validation succeeded.") + + @staticmethod + def _check_stop_event(context: Context) -> bool: + """check optimization stop event""" + task_id = context.get("id", "") + if context and context.get(STOP_EVENT) and context.get(STOP_EVENT).is_set(): + raise OnStopException(f"Task {task_id} stopped") + return True + + @staticmethod + def get_variable_from_dataset(dataset, prompt): + """get variable from dataset""" + variable = set() + for case in dataset: + vars = case.get(TuneConstant.VARIABLE_KEY, {}) + variable.update(vars) + + reordered_variable = [(prompt.find(f"{{{{{v}}}}}"), v) for v in variable] + reordered_variable = [v[1] for v in sorted(reordered_variable, key=lambda x: x[0], reverse=False)] + return reordered_variable + + def chat_completion(self, user_prompt, system_prompt=None, is_assistant=False): + """generate chat completion""" + messages = [] + if system_prompt: + messages.append({TuneConstant.MESSAGE_ROLE_KEY: TuneConstant.SYSTEM_ROLE, + TuneConstant.MESSAGE_CONTENT_KEY: system_prompt}) + messages.append({TuneConstant.MESSAGE_ROLE_KEY: TuneConstant.USER_ROLE, + TuneConstant.MESSAGE_CONTENT_KEY: user_prompt}) + retries = TuneConstant.DEFAULT_LLM_CALL_RETRY_NUM + for i in range(retries): + try: + response = self._opt_model.chat(messages) if not is_assistant else self._infer_model.chat(messages) + if not response or response.get(TuneConstant.MESSAGE_CONTENT_KEY) is None: + raise JiuWenBaseException( + StatusCode.LLM_FALSE_RESULT_ERROR.code, + StatusCode.LLM_FALSE_RESULT_ERROR.errmsg.format( + error_msh="call llm service get empty response" + ) + ) + return response.get(TuneConstant.MESSAGE_CONTENT_KEY) + except (JiuWenBaseException, KeyError, AttributeError, TypeError) as e: + logger.info(f"Inference failed at round {i}/{retries}: {str(e)}") + if i == retries - 1: + raise e + return None + + def resample_examples(self, sampled_incorrect_data): + """resampling examples""" + dataset = copy.deepcopy(self.dataset) + num_samples = self.params.num_examples + self.params.num_cot_examples + if len(sampled_incorrect_data) >= num_samples: + return sampled_incorrect_data + if len(dataset) <= num_samples: + return dataset + + num_to_add = num_samples - len(sampled_incorrect_data) + unique_data = [] + query_set = set() + for data in sampled_incorrect_data: + if not self.variable: + if data.get(TuneConstant.QUESTION_KEY, "") not in query_set: + query_set.add(data[TuneConstant.QUESTION_KEY]) + unique_data.append(data) + continue + if str(data[TuneConstant.VARIABLE_KEY]) in query_set: + query_set.add(data[TuneConstant.VARIABLE_KEY]) + unique_data.append(data) + remaining_data = [data for data in dataset if data not in unique_data] + sampled_new_data = random.sample(remaining_data, num_to_add) + return sampled_new_data + sampled_incorrect_data + + def evaluate(self, prompt, context: Context): + """evaluate dataset""" + try: + accuracy, sampled_incorrect_data = self.evaluator.evaluate(prompt, self.dataset, + context.get(STOP_EVENT)) + except JiuWenBaseException as e: + raise e + if not self._check_stop_event(context) or accuracy is None or sampled_incorrect_data is None: + return None,None + sampled_data = self.resample_examples(sampled_incorrect_data) + return accuracy, sampled_data + + def get_task_description(self): + """get task description""" + prompt = self.prompt_pool.get("get_task_description").format( + instruction=self.params.base_instructions + ) + return self.chat_completion(prompt) + + def get_answer_format(self): + """get answer format""" + prompt = self.prompt_pool.get("get_answer_format").format( + instruction=self.params.base_instructions + ) + return self.chat_completion(prompt) + + def init_parameters(self, optimize_info: OptimizeInfo, raw_prompt, context: Context): + """init parameters""" + self.params.num_iterations = optimize_info.num_iterations + self.params.num_examples = optimize_info.num_examples + self.params.num_cot_examples = optimize_info.num_cot_examples + self.params.placeholders = copy.deepcopy(optimize_info.placeholder) + self.params.original_placeholders = copy.deepcopy(optimize_info.placeholder) + self.params.base_instructions = raw_prompt[0] + self.params.opt_placeholder_names = self.get_optimize_placeholder(self.params.placeholders) + self.params.filled_instructions = self.fill_prompt(self.params.base_instructions, + self.params.placeholders) + self.save_state(context, force_save=True) + try: + self.params.task_description = self.get_task_description() + self.params.answer_format = self.get_answer_format() + except JiuWenBaseException as e: + raise e + + def prompt_combine(self, instruction: str, example_string=None, cot_example_string=None): + """prompt combine""" + variable_section = "\n".join([f"【{key}】: {{{{{key}}}}}" for key in self.variable]) + + prompt_parts = [ + instruction, + f"\n## 示例\n{example_string}" if example_string else "", + f"\n## 包含思维链示例\n{cot_example_string}" if cot_example_string else "", + f"\n{self.params.answer_format.strip()}" if self.params.answer_format else "", + f"\n\n用户输入:\n{variable_section}" if variable_section else "", + f"\n Output:" + ] + return "".join(prompt_parts) + + def update_placeholder(self, new_placeholders: List, placeholders_to_update: List): + """update placeholders""" + if not new_placeholders or not self.params.opt_placeholder_names: + return + for update_ph in new_placeholders: + name = update_ph.get(TuneConstant.NAME_KEY) + content = update_ph.get(TuneConstant.MESSAGE_CONTENT_KEY) + if not name or not content or name not in self.params.opt_placeholder_names: + continue + for i, base_ph in enumerate(placeholders_to_update): + if name == base_ph.get(TuneConstant.NAME_KEY): + placeholders_to_update[i][TuneConstant.MESSAGE_CONTENT_KEY] = content + break + + def optimize_instruction_by_gradient(self, tools: Optional[List]) -> Tuple[Optional[str], Optional[List]]: + """optimize instruction by text gradient""" + instruction = self.params.base_instructions + + error_example_string = "".join( + self.prompt_pool["quest_reason_ans_error"].format( + question=get_example_question(example), + answer=example.get(TuneConstant.LABEL_KEY, ""), + predict=example.get(TuneConstant.PREDICT_KEY, ""), + ) for example in self.sampled_incorrect_data + ) + + if tools: + tool_prefix = TuneConstant.DEFAULT_TOOL_CALL_PROMPT_PREFIX.replace( + "{{APIS_DESCRIPTION}}", str(tools) if tools else "None" + ) + prompt = tool_prefix + instruction + else: + prompt = instruction + prompt_critique_template = self.prompt_pool["prompt_critique_template"].format( + instruction=prompt, + examples=error_example_string + ) + text_gradient = self.chat_completion(prompt_critique_template) + if not self.params.placeholders: + prompt_update_result = self.optimize_instruction_without_placeholder(instruction, error_example_string, + text_gradient, tools) + placeholder_update_result = None + else: + prompt_update_result, placeholder_update_result = self.optimize_instruction_with_placeholder( + instruction, error_example_string, text_gradient, tools + ) + return prompt_update_result, placeholder_update_result + + def optimize_instruction_without_placeholder(self, instruction, error_example_string, text_gradient, tools): + """update instruction""" + optimize_prompt_template = self.prompt_pool["optimize_prompt_instruction_template"] + optimize_prompt_template = optimize_prompt_template.replace("{{PROMPT_META_TEMPLATE}}", instruction) + optimize_prompt_template = optimize_prompt_template.replace("{{ERROR_CASES}}", error_example_string) + optimize_prompt_template = optimize_prompt_template.replace("{{REFLECTIONS_ON_ERROR_CASES}}", text_gradient) + optimize_prompt_template = optimize_prompt_template.replace("{{API_TOOLS_DESCRIPTION}}", str(tools) if tools else "None") + optimized_instruction_response = self.chat_completion(optimize_prompt_template) + return self.extract_optimized_prompt_from_response(optimized_instruction_response) + + def optimize_instruction_with_placeholder(self, instruction, error_example_string, text_gradient, tools): + """update instruction with placeholder""" + # optimize instruction + placeholder_content = "".join( + f"<{p.get(TuneConstant.NAME_KEY)}>\n" + f"{p.get(TuneConstant.MESSAGE_CONTENT_KEY)}\n" + f"\n" for p in self.params.placeholders + ) + placeholder_content += "".join( + f"<{v}>\n{{{{{v}}}}}\n\n" for v in self.variable + ) + prompt_critique_template = self.prepare_optimization_template( + self.prompt_pool["optimize_prompt_instruction_template_with_placeholder"], + instruction, placeholder_content, error_example_string, text_gradient, str(tools) if tools else "None" + ) + optimized_instruction_response = self.chat_completion(prompt_critique_template) + optimized_instruction = self.extract_optimized_prompt_from_response(optimized_instruction_response) + + # optimize placeholder + opt_placeholder_names = str(self.params.opt_placeholder_names) if self.params.opt_placeholder_names else "" + optimized_placeholder = copy.deepcopy(self.params.placeholders) + if not self.params.opt_placeholder_names: + return optimized_instruction, optimized_placeholder + fixed_placeholders = [p for p in self.params.placeholders + if p.get(TuneConstant.NAME_KEY) not in self.params.opt_placeholder_names] + need_opt_placeholders = [p for p in self.params.placeholders + if p.get(TuneConstant.NAME_KEY) in self.params.opt_placeholder_names] + instruction = self.fill_prompt(instruction, fixed_placeholders) + placeholder_content = "".join( + f"<{p.get(TuneConstant.NAME_KEY)}>\n" + f"{p.get(TuneConstant.MESSAGE_CONTENT_KEY)}\n" + f"\n" for p in need_opt_placeholders + ) + placeholder_critique_template = self.prepare_optimization_template( + self.prompt_pool["optimize_prompt_placeholder_template"], + instruction, placeholder_content, error_example_string, text_gradient, str(tools) if tools else "None" + ) + placeholder_critique_template = placeholder_critique_template.replace( + "{{PLACEHOLDER_TO_OPTIMIZE}}", opt_placeholder_names + ) + updated_placeholders_response = self.chat_completion(placeholder_critique_template) + updated_placeholders = self.extract_optimized_placeholder_from_response(updated_placeholders_response) + self.update_placeholder(updated_placeholders, optimized_placeholder) + return optimized_instruction, optimized_placeholder + + def select_best_examples(self, context: Context) -> List: + """select best examples""" + if not self._check_stop_event(context) or self.params.num_examples == 0: + return [] + + try: + error_example_string = "".join( + self.prompt_pool["quest_reason_ans"].format( + question=get_example_question(example), + answer=example.get(TuneConstant.LABEL_KEY, ""), + ) for example in self.sampled_incorrect_data + ) + gt_example = random.sample(self.dataset, 1)[0] + gt_example_string = self.prompt_pool["gt_quest_reason_ans"].format( + question=get_example_question(gt_example), + answer=gt_example.get(TuneConstant.LABEL_KEY, "") + ) + examples_select_template = self.prompt_pool["examples_select_template"].format( + gt_example=gt_example_string, + task_description=self.params.base_instructions, + num_examples=self.params.num_examples, + error_examples=error_example_string + ) + response = self.chat_completion(examples_select_template) + return self.extract_examples_from_response(response) + + except (KeyError, TypeError, AttributeError, IndexError) as e: + logger.warning(f"Error occur while selecting best examples: {e}") + return [] + + def generate_best_reasoning_examples(self, context: Context) -> List: + """generate best reasoning examples""" + if not self._check_stop_event(context) or self.params.num_cot_examples == 0: + return [] + + try: + examples_string = self._get_example_string(self.params.cot_examples) + error_example_string = "".join( + self.prompt_pool["quest_reason_ans_error"].format( + question=get_example_question(example), + answer=example.get(TuneConstant.LABEL_KEY, ""), + predict=example.get(TuneConstant.PREDICT_KEY, "") + ) for example in self.sampled_incorrect_data + ) + + examples_critique_template = self.prompt_pool["error_examples_critique_template"].format( + examples=examples_string, + task_description=self.params.filled_instructions, + num_examples=self.params.num_cot_examples, + error_examples=error_example_string + ) + response = self.chat_completion(examples_critique_template) + examples_optimization_template = self.prompt_pool["examples_optimization_template"].format( + examples=examples_string, + critique=response, + task_description=self.params.filled_instructions, + num_examples=self.params.num_cot_examples + ) + response = self.chat_completion(examples_optimization_template) + return self.extract_examples_from_response(response) + except (KeyError, TypeError, AttributeError) as e: + import traceback + traceback.print_exc() + logger.warning(f"Error occur while generating best reasoning examples: {e}") + return [] + + def get_example_reasoning(self, question, answer): + """get example reason from question and answer""" + try: + reasoning_template = self.prompt_pool["get_example_reasoning_template"].format( + task_descriotion=self.params.task_description, + instruction=self.params.base_instructions, + question=question, + answer=answer + ) + response = self.chat_completion(reasoning_template) + return response + except (KeyError, TypeError, AttributeError) as e: + logger.warning(f"Error occur while getting example reason from question: {e}") + return "" + + def evaluate_baseline(self, context: Context) -> History: + """evaluate baseline""" + if not self._check_stop_event(context): + raise JiuWenBaseException(StatusCode.PROMPT_OPTIMIZE_EVALUATE_ERROR.code, + StatusCode.PROMPT_OPTIMIZE_EVALUATE_ERROR.errmsg.format( + error_msg="evaluation baseline stopped" + )) + prompt = self.fill_prompt(self.params.base_instructions, self.params.placeholders) + self.best_accuracy, self.sampled_incorrect_data = self.evaluate(prompt, context) + if self.best_accuracy is None or self.sampled_incorrect_data is None: + raise JiuWenBaseException(StatusCode.PROMPT_OPTIMIZE_EVALUATE_ERROR.code, + StatusCode.PROMPT_OPTIMIZE_EVALUATE_ERROR.errmsg.format( + error_msg="evaluation baseline failed" + )) + return History(optimized_prompt=prompt, iteration_round=0, success_rate=self.best_accuracy) + + def sample_example(self, num_examples: int): + """sample example""" + dataset = copy.deepcopy(self.dataset) + error_cases = self.sampled_incorrect_data + if num_examples >= len(dataset): + return [{ + TuneConstant.QUESTION_KEY: get_example_question(data), + TuneConstant.LABEL_KEY: data.get(TuneConstant.LABEL_KEY, "") + } for data in dataset] + + sampled_examples = [] + if error_cases: + num_error_examples = min(num_examples, len(error_cases)) + sampled_examples.extend(random.sample(error_cases, num_error_examples)) + + if len(sampled_examples) < num_examples: + num_remaining_examples = num_examples - len(sampled_examples) + remaining_examples = [ex for ex in dataset if ex not in sampled_examples] + sampled_examples.extend(random.sample(remaining_examples, num_remaining_examples)) + else: + sampled_examples.extend(random.sample(dataset, num_examples)) + + return [{ + TuneConstant.QUESTION_KEY: get_example_question(data), + TuneConstant.LABEL_KEY: data.get(TuneConstant.LABEL_KEY, "") + } for data in sampled_examples] + + def prepare_fewshot_examples(self): + """prepare fewshot examples""" + self.params.examples = self.sample_example(self.params.num_examples) + raw_cot_examples = self.sample_example(self.params.num_cot_examples) + if not self.params.cot_examples: + return + self.params.cot_examples = [ + {**example, + TuneConstant.LABEL_KEY: self.get_example_reasoning( + get_example_question(example), example.get(TuneConstant.LABEL_KEY, "") + )} + for example in raw_cot_examples + ] + + def do_optimize(self, + task_info: TaskInfo, + raw_templates: List[str], + optimize_info: OptimizeInfo, + opt_model_info: LLMModelInfo, + infer_model_info: LLMModelInfo + ): + """do prompt optimization""" + original_prompt = raw_templates[0] + infer_model_info = infer_model_info or opt_model_info + context = dict(id=task_info.task_id, name=task_info.task_name, desc=task_info.task_description, + create_time=task_info.create_time, run_time=0, + raw_templates=raw_templates, optimize_info=optimize_info, + opt_model_info=opt_model_info, infer_model_info=infer_model_info, + error_msg="", stop_event=threading.Event(), status=TaskStatus.TASK_RUNNING, + history=[], cur_iteration=0, best_accuracy=0.0) + ContextManager().set(task_info.task_id, context) + try: + self._opt_model = LLMModelProcess(opt_model_info) + self._infer_model = LLMModelProcess(infer_model_info) + result_compare_template = self.prompt_pool["result_compare_template"].replace( + "{user_compare_rules}", optimize_info.user_compare_rules + ) + self.evaluator = JointEvaluatorWithRef(opt_model_info, infer_model_info, + optimize_info, result_compare_template) + self.dataset = CaseManager.validate_with_convert([case.model_dump() for case in optimize_info.cases], + CaseManager.default_convertor, + default_tools=optimize_info.tools) + self.variable = self.get_variable_from_dataset(self.dataset, original_prompt) + self.validate_placeholder(original_prompt, optimize_info.placeholder) + self.init_parameters(optimize_info, raw_templates, context) + baseline_history = self.evaluate_baseline(context) + for v in self.variable: + self.params.base_instructions = self.params.base_instructions.replace(f"{{{{{v}}}}}", v) + self.params.filled_instructions = self.fill_prompt(self.params.base_instructions, optimize_info.placeholder) + self.prepare_fewshot_examples() + self.save_state(context, baseline_history) + self.optimize_prompt_iteratively(context, 0) + except OnStopException: + ContextManager().set_context_attr(task_info.task_id, TaskStatus.TASK_STATUS, TaskStatus.TASK_STOPPED) + logger.info(f"Joint optimization task {task_info.task_id} stopped.") + return + except Exception as e: + import traceback + traceback.print_exc() + context[TaskStatus.TASK_STATUS] = TaskStatus.TASK_FAILED + context["run_time"] = calculate_runtime(context.get("create_time", "")) + checkpoint = ContextManager().get_checkpoint(task_info.task_id) or context + error_reason = str(e) if isinstance(e, JiuWenBaseException) else "other reason" + checkpoint["error_msg"] = f"Joint optimization task failed, reason: {error_reason}" + checkpoint["run_time"] = calculate_runtime(context.get("create_time", "")) + checkpoint[TaskStatus.TASK_STATUS] = TaskStatus.TASK_FAILED + ContextManager().set_checkpoint(task_info.task_id, checkpoint) + + def continue_optimize(self, task_id: str): + """continue optimization""" + context = ContextManager().get_checkpoint(task_id) + context["id"] = task_id + ContextManager().set(task_id, context) + run_time = calculate_runtime(context.get("create_time", "")) + ContextManager().set_context_attr(task_id, TaskStatus.TASK_STATUS, TaskStatus.TASK_RUNNING) + ContextManager().set_context_attr(task_id, "run_time", run_time) + raw_templates: List[str] = context.get("raw_templates", []) + optimize_info: OptimizeInfo = context.get("optimize_info", None) + opt_model_info: LLMModelInfo = context.get("opt_model_info", None) + infer_model_info: LLMModelInfo = context.get("infer_model_info", None) + is_loaded = self.load_state(context) + if optimize_info is None or not is_loaded: + raise JiuWenBaseException(StatusCode.PROMPT_OPTIMIZE_RESTART_TASK_ERROR.code, + StatusCode.PROMPT_OPTIMIZE_RESTART_TASK_ERROR.errmsg.format( + error_msg="load context failed" + )) + if self.cur_iteration == 0: + return self.do_optimize( + TaskInfo(task_id=task_id, task_name=context.get("name"), + task_description=context.get("desc", ""), create_time=context.get("create_time", "")), + raw_templates, optimize_info, opt_model_info, infer_model_info + ) + try: + self._opt_model = LLMModelProcess(opt_model_info) + self._infer_model = LLMModelProcess(infer_model_info) + result_compare_template = self.prompt_pool["result_compare_template"].replace( + "{user_compare_rules}", optimize_info.user_compare_rules + ) + self.evaluator = JointEvaluatorWithRef(opt_model_info, infer_model_info, + optimize_info, result_compare_template) + + self.dataset = CaseManager.validate_with_convert([case.model_dump() for case in optimize_info.cases], + CaseManager.default_convertor, + default_tools=optimize_info.tools) + self.variable = self.get_variable_from_dataset(self.dataset, raw_templates[0]) + logger.info(f"Prompt optimization task {task_id} restarted.") + self.optimize_prompt_iteratively(context, self.cur_iteration) + except OnStopException: + ContextManager().set_context_attr(task_id, TaskStatus.TASK_STATUS, TaskStatus.TASK_STOPPED) + logger.info(f"Joint optimization task {task_id} stopped.") + return + except Exception as e: + context[TaskStatus.TASK_STATUS] = TaskStatus.TASK_FAILED + context["run_time"] = calculate_runtime(context.get("create_time", "")) + checkpoint = ContextManager().get_checkpoint(task_id) or context + error_reason = str(e) if isinstance(e, JiuWenBaseException) else "other reason" + checkpoint["error_msg"] = f"Joint optimization task failed, reason: {error_reason}" + checkpoint["run_time"] = calculate_runtime(context.get("create_time", "")) + checkpoint[TaskStatus.TASK_STATUS] = TaskStatus.TASK_FAILED + ContextManager().set_checkpoint(task_id, checkpoint) + + def load_state(self, context: Context): + """load task state""" + if "params" not in context: + return False + self.params: JointParameters = context.get("params") + self.cur_iteration = context.get("cur_iteration", 0) + self.best_accuracy = self.params.get("best_accuracy", 0.0) + self.params.filled_instructions = self.fill_prompt(self.params.base_instructions, self.params.placeholders) + self.sampled_incorrect_data = context.get("sampled_incorrect_data", []) + return True + + def save_state(self, context: Context, history: Optional[History] = None, force_save: bool = False): + """save task state""" + if not force_save and not self._check_stop_event(context): + return + + example_string = self._get_example_string(self.params.examples) + cot_example_string = self._get_example_string(self.params.cot_examples) + self.params.full_prompt = self.prompt_combine(self.params.filled_instructions, + example_string, cot_example_string) + context["cur_iteration"] = self.cur_iteration + context["params"] = self.params + context["best_accuracy"] = self.best_accuracy + context["sampled_incorrect_data"] = self.sampled_incorrect_data + context["run_time"] = calculate_runtime(context.get("create_time", "")) + if history: + original_placeholder = {} + if self.params.original_placeholders: + for ph in self.params.original_placeholders: + original_placeholder[ph.get(TuneConstant.NAME_KEY)] = ph.get(TuneConstant.MESSAGE_CONTENT_KEY) + history.original_placeholder = original_placeholder + context.get("history", []).append(history) + task_id = context.get("id") + ContextManager().set_checkpoint(task_id, context) + + def optimize_prompt_iteratively(self, context: Context, begin_iteration: int): + """optimize prompt iteratively""" + logger.info("Optimizing prompt instruction and examples iteratively...") + history = None + need_optimize_example = self.params.num_examples > 0 or self.params.num_cot_examples > 0 + for iter in range(begin_iteration, self.params.num_iterations): + self.cur_iteration = iter + 1 + is_optimize_instruction = random.choice([True, False]) if need_optimize_example else True + logger.info(f"Task-{context.get('id')} at iteration {iter} / {self.params.num_iterations}" + f"start optimize {'instruction' if is_optimize_instruction else 'examples'}") + if is_optimize_instruction: + history = self._optimize_instruction(context) + else: + history = self._optimize_examples(context) + + if iter < self.params.num_iterations - 1: + self.save_state(context, history) + context[TaskStatus.TASK_STATUS] = TaskStatus.TASK_FINISHED + context["run_time"] = calculate_runtime(context.get("create_time", "")) + self.save_state(context, history) + logger.info(f"Joint optimization task-{context.get('id')} finished.") + + def _optimize_instruction(self, context: Context) -> Optional[History]: + """optimize instruction""" + optimize_info: OptimizeInfo = context.get("optimize_info", None) + tools = optimize_info.tools + optimized_instruction, optimized_placeholder = self.optimize_instruction_by_gradient(tools) + if not optimized_instruction: + raise JiuWenBaseException( + StatusCode.PROMPT_OPTIMIZE_REFINE_INSTRUCTION_ERROR.code, + StatusCode.PROMPT_OPTIMIZE_REFINE_INSTRUCTION_ERROR.errmsg.format( + error_msg="optimize instruction failed, get empty result" + ) + ) + full_prompt = self._get_full_prompt(optimized_instruction, optimized_placeholder) + accuracy, error_cases = self.evaluate(full_prompt, context) + if accuracy is None: + raise JiuWenBaseException( + StatusCode.PROMPT_OPTIMIZE_REFINE_INSTRUCTION_ERROR.code, + StatusCode.PROMPT_OPTIMIZE_REFINE_INSTRUCTION_ERROR.errmsg.format( + error_msg="failed to get evaluation result" + ) + ) + if accuracy > self.best_accuracy: + self.sampled_incorrect_data = error_cases + self.best_accuracy = accuracy + self.params.base_instructions = optimized_instruction + self.update_placeholder(optimized_placeholder, self.params.placeholders) + self.params.filled_instructions = self.fill_prompt(optimized_instruction, optimized_placeholder) + self.params.full_prompt = full_prompt + opt_placeholder_dict = {} + if optimized_placeholder and self.params.opt_placeholder_names: + for p in optimized_placeholder: + placeholder_name = p.get(TuneConstant.NAME_KEY) + if placeholder_name in self.params.opt_placeholder_names: + opt_placeholder_dict[placeholder_name] = p.get(TuneConstant.MESSAGE_CONTENT_KEY, "") + return History(optimized_prompt=optimized_instruction, + optimized_placeholder=opt_placeholder_dict, + examples=self._get_examples_string_list(self.params.examples) + + self._get_examples_string_list(self.params.cot_examples), + filled_prompt=full_prompt, + success_rate=accuracy, + iteration_round=self.cur_iteration) + + def _optimize_examples(self, context: Context) -> Optional[History]: + """optimize examples""" + if self.params.num_examples == 0 and self.params.num_cot_examples == 0: + return None + selected_examples = [] + if self.params.num_examples > 0: + for _ in range(TuneConstant.DEFAULT_LLM_CALL_RETRY_NUM): + selected_examples = self.select_best_examples(context) + if selected_examples: + break + cot_examples = self.generate_best_reasoning_examples(context) + examples_string = self._get_example_string(selected_examples) + cot_examples_string = self._get_example_string(cot_examples) + + full_prompt = self.prompt_combine(self.params.filled_instructions, examples_string, cot_examples_string) + accuracy, error_cases = self.evaluate(full_prompt, context) + + if accuracy is None: + raise JiuWenBaseException( + StatusCode.PROMPT_OPTIMIZE_REFINE_INSTRUCTION_ERROR.code, + StatusCode.PROMPT_OPTIMIZE_REFINE_INSTRUCTION_ERROR.errmsg.format( + error_msg="failed to get evaluation result" + ) + ) + + if accuracy > self.best_accuracy: + self.sampled_incorrect_data = error_cases + self.best_accuracy = accuracy + self.params.examples = selected_examples or [] + self.params.cot_examples = cot_examples or [] + self.params.full_prompt = full_prompt + return History(optimized_prompt=self.params.base_instructions, + optimized_placeholder=placeholder_to_dict(self.params.placeholders), + examples=self._get_examples_string_list(self.params.examples) + + self._get_examples_string_list(self.params.cot_examples), + filled_prompt=full_prompt, + success_rate=accuracy, + iteration_round=self.cur_iteration) + + def _get_example_string(self, examples: List): + """get example string""" + if not examples: + return "" + return "\n".join(self._get_examples_string_list(examples)) + + def _get_examples_string_list(self, examples: List) -> List[str]: + """get example string list""" + if not examples: + return [] + example_string_template = self.prompt_pool["quest_reason_ans"] + example_string_list = [] + for i, example in enumerate(examples): + formated_example_string = example_string_template.format( + question=get_example_question(example), + answer=example.get(TuneConstant.LABEL_KEY, "") + ) + example_string_list.append(f"示例{i + 1}:\n{formated_example_string}") + return example_string_list + + def _get_full_prompt(self, instruction, placeholders): + """get full prompt""" + filled_instruction = self.fill_prompt(instruction, placeholders) + example_string = self._get_example_string(self.params.examples) + cot_example_string = self._get_example_string(self.params.cot_examples) + full_prompt = self.prompt_combine(filled_instruction, example_string, cot_example_string) + return full_prompt \ No newline at end of file diff --git a/jiuwen/agent_builder/prompt_builder/tune/joint_prompt_pool.yaml b/jiuwen/agent_builder/prompt_builder/tune/joint_prompt_pool.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab75ff735aec4fd6d454f572a4bab1c9b5d0fa7c --- /dev/null +++ b/jiuwen/agent_builder/prompt_builder/tune/joint_prompt_pool.yaml @@ -0,0 +1,238 @@ +quest_reason_ans: | + [query]: {question} + [assistant answer]: {answer} + === + +gt_quest_reason_ans: | + [query]: {question} + [assistant answer]: {answer} + +quest_reason_ans_error: | + [query]: {question} + [expected answer]: {answer} + [assistant answer]: {predict} + === + +prompt_critique_template: | + 作为提示词优化专家,我的目标是帮助代理高效且成功地完成任务 + 当前的提示词是:“{instruction}” + 然而,这个提示词在以下实例中并未能给出正确的结果:“{examples}” + 请提供详细的反馈,分析指令可能出错的原因。 + 针对每个实例,具体说明指令中的问题,解释代理为何会误解指令,并提出如何让指令更加清晰和精确的建议 + 每个反馈信息请用包裹 + +optimize_prompt_instruction_template: | + 你是一位提示词优化专家,你的任务是根据提供的信息对提示词进行优化。具体信息如下: + 首先,请阅读以下提示词: + + {{PROMPT_META_TEMPLATE}} + + + 你拥有的工具或API说明如下: + + {{API_TOOLS_DESCRIPTION}} + + + 提示词在应用的过程中出现的错误case如下: + + {{ERROR_CASES}} + + + 对这些错误case的反思如下: + + {{REFLECTIONS_ON_ERROR_CASES}} + + + 在优化提示词模板时,请遵循如下要求: + 1. 在`<思考>`标签中,请根据错误示例及其对应的反思内容,深入、全面地分析提示词中可能导致错误的部分。分析应覆盖:错误原因的识别、原始提示词中存在的问题,以及通过哪些具体修改可以有效规避这些问题。 + 2. 在``标签中,基于上述分析,输出优化后的提示词版本。 + 3. 分析过程中应聚焦于问题的具体成因,结合模板结构、语意表达和格式规范等方面,系统性地进行优化。 + 4. 优化过程中务必信息表达完整、逻辑侵袭,不可遗漏重要内容或引入模糊表达 + + 输出格式: + <思考> + [在此详细说明你对提示词的优化分析] + + + [在此输出优化后的提示词] + + 请确保优化后的内容能够有效避免之前出现的错误case。 + +optimize_prompt_instruction_template_with_placeholder: | + 你是一位提示词优化专家,你的任务是根据提供的信息对提示词进行优化。具体信息如下: + 你拥有的工具或API说明如下: + + {{API_TOOLS_DESCRIPTION}} + + + 以下是提示词中出现的占位符名称以及对应内容: + + {{PLACEHOLDER_CONTENTS}} + + + 提示词在应用的过程中出现的错误case如下: + + {{ERROR_CASES}} + + + 以下是对这些错误case的反思: + + {{REFLECTIONS_ON_ERROR_CASES}} + + + 以下是待优化的提示词模板,注意模板内容边界从开始,到结束: + + {{PROMPT_META_TEMPLATE}} + + + 在优化提示词时,请遵循以下要求: + 1. 在`<思考>`标签中,请根据错误示例及其对应的反思内容,深入、全面地分析提示词中可能导致错误的部分。分析应覆盖:错误原因的识别、原始提示词中存在的问题,以及通过哪些具体修改可以有效规避这些问题。 + 2. 在``标签中,基于上述分析,输出优化后的提示词版本。 + 3. 分析过程中应聚焦于问题的具体成因,结合模板结构、语意表达和格式规范等方面,系统性地进行优化。 + 4. 优化过程中务必信息表达完整、逻辑侵袭,不可遗漏重要内容或引入模糊表达 + 5. 【重要】优化后的模板不能丢弃任何原有的占位符信息,务必保留占位符的双花括号 + + 输出格式: + + [在此输出优化后的提示词] + + 请确保优化后的内容能够有效避免之前出现的错误case。 + +optimize_prompt_placeholder_template: | + 你是一位提示词优化专家,你的任务是根据提供的信息对指定占位符进行优化,优化后的占位符内容应详细具体,包含参数描述、注意事项等 + + 优化要求如下: + 严格根据需要优化的占位符列表进行优化。不需要优化的占位符则保持不变。 + 优化结果格式:每个占位符需要用包裹,xxx代表占位符的名称或key + + 以下是带占位符的提示词模板,注意模板内容边界从开始,到结束: + + {{PROMPT_META_TEMPLATE}} + + + 你拥有的工具或API说明如下: + + {{API_TOOLS_DESCRIPTION}} + + + 以下是提示词中出现的占位符名称以及对应内容: + + {{PLACEHOLDER_CONTENTS}} + + + 以下是需要优化的占位符: + + {{PLACEHOLDER_TO_OPTIMIZE}} + + + 提示词在应用的过程中出现的错误case如下: + + {{ERROR_CASES}} + + + 以下是对这些错误case的反思: + + {{REFLECTIONS_ON_ERROR_CASES}} + + + 请按照以下要求对占位符进行优化: + 1. 仔细分析错误case和对其的反思、明确要改进的方向 + 2. 结合prompt模板和占位符对应内容,考虑占位符在整个提示词中的作用 + 3. 根据以上分析,对需要优化的占位符进行详细、具体的优化,包含参数描述,注意事项等。 + + 输出格式: + + [在此输出优化后的占位符及其对应内容,例如] + + 请确保优化后的内容能够有效避免之前出现的错误case。 + +error_examples_critique_template: | + 作为提示词优化专家,你的目标是帮助代理高效且成功地完成任务。 + 当前任务的描述: + [任务描述]:{task_description} + 尝试通过使用{num_examples}个具备思维链的示例,编写一个有效的提示词,以解决上述任务中的任何问题。 + 我当前的{num_examples}个思维链示例是:{examples} + 请从示例的多样性、复杂性以及与整个示例集的相关性、兼容性角度分析、理解以上任务示例。 + 但这个提示词在以下示例中得出了错误的结果:{error_examples} + 请通过反思这些错误的原因,提出修正建议,以提高示例集的有效性。 + 输出所有可以改进的建议,帮助优化每个示例和整个示例选择集。 + +examples_optimization_template: | + 作为提示词优化专家,我的任务是帮助代理高效且成功地完成任务。 + 当前任务描述: + [任务描述]:{task_description} + 我正在尝试通过使用{num_examples}个具备思维链的示例,编写一个有效的提示词,以解决上述任务中的任何问题。 + 我当前的{num_examples}个思维链示例是:{examples} + 此外,我还有一组针对示例的建议/改进意见,以帮助每个示例和示例选择集。 + [建议/改进]:{critique} + 请根据上述信息,巧妙且认真地运用所有信息,精心创建一个新的示例集,要求示例数量严格为{num_examples}个,确保遵循这些建议和改进。 + 请确保每个示例都用包裹。 + + 新的示例应遵循以下格式: + [query]:紧随其后的为示例中的问题部分 + [assistant answer]: 紧随其后的为与答案相关的所有逻辑推理步骤,包含最终答案。 + + 请确保每个示例都用包裹。 + [以下是新生成的{num_examples}个示例] + +examples_select_template: | + 作为提示词优化专家,我的任务是帮助代理高效且成功地完成任务。 + 当前任务描述: + [任务描述]:{task_description} + 请从以下做错的数据中选择最具代表性的{num_examples}个示例,以解决上述任务中的任何问题。 + 当前的错误示例集是:{error_examples} + 请确保每个示例都用包裹。 + + 例如: + {gt_example} + + [选择示例] + +get_example_reasoning_template: | + 你将获得一个任务描述和指令,后面跟着一组任务的正确示例。 + [任务描述]:{task_description} + [指令]:{instruction} + 示例包含一个问题,标记为问题[问题]和最终答案[答案] + [问题]:{question} + [答案]:{answer} + 请根据任务描述和示例,生产一个简洁明了的思维链。 + 思维链应包含关键步骤,展示出正确答案的核心逻辑,最后给出最终结果。 + 确保每个步骤简洁且易于理解,避免过多细节。 + + 思维链: + +get_task_description: | + [提示词]:{instruction} + 请简明扼要地总结以上提示词的任务,精炼其核心目标和要点,确保信息简洁、清晰。 + +get_answer_format: | + [提示词]:{instruction} + 根据给定提示词,用一句话总结用户要求的输出格式,不包含“用户要求的输出格式”这类似的前缀。请简洁明了地表述。 + +result_compare_template: | + 你是一个答案校验专家,负责校验给定的用户答案和标准答案之间的含义和结论一致性。请根据以下标准判断用户答案是否与标准答案的含义和结论一致。 + + - 如果用户答案和标准答案含义一致,返回`true`。 + - 如果用户答案和标准答案含义不一致,返回`false`。 + - 注意区分对话和工具调用,两者通常不能按语意判断为一致 + - 结合用户问题和标准答案,简要分析用户答案和标准答案不一致的理由 + + 以下是用户补充的自定义校验规则,如果与上述规则冲突,则优先遵从用户自定义规则,请严格遵守: + {user_compare_rules} + + 输出JSON格式: + ```json + {{ + “result”: truw/false, + "reason": "校验理由" + }} + ``` + + [问题]:{question} + + 以下是需要比对的用户答案和标准答案: + [标准答案]:{actual_answer} + + [用户答案]:{answer_match} + + 请校验并返回结果: \ No newline at end of file