From 84ef7846f041ba8cfe566d4347cd46f5b3247fc3 Mon Sep 17 00:00:00 2001 From: tianyiliu_9999 Date: Wed, 30 Apr 2025 17:40:32 +0800 Subject: [PATCH] add bisheng migration agent --- migration-agent/.gitignore | 17 + migration-agent/README.md | 15 + migration-agent/requirements.txt | 6 + migration-agent/setup_compiler_driver.sh | 21 ++ migration-agent/src/compiler_driver.py | 332 +++++++++++++++++++ migration-agent/src/global_config.py | 116 +++++++ migration-agent/src/inference.py | 192 +++++++++++ migration-agent/src/prompt_engineering.py | 60 ++++ migration-agent/src/utilities/display.py | 192 +++++++++++ migration-agent/src/utilities/filemanager.py | 40 +++ migration-agent/src/utilities/llms.py | 57 ++++ migration-agent/src/utilities/utilities.py | 49 +++ 12 files changed, 1097 insertions(+) create mode 100644 migration-agent/.gitignore create mode 100644 migration-agent/README.md create mode 100644 migration-agent/requirements.txt create mode 100644 migration-agent/setup_compiler_driver.sh create mode 100644 migration-agent/src/compiler_driver.py create mode 100644 migration-agent/src/global_config.py create mode 100644 migration-agent/src/inference.py create mode 100644 migration-agent/src/prompt_engineering.py create mode 100644 migration-agent/src/utilities/display.py create mode 100644 migration-agent/src/utilities/filemanager.py create mode 100644 migration-agent/src/utilities/llms.py create mode 100644 migration-agent/src/utilities/utilities.py diff --git a/migration-agent/.gitignore b/migration-agent/.gitignore new file mode 100644 index 000000000000..10b8bdb4e545 --- /dev/null +++ b/migration-agent/.gitignore @@ -0,0 +1,17 @@ +.git/ +__pycache__/ +*.vscode/ + +*.out +*.o +*.log +test +tmp*/ +*_llmRepaired.* + +build/ +dist/ +bishengai.spec + + + diff --git a/migration-agent/README.md b/migration-agent/README.md new file mode 100644 index 000000000000..5c6719a9dd91 --- /dev/null +++ b/migration-agent/README.md @@ -0,0 +1,15 @@ +# llm4codemigration + +LLM for code migration project + +#Steps to use compiler_driver.py +## 1. Build the executable binary but running +```source setup_compiler_driver.sh``` + +## 2. Evaluate on the test cases +Go to test_examples +set into the folder of each test case +then run +```make``` + +Note: make sure you build the executable binary every time when you change any of the python scripts. diff --git a/migration-agent/requirements.txt b/migration-agent/requirements.txt new file mode 100644 index 000000000000..177eb8c2e20d --- /dev/null +++ b/migration-agent/requirements.txt @@ -0,0 +1,6 @@ +clang +colorama +Pygments +Requests +termcolor +demjson3 diff --git a/migration-agent/setup_compiler_driver.sh b/migration-agent/setup_compiler_driver.sh new file mode 100644 index 000000000000..07c9469ee343 --- /dev/null +++ b/migration-agent/setup_compiler_driver.sh @@ -0,0 +1,21 @@ +#!/bin/sh +pip install pyinstaller +rm -rf build/ +rm -rf dist/ +rm bishengai.spec +TARGET="bishengai" +SOURCE="src/compiler_driver.py" +pyinstaller --onefile $SOURCE --name $TARGET +export CC=`pwd`/dist/$TARGET +export CXX=$CC +export LLM_DEVELOPMENT=1 +# export COMPILER_CHOICE="clang++" # "clang" for c; "clang++"" for c++ +export LLM_DEBUG=1 +export AUTO_ACCEPT=1 # not export-interactive window; export-automatically change +export LLM_API_TOKEN= # add your api token here +~ +~ +~ +~ +~ + diff --git a/migration-agent/src/compiler_driver.py b/migration-agent/src/compiler_driver.py new file mode 100644 index 000000000000..1c10dff73aff --- /dev/null +++ b/migration-agent/src/compiler_driver.py @@ -0,0 +1,332 @@ +"""This defines the compiler driver routines""" + +import sys +import os +import subprocess +import re +import shutil +import logging +import tempfile +from pathlib import Path +import signal +import traceback +from inference import * +from global_config import * +from utilities.utilities import * +from utilities.display import * +from prompt_engineering import * +from utilities.filemanager import * + + +class CompilerDriver: + def __init__(self, CC="clang", CXX="clang++"): + self.argv = sys.argv + self.CC = CC + self.CXX = CXX + self.compiler = self.CXX + self.attempts = 0 + self.max_attempts = get_llm_retry_times() + self.user_defined_compiler_choice = False + self.temp_dir = None + self.compiler_outputs = [] + self.compiler_choice = os.getenv("NATIVE_COMPILER") or os.getenv( + "COMPILER_CHOICE" + ) + self.initialize_compiler() + self.temp_dir = tempfile.mkdtemp() + + def initialize_compiler(self): + """Initialize the compiler based on environment variables or command-line arguments.""" + if self.compiler_choice: + self.compiler = self.compiler_choice + self.user_defined_compiler_choice = True + else: + self.set_compiler_from_args() + + if not self.check_compiler(self.compiler): + sys.exit(-1) + + logging.info(f"Using {self.compiler} as the native compiler.") + + def set_compiler_from_args(self): + """Set the compiler based on file extensions in the arguments.""" + cfiles = self.get_compile_target(self.argv[1:]) + if cfiles: + for f in cfiles: + self.compiler = self.CC if f.endswith(".c") else self.CXX + break + else: + self.compiler = self.CXX + + def check_compiler(self, compiler: str): + """Check if the compiler is accessible.""" + try: + subprocess.run( + [compiler, "--version"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + return True + except Exception as e: + logging.error(f"Failed to verify {compiler}. Is it in $PATH?") + print(e) + return False + + def get_compile_target(self, command): + """Retrieve the compile target from the command.""" + return [e for e in command if re.match(r".*\.(c|cpp)$", e)] + + +class CompilerErrorHandler: + def __init__(self, max_attempts, compiler, temp_dir): + self.max_attempts = max_attempts + self.compiler = compiler + self.attempts = 0 + self.temp_dir = temp_dir + self.compiler_outputs = [] + self.return_code = None + self.source_paths = {} + self.backup_path = None + self.new_folders = [] + + def compile_and_repair(self, compile_targets_options): + """Attempt to compile and repair errors.""" + if self.compiler not in compile_targets_options: + command = [self.compiler] + compile_targets_options + else: + command = compile_targets_options + + compiler_output, success = self.compile(command) + if success: + return + + self.compiler_outputs.append(compiler_output) + first_error = compiler_output + prompt_template = None + + while self.attempts < self.max_attempts: + self.attempts += 1 + command, repair_prompt_template = self.repair( + command, compiler_output, prompt_template + ) + if command: + compiler_output, success = self.compile(command) + self.compiler_outputs.append(compiler_output) + if success: + logging.info("Successfully fixed the error.") + break + + pe = PromptEngine(repair_prompt_template) + prompt_template = pe.update_template( + previous_compile_log=first_error, + current_compile_log=compiler_output, + ) + else: + break + else: + logging.info(f"Exceeded max attempts ({self.max_attempts}).") + if self.backup_path: + FileManager.restore_files( + source_paths=self.source_paths, backup_path=self.backup_path + ) + + def compile(self, command): + """Call the compiler and handle errors.""" + try: + result = subprocess.run(command, capture_output=True, text=True, check=True) + print(result.stdout) + return result.stdout, True + except Exception as e: + print(e.stderr, file=sys.stderr) + if not self.return_code: + self.return_code = e.returncode + return e.stderr, False + + def repair(self, command, error, prompt_template=None): + """Repair compilation errors using suggestions from LLM.""" + compile_files = [e for e in command if e.endswith((".c", ".cpp"))] + if not compile_files: + logging.error("No valid source files for compilation.") + return None, None + + target_code = {file: Path(file).read_text() for file in compile_files} + code = self.check_relevant_files(error, target_code) + + # Repair using LLM + lr = LLMRepair() + llm_suggestions = lr.query_llm_for_fix( + code, " ".join(command), error, prompt_template, get_model_id() + ) + + if not llm_suggestions: + logging.debug("Failed to get LLM response.") + return command, lr.get_repair_prompt_template() + + return ( + self.apply_llm_suggestions(command, llm_suggestions), + lr.get_repair_prompt_template(), + ) + + def apply_llm_suggestions(self, command, llm_suggestions): + """Apply LLM suggestions for compilation options and code changes.""" + if llm_suggestions.get("compiler_options"): + command = self.repair_via_compiler_options(llm_suggestions) + + reason = llm_suggestions.get("reasoning", "NA") + if llm_suggestions.get("code"): + for target, code_content in llm_suggestions["code"].items(): + self.repair_via_code_alternation(target, code_content, reason=reason) + + return command + + def repair_via_compiler_options(self, llm_suggestions): + """Apply LLM's suggested compiler options.""" + compiler_options = llm_suggestions["compiler_options"].split(" ") + return ( + [self.compiler] + compiler_options + if self.compiler not in compiler_options + else compiler_options + ) + + def repair_via_code_alternation(self, target_file, repaired_code, reason): + """Apply code changes suggested by LLM.""" + + def backup(file_name): + # Set up the backup path + if not self.backup_path: + self.backup_path = self.temp_dir + backup_file_path = os.path.join(self.backup_path, file_name) + + if os.path.exists(target_file) and not os.path.exists(backup_file_path): + self.source_paths[file_name] = target_file + FileManager.backup_file(source=target_file, dest=backup_file_path) + else: + logging.debug(f"A backup of {target_file} already exists.") + + file_name = Path(target_file).name + if file_name == "": + logging.debug(f"Failed to extract the file name from {target_file}") + + source_code = ( + Path(target_file).read_text() if os.path.exists(target_file) else "" + ) + + try: + if not is_auto_accept_code_change(): + # Reasons may be presented as a list + if isinstance(reason, list): + reason = " ".join(reason) + + ret = code_dialog( + old_code=source_code, + new_code=repaired_code, + reason=reason, + cfile=target_file, + ) + else: + ret = True # Auto-accept code changes + + if ret: + backup(file_name) + if not is_auto_accept_code_change(): + logging.info("User accepted the suggested code changes.") + + # Just in case LLM suggest a new file + self.new_folders = ( + self.new_folders + FileManager.create_folders_for_path(target_file) + ) + if ( + not os.path.exists(target_file) + and target_file not in self.new_folders + ): + self.new_folders.append(target_file) + + logging.debug(f"Applying LLM-suggested code changes to {target_file}") + + with open(target_file, "w") as file: + file.write(repaired_code) + return True + else: + logging.info("User rejected the LLM suggested code changes.") + self.clean_up() + sys.exit(self.return_code) + except Exception as e: + logging.error(f"Failed to repair {target_file} via code alternation: {e}") + traceback.print_exc() + self.clean_up() + return False + + def apply_code_changes(self, target_file, repaired_code): + """Write repaired code back to the file.""" + logging.debug(f"Applying LLM-suggested code changes to {target_file}") + Path(target_file).write_text(repaired_code) + + def check_relevant_files(self, compile_error, code): + """Check for relevant files based on the compile error.""" + error_lines = compile_error.splitlines() + for error in error_lines: + if ":" in error: + file_name = error.split(":", 1)[0] + if file_name.endswith((".cpp", ".c", ".h")) and file_name not in code: + if os.path.exists(file_name): + code[file_name] = Path(file_name).read_text() + logging.debug(f"Also passing {file_name} to the compiler") + return code + + def clean_up(self): + """Clean up temporary files and directories.""" + if self.temp_dir: + if is_development_mode() and self.backup_path: + logging.debug("Restoring buggy file in development mode.") + FileManager.restore_files( + source_paths=self.source_paths, backup_path=self.backup_path + ) + + if not is_development_mode(): + try: + os.remove(LOG_FILE) + except: + pass + + # Remove temporary folders + try: + shutil.rmtree(self.temp_dir) + for folder in self.new_folders: + shutil.rmtree(folder) + except: + pass + + +def main(): + compiler_driver = CompilerDriver() + error_handler = CompilerErrorHandler( + compiler_driver.max_attempts, compiler_driver.compiler, compiler_driver.temp_dir + ) + signal.signal( + signal.SIGINT, lambda sig, frame: exit_gracefully(error_handler=error_handler) + ) + + try: + error_handler.compile_and_repair(sys.argv[1:]) + except Exception as e: + logging.error(f"Error occurred: {e}") + finally: + error_handler.clean_up() + if error_handler.return_code: + return error_handler.return_code + else: + return 0 + + +def exit_gracefully(error_handler): + logging.debug("\nExiting gracefully.") + error_handler.clean_up() + if error_handler.return_code: + sys.exit(error_handler.return_code) + else: + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/migration-agent/src/global_config.py b/migration-agent/src/global_config.py new file mode 100644 index 000000000000..53e029b5f274 --- /dev/null +++ b/migration-agent/src/global_config.py @@ -0,0 +1,116 @@ +""" +This defines the global configurations for the compiler driver +""" + +import logging +import os +from termcolor import colored + +# Global Constants +LOG_FILE = "llm4compiler.log" +llm_url = "https://api.siliconflow.cn/v1/chat/completions" + + +# Function to get the api token for siliconflow +def get_llm_api_token() -> str: + """Return the api token for LLM""" + return os.getenv("LLM_API_TOKEN", "") +llm_api_token = get_llm_api_token() + + +# Function to check if automatic acceptance of LLM code changes is enabled +def is_auto_accept_code_change() -> bool: + """Return True if the user opts to automatically accept LLM suggested code changes.""" + return bool(os.getenv("AUTO_ACCEPT", False)) + + +# Function to check if the compiler driver is running in development mode +def is_development_mode() -> bool: + """Return True if running in development mode.""" + return bool(os.getenv("LLM_DEVELOPMENT", True)) # Default to True if not set + + +# Return the LLM model ID, default to a preconfigured model +def get_model_id() -> str: + """Return the LLM model ID, with a fallback if not set in environment.""" + return os.getenv("LLM_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Llama-70B") + + +# Use a smaller model for simple tasks +def get_small_model_id() -> str: + """Return the LLM small model ID, with a fallback if not set in environment.""" + return os.getenv("SMALL_LLM_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B") + + +# Function to get the maximum retry times for LLM interactions +def get_llm_retry_times() -> int: + """Return the number of retry attempts for LLM. Default is 5, max 10.""" + try: + retry_times = int(os.getenv("LLM_RETRY_TIMES", 5)) + return min(max(retry_times, 1), 10) # Ensure retry times is between 1 and 10 + except ValueError: + return 5 + + +def try_small_llm_first(): + """Return true to try using a small LLM first for repairing""" + return False + + +def single_source_file(): + """Return true to try using a small LLM first for repairing""" + if os.getenv("SINGLE_SOURCE"): + return True + return False + + +# Set up the logging level based on environment variables +def set_logging_level() -> int: + """Determine the logging level based on environment variables.""" + if os.getenv("LLM_DEBUG"): + return logging.DEBUG + elif os.getenv("LLM_SILENT"): + return logging.ERROR + else: + return logging.INFO + + +# Custom logging formatter with color +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "magenta", + } + + def format(self, record): + levelname = record.levelname + message = super().format(record) + return colored(message, self.COLORS.get(levelname, "white")) + + +# Set up logging with colored output +def configure_logging(): + """Configure logging with colored output and file logging.""" + formatter = ColoredFormatter( + "%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) + + logger = logging.getLogger() + logger.setLevel(set_logging_level()) + + # Console handler for colored output + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # File handler for logging to file + file_handler = logging.FileHandler(LOG_FILE, mode="w", encoding="utf-8") + file_handler.setLevel(logging.DEBUG) # Log everything to the file + logger.addHandler(file_handler) + + +# Initialize logging +configure_logging() diff --git a/migration-agent/src/inference.py b/migration-agent/src/inference.py new file mode 100644 index 000000000000..5ddef41c8f1e --- /dev/null +++ b/migration-agent/src/inference.py @@ -0,0 +1,192 @@ +""" +This defines the routines for using LLMs for inference +""" + +from pathlib import Path +import json +import demjson3 +import logging +from json import JSONDecoder +from utilities.llms import LLM +from utilities.utilities import pretty_print_code, pretty_print_dialog +from global_config import * +import re + + +class LLMRepair: + + def __init__(self): + self.llm = LLM() + self.repair_prompt_template = ( + "Given the following LLVM/Clang compilation errors, compiler command, and source code, provide a solution to fix the error. " + "If this can be fixed by modifying the code, provide the revised code (as a whole) and reasoning in a JSON object. " + "If this can be fixed by changing the compiler options, place the new compiler command in the 'compiler_options' field along with reasoning. " + 'Organise the code change per file. A json output example could be json {"code_changes":{"path to the source file": "code"}, "compiler_options", "compiler_options", "reasoning", "reasoning"} ' + "Ensure that strings (especially multiline strings) are properly escaped, but do not just escape space. " + "Do not produce text other than the JSON object. Keep the reasoning text concise." + ) + + self.template_code_suffix = "The code from [SOURCE_FILE] is: ``` [CODE] ```, " + self.repair_prompt_template_suffix = ( + "the compiler command is: ``` [COMMAND] ```, " + "the compilation error is: ``` [ERROR] ```" + ) + self.extract_prompt_template = ( + "I want you to act as an expert programmer to help me fix a compilation error. " + "Given the following compilation error message and content of the relevant file, determine if there is additional source file you " + "need to see to diagnose the issue." + "Respond in JSON format with two fields: " + "source_file: The name of the source file you need to see. If the issue can be " + "resolved without viewing any source file, return None. " + "reasoning: A brief explanation of why this file is needed or why no file is required." + "Do not show your reasoning process, only return the JSON response and ensure valid JSON escapes." + ) + self.extract_prompt_template_suffix = ( + "the compiler command is: ``` [COMMAND] ```, " + "the compilation error is: ``` [ERROR] ```, " + ) + + # Get prompt template + def get_repair_prompt_template(self): + return self.repair_prompt_template + + # Return the model id (name) used for inference + def get_model_id(self): + return self.llm.get_model_id() + + def extract_json_objects(self, text, decoder=JSONDecoder()): + """Extracts JSON objects from text.""" + pos = 0 + while True: + match = text.find("{", pos) + if match == -1: + break + try: + result, index = decoder.raw_decode(text[match:]) + yield result + pos = match + index + except ValueError: + pos = match + 1 + + def _extract_code_changes(self, code_changes): + """Extract and log code changes from LLM response.""" + code = {} + if code_changes: + for source, code_text in code_changes.items(): + code[source] = code_text + logging.debug( + pretty_print_code(f"LLM Generated Code for {source}", code_text) + ) + return code + + def _extract_compiler_options(self, data): + """Extract compiler options from LLM response.""" + compiler_options = data.get("compiler_options") + if compiler_options: + logging.debug( + pretty_print_dialog("LLM suggested options", compiler_options) + ) + return compiler_options + + def _extract_reasoning(self, data, reasoning_content): + """Extract reasoning from LLM response.""" + reasoning = data.get("reasoning") + return reasoning if reasoning else reasoning_content + + def _extract_source_target(self, data): + """Extract the source file required from LLM response.""" + return data.get("source_file") + + def process_response(self, message_content, reasoning_content): + # Remove redundant texts before ``json { + def remove_redundant_text(input_string): + # Use regular expression to remove everything before the `json{...}` part + cleaned_string = re.sub( + r".*```json\s*{", "```json {", input_string, flags=re.DOTALL + ) + return cleaned_string + + logging.debug(pretty_print_dialog("LLM inference raw data", message_content)) + # if reasoning_content: + # logging.debug(pretty_print_dialog("CoT raw data", reasoning_content)) + + message_content = remove_redundant_text(message_content) + message_content = message_content.replace("```json", "").replace("```", "") + try: + # logging.debug(f"RAW JSON DATA {message_content}") + data = demjson3.decode(message_content) + except Exception as e: + logging.debug(f"Failed to parse json data {message_content} {e}") + return {} + # data = next(self.extract_json_objects(message_content)) + + res = { + "code": self._extract_code_changes(data.get("code_changes")), + "compiler_options": self._extract_compiler_options(data), + "reasoning": self._extract_reasoning(data, reasoning_content), + "source_file": self._extract_source_target(data), + } + + return res + + def populate_prompt_template(self, template: str, command: str, error: str) -> str: + """Populate the prompt template with command and error.""" + return template.replace("[COMMAND]", command).replace("[ERROR]", error) + + def populate_code_to_prompt(self, prompt, code: dict): + if code: + for file, c in code.items(): + prompt += self.template_code_suffix.replace( + "[SOURCE_FILE]", file + ).replace("[CODE]", c) + return prompt + + def query_llm_for_fix( + self, code: dict, command, error, prompt_template=None, model_id=None + ): + """Query LLM for fix based on code, command, and error.""" + prompt = prompt_template or self.repair_prompt_template + prompt = self.populate_code_to_prompt(prompt=prompt, code=code) + prompt += self.repair_prompt_template_suffix + return self.inference( + prompt=self.populate_prompt_template(prompt, command, error), + model_id=model_id, + task_msg="to repair errors.", + ) + + def query_llm_for_source_file( + self, code: dict, command, error, prompt_template=None, model_id=None + ): + """Query LLM for the source file required to fix the issue.""" + prompt = prompt_template or self.extract_prompt_template + prompt += self.extract_prompt_template_suffix + prompt = self.populate_code_to_prompt(prompt=prompt, code=code) + return self.inference( + prompt=self.populate_prompt_template(prompt, command, error), + model_id=model_id, + task_msg="to locate relevant source files.", + ) + + def inference(self, prompt, model_id=None, task_msg=""): + """Perform inference by querying the LLM.""" + prev_model = self.llm.get_model_id() + if model_id: + self.llm.set_model_id(model_id) + else: + model_id = self.llm.get_model_id() + + logging.info(f"Calling {model_id} {task_msg}") + # logging.debug(f"prompt={prompt}") + response = self.llm.inference(prompt) + + if response: + data = json.loads(response.text) + message_content = data.get("choices")[0].get("message").get("content") + reasoning_content = ( + data.get("choices")[0].get("message").get("reasoning_content") + ) + self.llm.set_model_id(prev_model) + return self.process_response(message_content, reasoning_content) + + logging.error("Failed in getting the inference response") + self.llm.set_model_id(prev_model) diff --git a/migration-agent/src/prompt_engineering.py b/migration-agent/src/prompt_engineering.py new file mode 100644 index 000000000000..2a92e4272d0b --- /dev/null +++ b/migration-agent/src/prompt_engineering.py @@ -0,0 +1,60 @@ +""" +Prompt engineering utilities +""" + +import re + + +class PromptEngine: + def __init__(self, prompt_template): + self.prompt_template = prompt_template + self.redefined_identifiers = [] + self.missing_files = [] + + def _extract_identifier(self, text): + """Retrieve the identifier from the compilation message.""" + match = re.search(r"'(.*?)'", text) + return match.group(1) if match else None + + def _check_and_get_identifier(self, pattern, error_message): + """Check if a pattern exists in the error message and extract the identifier.""" + if pattern in error_message: + return self._extract_identifier(error_message) + return None + + def _check_for_redefinition(self, error_message): + """Check if the error message indicates a redefinition.""" + identifier = self._check_and_get_identifier("redefinition of", error_message) + if identifier: + self.redefined_identifiers.append(identifier) + + def _check_for_missing_file(self, error_message): + """Check if the error message indicates a missing file.""" + identifier = self._check_and_get_identifier("file not found", error_message) + if identifier: + self.missing_files.append(identifier) + + def _check_new_compilation_error(self, previous_log, current_error): + """Check if a new error is introduced by the LLM.""" + match = re.search(r"error:\s*(.*)", current_error, re.IGNORECASE) + if match: + error_message = match.group(1) + if error_message not in previous_log: + self._check_for_redefinition(error_message) + self._check_for_missing_file(error_message) + + def update_template(self, previous_compile_log, current_compile_log): + """Update the prompt template based on the current compilation log.""" + logs = current_compile_log.splitlines() + for line in logs: + self._check_new_compilation_error(previous_compile_log, line) + + prompt_suffix = "" + if self.redefined_identifiers: + redefine_str = ", ".join(self.redefined_identifiers) + prompt_suffix += f" Don't provide definitions for the following data structures or classes as they have already been defined elsewhere: ```{redefine_str}```" + if self.missing_files: + missing_files_str = ", ".join(self.missing_files) + prompt_suffix += f" Don't refer to the following files, as they don't exist: ```{missing_files_str}```" + + return self.prompt_template + prompt_suffix diff --git a/migration-agent/src/utilities/display.py b/migration-agent/src/utilities/display.py new file mode 100644 index 000000000000..b460b2ae629f --- /dev/null +++ b/migration-agent/src/utilities/display.py @@ -0,0 +1,192 @@ +""" +This defines uitility functions for display and user interaction +""" + +import curses +import difflib +import textwrap +import sys +from pygments import highlight +from pygments.lexers import CppLexer +from pygments.formatters import TerminalFormatter +from global_config import * + + +def diff_strings(old_code, new_code): + # Split the code into lines for better comparison + old_lines = old_code.splitlines() + new_lines = new_code.splitlines() + diff = difflib.unified_diff(old_lines, new_lines) + return list(diff) + + +def show_code_dialog(code_message, code, reason_message, reason): + """ + Displays a dialog with the suggested C/C++ code snippet, the reason for the change and asks the user for a Yes/No response + by typing 'y' for Yes or 'n' for No. The window adjusts its size based on the code length. + It also supports scrolling if the content exceeds the window height. + + :param code_message: The message to display in the first dialog. + :param code: The C/C++ code snippet to display in the dialog. + :param reason_message: The message to display in the second dialog. + :param reason: The reason for the change + :return: True or False based on user input. + """ + + def show_code_win(curses, win_width, begin_y, begin_x, code_lines): + win_height = max(15, len(code_lines) + 6) + code_win = curses.newwin(win_height, win_width, begin_y, begin_x) + code_win.box() + # Create a scrollable window by enabling scrolling + code_win.scrollok(True) + + # Add the code message and print it with background color + code_win.attron(curses.color_pair(1)) # Apply background color pair + code_win.addstr(0, 2, code_message) + code_win.attroff(curses.color_pair(1)) # Turn off background color + + # Add diff conent to the code window + for i, line in enumerate(code_lines): + line = textwrap.fill(line, win_width - 2, placeholder="...") + y = 2 + i + x = 1 + if line.startswith("-"): # Lines removed from the original text + code_win.addstr( + y, x, line, curses.color_pair(3) + ) # Red for removed lines + elif line.startswith("+"): # Lines added in the modified text + code_win.addstr( + y, x, line, curses.color_pair(4) + ) # Red for removed lines + else: + code_win.addstr(y, x, line) # Normal text for unchanged lines + + code_win.refresh() + + return code_win, win_height + + def show_reason_win(curses, win_width, begin_y, begin_x, reason): + # Create the second window below the first one + reason_lines = reason.splitlines() + win_height = max(5, min(15, len(reason_lines) + 6)) + reason_win = curses.newwin(win_height, win_width, begin_y, begin_x) + reason_win.box() + + # Add the reasoning message and print it with background color + reason_win.attron(curses.color_pair(1)) # Apply background color pair + reason_win.addstr(0, 2, reason_message) + reason_win.attroff(curses.color_pair(1)) # Turn off background color + + # Print the second message inside the second window + for i, line in enumerate(reason_lines): + reason_win.addstr( + 1 + i, 2, line.rstrip() + ) # Print starting from row 2 to leave space for the box + + reason_win.refresh() + + return reason_win, win_height + + def show_action_win(curses, win_width, begin_y, begin_x, prompt): + # Create the third window below the reaon window + action_win_height = 5 + action_win = curses.newwin(action_win_height, win_width, begin_y, begin_x) + action_win.box() + + # Add the action message and print it with background color + action_win.attron(curses.color_pair(1)) # Apply background color pair + action_win.addstr(0, 2, "Action") + action_win.attroff(curses.color_pair(1)) # Turn off background color + + # Ask if the user wants to accept the change + action_win.attron( + curses.color_pair(3) | curses.A_BOLD + ) # Apply red color and bold to the prompt + action_win.addstr(2, 2, prompt) + action_win.attroff(curses.color_pair(3) | curses.A_BOLD) # Turn off red color + action_win.addstr(2, 2 + len(prompt) + 1, "_", curses.A_BLINK) + action_win.refresh() + + while True: + # Wait for user input + key = action_win.getch() + + # Handle Yes (y) or No (n) response + if key == ord("y"): # 'y' key for Yes + return True + elif key == ord("n"): # 'n' key for No + return False + + def init_curses(): + # Initialize curses color functionality + curses.start_color() + curses.init_pair( + 1, curses.COLOR_WHITE, curses.COLOR_BLUE + ) # White text on Blue background + curses.init_pair( + 2, curses.COLOR_YELLOW, curses.COLOR_BLACK + ) # Yellow text on Black background + curses.init_pair( + 3, curses.COLOR_RED, curses.COLOR_BLACK + ) # Red text on Black background for the prompt + curses.init_pair( + 4, curses.COLOR_GREEN, curses.COLOR_BLACK + ) # Green for added lines + return curses + + def show_dialog(stdscr): + curses = init_curses() + prompt = "Would you accept the LLM suggested code? Changes will be rolled back if we can't fix it (y/n):" + + # Clear screen + # stdscr.clear() + + # Get the height and width of the terminal + h, w = stdscr.getmaxyx() + + if isinstance(code, str): + code_lines = code.splitlines() + else: + code_lines = code + + # Truncate the code lines + if len(code_lines) > h - 2: + code_lines = code_lines[0 : h - 2] + + # Set up the window height and width + max_line_length = max(max(len(line) for line in code_lines), len(prompt)) + win_width = ( + max(50, max_line_length + 10) + 4 + ) # Ensure a minimum width of 50 columns + + # Wrap the code if needed + if win_width > w: + win_width = w + + begin_y = 4 + begin_x = (w - win_width) // 2 + _, code_win_height = show_code_win( + curses, win_width, begin_y=begin_y, begin_x=begin_x, code_lines=code_lines + ) + begin_y = begin_y + code_win_height + _, reason_win_height = show_reason_win( + curses, win_width, begin_y=begin_y, begin_x=begin_x, reason=reason + ) + begin_y = begin_y + reason_win_height + return show_action_win( + curses, win_width, begin_y=begin_y, begin_x=begin_x, prompt=prompt + ) + + return curses.wrapper(show_dialog) + + +# Show a diaglog to ask if the user want to accept the changes. +def code_dialog(old_code: str, new_code: str, reason: str, cfile: str): + code_diff = diff_strings(old_code=old_code, new_code=new_code) + if code_diff: + return show_code_dialog( + code_message="Suggested code changes for " + cfile, + code=code_diff, + reason_message="Reasons", + reason=reason, + ) diff --git a/migration-agent/src/utilities/filemanager.py b/migration-agent/src/utilities/filemanager.py new file mode 100644 index 000000000000..5d393b73980a --- /dev/null +++ b/migration-agent/src/utilities/filemanager.py @@ -0,0 +1,40 @@ +import logging +import shutil +import os + + +class FileManager: + @staticmethod + def backup_file(source, dest): + """Backup a file to the destination.""" + try: + logging.debug(f"Backing up {source} to {dest}") + shutil.copy(source, dest) + return True + except Exception as e: + logging.error(f"Failed to backup {source} to {dest}: {e}") + return False + + @staticmethod + def restore_files(source_paths, backup_path): + """Restore files from the backup.""" + try: + for file_name, path in source_paths.items(): + backup_file_path = os.path.join(backup_path, file_name) + shutil.copy(backup_file_path, path) + logging.debug(f"Restored {backup_file_path} to {path}") + return True + except Exception as e: + logging.error(f"Failed to restore files: {e}") + return False + + @staticmethod + def create_folders_for_path(file_path): + """Create necessary folders for a given file path.""" + dir_path = os.path.dirname(file_path) + new_folders = [] + if dir_path and not os.path.exists(dir_path): + logging.debug(f"Creating missing directories: {dir_path}") + os.makedirs(dir_path) + new_folders.append(dir_path) + return new_folders diff --git a/migration-agent/src/utilities/llms.py b/migration-agent/src/utilities/llms.py new file mode 100644 index 000000000000..0f512e26e448 --- /dev/null +++ b/migration-agent/src/utilities/llms.py @@ -0,0 +1,57 @@ +""" This defines LLM related utility functions """ + +import requests +import logging +from global_config import * + +# Suppress all warnings for cleaner output +import warnings + +warnings.filterwarnings("ignore") + + +class LLM: + def __init__(self): + """Initialize LLM instance with the model ID and API headers.""" + self.headers = { + "Authorization": self.get_api_token(), + "Content-Type": "application/json", + } + self.model_id = get_model_id() + + def get_api_token(self) -> str: + """Return the API token required for authorization.""" + return f"Bearer {llm_api_token}" + + def get_model_id(self): + """Return the current LLM model ID.""" + return self.model_id + + def set_model_id(self, model_id: str): + """Set the model ID for the LLM.""" + self.model_id = model_id + + def inference(self, prompt): + payload = { + "model": self.model_id, + "messages": [ + { + "role": "user", + "content": prompt, + } + ], + "temperature": 0.3, + } + + response = requests.request( + "POST", + llm_url, + json=payload, + headers=self.headers, + verify=False, + ) + # Successfully getting the inference response + if response.status_code == 200: + return response + else: + return None diff --git a/migration-agent/src/utilities/utilities.py b/migration-agent/src/utilities/utilities.py new file mode 100644 index 000000000000..2fdc39f3edc8 --- /dev/null +++ b/migration-agent/src/utilities/utilities.py @@ -0,0 +1,49 @@ +""" +This defines the utility functions for the compiler driver +""" + +import shutil +import pygments +from pygments import highlight +from pygments.lexers import CppLexer +from pygments.formatters import TerminalFormatter +from colorama import init, Back, Fore, Style +import textwrap + + + + +def pretty_print_dialog(title: str, message: str, auto_wrap=True) -> str: + # Get the current terminal size + columns, _ = shutil.get_terminal_size() + if auto_wrap: + message = textwrap.fill(message, width=columns - 6) + border = "+" * min( + max(len(message), len(title) + 6), columns + ) # Make border length based on the message length + dialog = f"\n{border}\n" + dialog += f"+ {title.center(len(title))} +\n" # Title centered + dialog += f"+ {message.center(len(message))} +\n" # Message centered + dialog += f"{border}" + return dialog + + +# Initialize colorama (needed for Windows) +init(autoreset=True) + + +def pretty_print_code(title, code): + columns, _ = shutil.get_terminal_size() + border = "+" * columns + highlighted_code = highlight(code, CppLexer(), TerminalFormatter()) + return ( + f"\n{border}" + + Back.GREEN + + Fore.BLACK + + Style.BRIGHT + + highlighted_code + + f"{border}\n" + ) + + + -- Gitee