diff --git a/pyporter/utils.py b/pyporter/utils.py index 1c8e3d4a833836b30f49f8ac7ab0671160cd7f93..932e2974b19636da23512049097b5b3b06a482fb 100644 --- a/pyporter/utils.py +++ b/pyporter/utils.py @@ -3,17 +3,22 @@ import re # TODO: this should be more compatible for https://peps.python.org/pep-0508/ def transform_module_name(input_str): - module_name = re.match(r"([a-zA-Z0-9_-]+)", input_str).group(1).strip() - version_names = input_str[len(module_name):].strip().strip("()") - version_constraint = version_names.split(",") - package_name = "python3-" + module_name - if len(version_constraint) > 1: - constraints_string = " with ".join([ - f"{package_name}{constraint}" for constraint in version_constraint - ]) - result_string = f"({constraints_string})" + match = re.match(r"([a-zA-Z0-9_-]+)", input_str) + if match: + module_name = match.group(1).strip() + + version_names = input_str[len(module_name):].strip().strip("()") + version_constraint = version_names.split(",") + package_name = "python3-" + module_name + if len(version_constraint) > 1: + constraints_string = " with ".join([ + f"{package_name}{constraint}" for constraint in version_constraint + ]) + result_string = f"({constraints_string})" + else: + result_string = f"({package_name}{version_constraint[0]})" else: - result_string = f"({package_name}{version_constraint[0]})" + result_string = "Invalid input format" return result_string diff --git a/tests/test_transform_module_name.py b/tests/test_transform_module_name.py index c292b7d90fadf27593d879e9b4ef4a172ba3a33a..de53d7bb0a499e883420b874c9906757de7d08c0 100644 --- a/tests/test_transform_module_name.py +++ b/tests/test_transform_module_name.py @@ -16,9 +16,14 @@ class TestTransofrmModuleName(unittest.TestCase): def test_transform_module_name_strip_whitespace(self): input_str = "pyjwkest (>=1.3.6)" - expected_output = "(python3-pyjwkest(>=1.3.6))" + expected_output = "(python3-pyjwkest>=1.3.6)" self.assertEqual(transform_module_name(input_str), expected_output) + def test_invalid_input(self): + input_str = "!invalid_module123" + expected_output = "Invalid input format" + self.assertEqual(transform_module_name(input_str), expected_output) + if __name__ == '__main__': unittest.main()