diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index c0349a692f8368b7dda88c08b2a1e39015d14811..761608cef4b13326480352d830aed12812e0eacb 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -321,6 +321,7 @@ class FileCheckConst: CSV_SUFFIX: MAX_CSV_SIZE, YAML_SUFFIX: MAX_YAML_SIZE } + CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]' class OverflowConst: diff --git a/debug/accuracy_tools/msprobe/core/common/file_utils.py b/debug/accuracy_tools/msprobe/core/common/file_utils.py index 16095bf2261f33b818f418c63b171bdc09cdb9e1..85c6e35d2d02ec0ea6ecd0f7fdcf310ab47928e2 100644 --- a/debug/accuracy_tools/msprobe/core/common/file_utils.py +++ b/debug/accuracy_tools/msprobe/core/common/file_utils.py @@ -22,6 +22,7 @@ import re import shutil import yaml import numpy as np +import pandas as pd from msprobe.core.common.log import logger from msprobe.core.common.exceptions import FileCheckException @@ -445,7 +446,25 @@ def save_workbook(workbook, file_path): change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) -def write_csv(data, filepath, mode="a+"): +def write_csv(data, filepath, mode="a+", malicious_check=False): + def csv_value_is_valid(value: str) -> bool: + if not isinstance(value, str): + return True + try: + # -1.00 or +1.00 should be consdiered as digit numbers + float(value) + except ValueError: + # otherwise, they will be considered as formular injections + return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value)) + return True + + if malicious_check: + for row in data: + for cell in row: + if not csv_value_is_valid(cell): + raise RuntimeError(f"Malicious value [{cell}] is not allowed " \ + f"to be written into the csv: {filepath}.") + file_path = os.path.realpath(filepath) check_path_before_create(filepath) try: @@ -458,6 +477,16 @@ def write_csv(data, filepath, mode="a+"): change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) +def read_csv(filepath): + check_file_or_directory_path(filepath) + try: + csv_data = pd.read_csv(filepath) + except Exception as e: + logger.error(f"The csv file failed to load. Please check the path: {filepath}.") + raise RuntimeError(f"Read csv file {filepath} failed.") from e + return csv_data + + def remove_path(path): if not os.path.exists(path): return diff --git a/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py b/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py index 0a7b89e88dc3d34b1fccb8af9559924144a8b1f9..5ccc9e5ae1e022540c089486fc8a98712154d6c0 100644 --- a/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py +++ b/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py @@ -2,12 +2,11 @@ import os from typing import List from tqdm import tqdm -import pandas as pd import matplotlib.pyplot as plt from msprobe.core.common.file_utils import create_directory, check_path_before_create, check_file_or_directory_path from msprobe.core.common.log import logger -from msprobe.core.common.file_utils import remove_path, load_npy, write_csv +from msprobe.core.common.file_utils import remove_path, load_npy, write_csv, read_csv from msprobe.core.grad_probe.constant import GradConst from msprobe.core.grad_probe.utils import plt_savefig @@ -21,7 +20,7 @@ class GradComparator: continue if not os.path.exists(os.path.join(path2, summary_file)): continue - summary_csv = pd.read_csv(os.path.join(path1, summary_file)) + summary_csv = read_csv(os.path.join(path1, summary_file)) return summary_csv["param_name"] raise RuntimeError("no matched grad_summary.csv for comparison, please dump data in same configuration")