From f014d9975a49f8b46fc46b06dcca98eb5fa60a50 Mon Sep 17 00:00:00 2001 From: lzq11122 Date: Sat, 28 Feb 2026 10:55:37 +0800 Subject: [PATCH] Add patch to fix CVE-2026-24747 --- 0002-add-patch-to-fix-CVE-2026-24747.patch | 132 +++++++++++++++++++++ pytorch.spec | 7 +- 2 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 0002-add-patch-to-fix-CVE-2026-24747.patch diff --git a/0002-add-patch-to-fix-CVE-2026-24747.patch b/0002-add-patch-to-fix-CVE-2026-24747.patch new file mode 100644 index 0000000..2690fb7 --- /dev/null +++ b/0002-add-patch-to-fix-CVE-2026-24747.patch @@ -0,0 +1,132 @@ +From 167ad09be5af5c52666759412a3804068c6955d1 Mon Sep 17 00:00:00 2001 +From: Filip +Date: Wed, 17 Sep 2025 18:17:20 +0000 +Subject: [PATCH] [optim] override SWALR.state_dict and load_state_dict + (#163122) + +Fixes #163105 + +Note that the new `SWALR.load_state_dict` is **not backwards compatible**: +```python +@override +def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + self._set_anneal_func(self._anneal_strategy) +``` + +If we'd like to maintain compatibility with old state_dicts (loaded with `weights_only=False`), we could use something along these lines: +```python +@override +def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + anneal_func = state_dict.pop("anneal_func", None) + strategy = state_dict.get("_anneal_strategy") + self.__dict__.update(state_dict) + + if anneal_func is not None: + state_dict["anneal_func"] = anneal_func + if strategy is None: + if anneal_func == self._linear_anneal: + strategy = "linear" + elif anneal_func == self._cosine_anneal: + strategy = "cos" + + if strategy is None: + strategy = getattr(self, "_anneal_strategy", "cos") + + self._set_anneal_func(strategy) +``` + +But given the fact that loading an `SWALR` state_dict before this PR would have caused an error, this seems okay. A GitHub/Google search for `SWALR.load_state_dict` had no results. Happy to change if not, or add a warning just in case. +Pull Request resolved: https://github.com/pytorch/pytorch/pull/163122 +Approved by: https://github.com/janeyx99 +--- + test/optim/test_lrscheduler.py | 1 + + torch/optim/swa_utils.py | 37 +++++++++++++++++++++++++++++++++---- + 2 files changed, 34 insertions(+), 4 deletions(-) + +diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py +index a6e4481..b537cf0 100644 +--- a/test/optim/test_lrscheduler.py ++++ b/test/optim/test_lrscheduler.py +@@ -2389,6 +2389,7 @@ class TestLRScheduler(TestCase): + partial(CyclicLR, base_lr=0.01, max_lr=0.1), + partial(OneCycleLR, max_lr=0.01, total_steps=10, anneal_strategy="linear"), + partial(CosineAnnealingWarmRestarts, T_0=20), ++ partial(SWALR, swa_lr=0.01), + ], + ) + @parametrize("weights_only", [True, False]) +diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py +index f3a1fd2..be61a94 100644 +--- a/torch/optim/swa_utils.py ++++ b/torch/optim/swa_utils.py +@@ -7,6 +7,7 @@ import warnings + from collections.abc import Iterable + from copy import deepcopy + from typing import Any, Callable, Literal, Optional, Union ++from typing_extensions import override + + import torch + from torch import Tensor +@@ -428,10 +429,7 @@ class SWALR(LRScheduler): + "anneal_strategy must by one of 'cos' or 'linear', " + f"instead got {anneal_strategy}" + ) +- elif anneal_strategy == "cos": +- self.anneal_func = self._cosine_anneal +- elif anneal_strategy == "linear": +- self.anneal_func = self._linear_anneal ++ self._set_anneal_func(anneal_strategy) + if not isinstance(anneal_epochs, int) or anneal_epochs < 0: + raise ValueError( + f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}" +@@ -479,3 +477,34 @@ class SWALR(LRScheduler): + group["swa_lr"] * alpha + lr * (1 - alpha) + for group, lr in zip(self.optimizer.param_groups, prev_lrs) + ] ++ ++ def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]): ++ self._anneal_strategy = anneal_strategy ++ if anneal_strategy == "cos": ++ self.anneal_func = self._cosine_anneal ++ else: ++ self.anneal_func = self._linear_anneal ++ ++ @override ++ def state_dict(self) -> dict[str, Any]: ++ """Return the state of the scheduler as a :class:`dict`. ++ ++ It contains an entry for every variable in self.__dict__ which ++ is not the optimizer or anneal_func. ++ """ ++ return { ++ key: value ++ for key, value in self.__dict__.items() ++ if key not in ("optimizer", "anneal_func") ++ } ++ ++ @override ++ def load_state_dict(self, state_dict: dict[str, Any]) -> None: ++ """Load the scheduler's state. ++ ++ Args: ++ state_dict (dict): scheduler state. Should be an object returned ++ from a call to :meth:`state_dict`. ++ """ ++ self.__dict__.update(state_dict) ++ self._set_anneal_func(self._anneal_strategy) +-- +1.8.3.1 + diff --git a/pytorch.spec b/pytorch.spec index 22202ae..9ed4e36 100644 --- a/pytorch.spec +++ b/pytorch.spec @@ -1,4 +1,4 @@ -%define anolis_release 2 +%define anolis_release 3 %global vcu_maj 12 %global vcu_min 1 @@ -21,6 +21,8 @@ Source0: https://github.com/pytorch/pytorch/releases/download/v%{version} Patch0001: 0001-add-patch-to-fix-CVE-2025-2999.patch Patch0002: 1001-add-loongarch64-support-for-sleef.patch Patch0003: 1002-add-loongarch64-support-for-cpuinfo.patch +# https://github.com/pytorch/pytorch/commit/167ad09be5af5c52666759412a3804068c6955d1 +Patch0004: 0002-add-patch-to-fix-CVE-2026-24747.patch BuildRequires: python3-devel cmake gcc-c++ BuildRequires: python3-typing-extensions python3-pyyaml python3-setuptools @@ -109,6 +111,9 @@ end %{python3_sitearch}/torch/share %changelog +* Sat Feb 28 2026 lzq11122 - 2.8.0-3 +- Add patch to fix CVE-2026-24747 + * Tue Feb 24 2026 Wenlong Zhang - 2.8.0-2 - fix build error on loongarch64 -- Gitee