diff --git a/torch_npu/utils/serialization.py b/torch_npu/utils/serialization.py index ab8e3fd6a24b5b9d84f14446f6fb490c6c220066..df973bdaad8377e7a9ed9c432ed879b0a88e02f5 100644 --- a/torch_npu/utils/serialization.py +++ b/torch_npu/utils/serialization.py @@ -74,7 +74,7 @@ def to_cpu(data): elif isinstance(value, torch.Tensor): copy_data[key] = value.cpu() elif isinstance(value, nn.Module): - if torch_npu._C.is_npu(next(value.parameters())): + if torch_npu._C.is_npu(next(data.parameters())): setattr(value, "mark_npu", True) copy_data[key] = value.cpu() else: