From 73187a93711b1c7ce790e1f9e4c0c95d3f2939c9 Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Mar 2025 10:27:30 +0800 Subject: [PATCH] [add]: diffusion api --- MindFlow/mindflow/cell/__init__.py | 6 +- MindFlow/mindflow/cell/diffusion.py | 924 ++++++++++++++++++ .../mindflow/cell/diffusion_transformer.py | 264 +++++ tests/st/mindflow/cell/diffusion/ae.py | 156 +++ tests/st/mindflow/cell/diffusion/dataset.py | 129 +++ tests/st/mindflow/cell/diffusion/ddim_gt.py | 430 ++++++++ tests/st/mindflow/cell/diffusion/ddpm_gt.py | 473 +++++++++ .../mindflow/cell/diffusion/test_diffusion.py | 532 ++++++++++ .../cell/diffusion/test_diffusion_train.py | 227 +++++ tests/st/mindflow/cell/diffusion/utils.py | 98 ++ 10 files changed, 3238 insertions(+), 1 deletion(-) create mode 100644 MindFlow/mindflow/cell/diffusion.py create mode 100644 MindFlow/mindflow/cell/diffusion_transformer.py create mode 100644 tests/st/mindflow/cell/diffusion/ae.py create mode 100644 tests/st/mindflow/cell/diffusion/dataset.py create mode 100644 tests/st/mindflow/cell/diffusion/ddim_gt.py create mode 100644 tests/st/mindflow/cell/diffusion/ddpm_gt.py create mode 100644 tests/st/mindflow/cell/diffusion/test_diffusion.py create mode 100644 tests/st/mindflow/cell/diffusion/test_diffusion_train.py create mode 100644 tests/st/mindflow/cell/diffusion/utils.py diff --git a/MindFlow/mindflow/cell/__init__.py b/MindFlow/mindflow/cell/__init__.py index 125cf7652..80bb71339 100644 --- a/MindFlow/mindflow/cell/__init__.py +++ b/MindFlow/mindflow/cell/__init__.py @@ -20,8 +20,12 @@ from .attention import Attention, MultiHeadAttention, AttentionBlock from .vit import ViT from .unet2d import UNet2D from .sno_utils import poly_data, get_poly_transform, interpolate_1d_dataset, interpolate_2d_dataset +from .diffusion import DiffusionScheduler, DiffusionTrainer, DDPMScheduler, DDIMScheduler, DDPMPipeline, DDIMPipeline +from .diffusion_transformer import DiffusionTransformer, ConditionDiffusionTransformer __all__ = ["get_activation", "FNO1D", "FNO2D", "FNO3D", "KNO1D", "KNO2D", "PDENet", "UNet2D", "PeRCNN", - "SNO1D", "SNO2D", "SNO3D", "Attention", "MultiHeadAttention", "AttentionBlock", "ViT"] + "SNO1D", "SNO2D", "SNO3D", "Attention", "MultiHeadAttention", "AttentionBlock", "ViT", "DDPMPipeline", + "DDIMPipeline", "DiffusionTrainer", "DiffusionScheduler", "DDPMScheduler", "DDIMScheduler", + "DiffusionTransformer", "ConditionDiffusionTransformer"] __all__.extend(basic_block.__all__) __all__.extend(sno_utils.__all__) diff --git a/MindFlow/mindflow/cell/diffusion.py b/MindFlow/mindflow/cell/diffusion.py new file mode 100644 index 000000000..408fa5fd1 --- /dev/null +++ b/MindFlow/mindflow/cell/diffusion.py @@ -0,0 +1,924 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Diffusion api""" +# pylint: disable=C0301 +import math + +import numpy as np +from mindspore import dtype as mstype +from mindspore import ops, Tensor, jit_class, nn + + +def extract(a, t, x_shape): + """calculate a[timestep]""" + b = t.shape[0] + out = Tensor(a).gather(t, -1) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine"): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Outputs: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "exp": + def alpha_bar_fn(t): + return math.exp(t * -12.0) + else: + raise ValueError( + f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return np.array(betas) + + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on (https://arxiv.org/pdf/2305.08891.pdf) (Algorithm 1) + + Args: + betas (`ms.Tensor`): + the betas that the scheduler is being initialized with. + + Outputs: + `ms.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy() + alphas_bar_sqrt_t = alphas_bar_sqrt[-1].copy() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_t + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / \ + (alphas_bar_sqrt_0 - alphas_bar_sqrt_t) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = np.concatenate([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +@jit_class +class DiffusionScheduler: + r""" + Diffusion Scheduler init. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"squaredcos_cap_v2"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + prediction_type (`str`, defaults to `epsilon`): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. `clip_sample=True` when clip_sample_range > 0. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + compute_dtype: the dtype of compute, it can be mstype.float32 or mstype.float16. The default value is mstype.float32. + """ + + def __init__(self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "squaredcos_cap_v2", + prediction_type: str = "epsilon", + clip_sample_range: float = 1.0, + thresholding: bool = False, + sample_max_value: float = 1.0, + dynamic_thresholding_ratio: float = 0.995, + rescale_betas_zero_snr: bool = False, + timestep_spacing: str = "leading", + compute_dtype=mstype.float32): + + self.num_train_timesteps = num_train_timesteps + self.beta_start = beta_start + self.beta_end = beta_end + self.timestep_spacing = timestep_spacing + self.prediction_type = prediction_type + self.sample_max_value = sample_max_value + if clip_sample_range is not None: + self.clip_sample = True + else: + self.clip_sample = False + self.clip_sample_range = clip_sample_range + self.thresholding = thresholding + + self.betas = self._init_betas(beta_schedule, rescale_betas_zero_snr) + + # sampling related parameters + alphas = 1. - self.betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.pad( + alphas_cumprod[:-1], (1, 0), constant_values=1) + + self.alphas_cumprod = Tensor(alphas_cumprod, dtype=compute_dtype) + self.alphas_cumprod_prev = Tensor( + alphas_cumprod_prev, dtype=compute_dtype) + self.num_timesteps = num_train_timesteps + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = Tensor( + np.sqrt(alphas_cumprod), dtype=compute_dtype) + self.sqrt_one_minus_alphas_cumprod = Tensor( + np.sqrt(1. - alphas_cumprod), dtype=compute_dtype) + + self.sqrt_recip_alphas_cumprod = Tensor( + np.sqrt(1. / alphas_cumprod), dtype=compute_dtype) + self.sqrt_recipm1_alphas_cumprod = Tensor( + np.sqrt(1. / alphas_cumprod - 1), dtype=compute_dtype) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = self.betas * \ + (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + posterior_variance = np.clip(posterior_variance, 1e-20, None) + self.posterior_variance = Tensor( + posterior_variance, dtype=compute_dtype) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = Tensor(np.log( + posterior_variance), dtype=compute_dtype) # Tensor(np.log(posterior_variance)) + # See formula (7) from (https://arxiv.org/pdf/2006.11239.pdf) + self.posterior_mean_coef1 = Tensor( + self.betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod), dtype=compute_dtype) + self.posterior_mean_coef2 = Tensor( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod), dtype=compute_dtype) + self.num_inference_steps = None + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + + def _init_betas(self, beta_schedule="squaredcos_cap_v2", rescale_betas_zero_snr=False): + """init noise beta schedule + Inputs: + beta_schedule (`str`, defaults to `"squaredcos_cap_v2"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. + This enables the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. Loosely related to [`--offset_noise`] + (https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + Outputs: + betas: `ms.Tensor`: noise coefficients beta. + """ + betas = None + if beta_schedule == "linear": + betas = np.linspace( + self.beta_start, self.beta_end, self.num_train_timesteps) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + betas = np.linspace( + self.beta_start**0.5, self.beta_end**0.5, self.num_train_timesteps) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + betas = betas_for_alpha_bar(self.num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = np.linspace(-6, 6, self.num_train_timesteps) + betas = 1. / (1 + np.exp(-betas)) * \ + (self.beta_end - self.beta_start) + self.beta_start + else: + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}") + # Rescale for zero SNR + if rescale_betas_zero_snr: + betas = rescale_zero_terminal_snr(betas) + return betas + + def set_timesteps(self, num_inference_steps): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Inputs: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `num_train_timesteps`:" + f" {self.num_train_timesteps} as the diffusion model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + self.num_inference_steps = num_inference_steps + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of (https://arxiv.org/abs/2305.08891) + if self.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.num_train_timesteps - + 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int32) + ) + elif self.timestep_spacing == "leading": + step_ratio = self.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * + step_ratio).round()[::-1].astype(np.int32) + elif self.timestep_spacing == "trailing": + step_ratio = self.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round( + np.arange(self.num_train_timesteps, 0, -step_ratio)).astype(np.int32) + timesteps -= 1 + else: + raise ValueError( + f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + self.num_timesteps = timesteps + + def _threshold_sample(self, sample): + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + see (https://arxiv.org/abs/2205.11487) + """ + batch_size, channels, *remaining_dims = sample.shape + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + s = np.quantile(abs_sample, self.dynamic_thresholding_ratio, axis=1) + + s = ops.clamp( + s, min=1, max=self.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + # (batch_size, 1) because clamp will broadcast along dim=0 + s = s.unsqueeze(1) + # "we threshold xt0 to the range [-s, s] and then divide by s" + sample = ops.clamp(sample, -s, s) / s + + sample = sample.reshape(batch_size, channels, *remaining_dims) + + return sample + + def _pred_origin_sample(self, model_output: Tensor, sample: Tensor, timestep: Tensor): + """ + Predict x_0 with x_t. + Inputs: + model_output (`ms.Tensor`): + The direct output from learned diffusion model. + sample (`ms.Tensor`): + A current instance of a sample created by the diffusion process. + timestep (`ms.Tensor`): + The current discrete timestep in the diffusion chain. + Outputs: + x_0 (`ms.Tensor`): + The predicted origin sample x_0. + + """ + if self.prediction_type == "epsilon": + pred_original_sample = extract(self.sqrt_recip_alphas_cumprod, timestep, sample.shape)*sample - \ + extract(self.sqrt_recipm1_alphas_cumprod, + timestep, sample.shape)*model_output + elif self.prediction_type == "sample": + pred_original_sample = model_output + elif self.prediction_type == "v_prediction": + pred_original_sample = extract(self.sqrt_alphas_cumprod, timestep, sample.shape)*sample - \ + extract(self.sqrt_one_minus_alphas_cumprod, + timestep, sample.shape)*model_output + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + # 2. Clip or threshold "predicted x_0" + if self.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.clip_sample: + pred_original_sample = pred_original_sample.clamp(-self.clip_sample_range, + self.clip_sample_range) + return pred_original_sample + + def add_noise(self, original_samples: Tensor, noise: Tensor, timesteps: Tensor): + """ + Add noise in forward process. + Inputs: + original_samples (`ms.Tensor`): + The current samples. + noise (`ms.Tensor`): + Random noise to be add into sample. + timesteps (`ms.Tensor`): + The current discrete timestep in the diffusion chain. + + Outputs: + x_{t+1} (`Tensor`): + The noised sample of the next step. + """ + return (extract(self.sqrt_alphas_cumprod, timesteps, original_samples.shape)*original_samples + + extract(self.sqrt_one_minus_alphas_cumprod, timesteps, original_samples.shape)*noise) + + def step(self, model_output: Tensor, sample: Tensor, timestep: Tensor): + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Inputs: + model_output (`ms.Tensor`): + The direct output from learned diffusion model. + sample (`ms.Tensor`): + A current instance of a sample created by the diffusion process. + timestep (`ms.Tensor`): + The current discrete timestep in the diffusion chain. + Outputs: + x_{t-1} (`Tensor`): + The denoised sample. + """ + if not self.num_inference_steps: + raise NotImplementedError( + f"num_inference_steps is not set for {self.__class__}.Need to set timesteps first.") + raise NotImplementedError( + f"step function does is not implemented for {self.__class__}") + + +class DDPMScheduler(DiffusionScheduler): + r""" + `DDPMScheduler` is an implementation of the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs). + This class inherits from [`DiffusionScheduler`]. Check the superclass documentation for the generic methods + the library implements for all schedulers.(https://arxiv.org/abs/2006.11239) for more information. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"squaredcos_cap_v2"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + prediction_type (`str`, defaults to `epsilon`): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + variance_type (`str`, defaults to `"fixed_small_log"`): + Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, + `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. `clip_sample=True` when clip_sample_range > 0. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + compute_dtype: the dtype of compute, it can be mstype.float32 or mstype.float16. The default value is mstype.float32. + """ + + def __init__(self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "squaredcos_cap_v2", + prediction_type: str = "epsilon", + variance_type: str = "fixed_small_log", + clip_sample_range: float = 1.0, + thresholding: bool = False, + sample_max_value: float = 1.0, + dynamic_thresholding_ratio: float = 0.995, + rescale_betas_zero_snr: bool = False, + timestep_spacing: str = "leading", + compute_dtype=mstype.float32): + super().__init__(num_train_timesteps, + beta_start, + beta_end, + beta_schedule, + prediction_type, + clip_sample_range, + thresholding, + sample_max_value, + dynamic_thresholding_ratio, + rescale_betas_zero_snr, + timestep_spacing, + compute_dtype) + self.variance_type = variance_type + + def _get_variance(self, x_t, t, predicted_variance=None): + """get DDPM variance""" + variance = extract(self.posterior_variance, t, x_t.shape) + beta_t = extract(self.betas, t, x_t.shape) + # hacks - were probably added for training stability + if self.variance_type == "fixed_small": + variance = variance + # for rl-diffuser (https://arxiv.org/abs/2205.09991) + elif self.variance_type == "fixed_small_log": + variance = ops.log(variance) + variance = ops.exp(0.5 * variance) + elif self.variance_type == "fixed_large": + variance = beta_t + elif self.variance_type == "fixed_large_log": + # Glide max_log + variance = ops.log(beta_t) + elif self.variance_type == "learned": + return predicted_variance + elif self.variance_type == "learned_range": + min_log = ops.log(variance) + max_log = ops.log(beta_t) + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def set_timesteps(self, num_inference_steps): + # DDPM num_inference_steps defaults to num_train_timesteps + if num_inference_steps != self.num_train_timesteps: + raise ValueError( + "DDPM num_inference_steps defaults to num_train_timesteps") + super().set_timesteps(self.num_train_timesteps) + + # pylint: disable=W0221 + def step(self, model_output: Tensor, sample: Tensor, timestep: Tensor, predicted_variance: Tensor = None): + # 1. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from (https://arxiv.org/pdf/2006.11239.pdf) + + pred_original_sample = self._pred_origin_sample( + model_output, sample, timestep) + # 2. Compute predicted previous sample µ_t + # See formula (7) from (https://arxiv.org/pdf/2006.11239.pdf) + pred_prev_sample = ( + extract(self.posterior_mean_coef1, timestep, sample.shape)*pred_original_sample + + extract(self.posterior_mean_coef2, timestep, sample.shape)*sample + ) + + # 3. Add noise + v = self._get_variance(sample, timestep, predicted_variance) + variance = 0 + if timestep[0] > 0: + variance_noise = ops.randn_like(sample) + if self.variance_type == "fixed_small_log": + variance = v * variance_noise + elif self.variance_type == "learned_range": + variance = ops.exp(0.5 * v) * variance_noise + else: + variance = (v ** 0.5) * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample + + +class DDIMScheduler(DiffusionScheduler): + r""" + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`DiffusionScheduler`]. Check the superclass documentation for the generic + methods the library implements for all schedulers.(https://arxiv.org/abs/2010.02502) for more information. + + """ + + def __init__(self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "squaredcos_cap_v2", + prediction_type: str = "epsilon", + clip_sample_range: float = 1.0, + thresholding: bool = False, + sample_max_value: float = 1.0, + dynamic_thresholding_ratio: float = 0.995, + rescale_betas_zero_snr: bool = False, + timestep_spacing: str = "leading", + compute_dtype=mstype.float32): + super().__init__(num_train_timesteps, + beta_start, + beta_end, + beta_schedule, + prediction_type, + clip_sample_range, + thresholding, + sample_max_value, + dynamic_thresholding_ratio, + rescale_betas_zero_snr, + timestep_spacing, + compute_dtype) + + self.final_alpha_cumprod = Tensor(1.0, dtype=compute_dtype) + self.num_train_timesteps = num_train_timesteps + # standard deviation of the initial noise distribution + + def _get_variance(self, timestep, prev_timestep): + """get DDIM variance""" + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep[0] >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * \ + (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def _pred_epsilon(self, model_output: Tensor, sample: Tensor, timestep: Tensor): + """ + Predict epsilon. + Inputs: + model_output (`ms.Tensor`): + The direct output from learned diffusion model. + sample (`ms.Tensor`): + A current instance of a sample created by the diffusion process. + timestep (`ms.Tensor`): + The current discrete timestep in the diffusion chain. + Outputs: + epsilon (`ms.Tensor`): + The predicted epsilon. + + """ + if self.prediction_type == "epsilon": + pred_epsilon = model_output + elif self.prediction_type == "sample": + tmp = (sample - extract(self.sqrt_alphas_cumprod, + timestep, sample.shape) * model_output) + pred_epsilon = extract(ops.reciprocal( + self.sqrt_one_minus_alphas_cumprod), timestep, tmp.shape)*tmp + else: + pred_epsilon = extract(self.alphas_cumprod, timestep, sample.shape) * model_output + extract( + self.sqrt_one_minus_alphas_cumprod, timestep, sample.shape) * sample + + return pred_epsilon + + # pylint: disable=W0221 + def step(self, + model_output: Tensor, + sample: Tensor, + timestep: Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + ): + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Inputs: + model_output (`Tensor`): + The direct output from learned diffusion model. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + timestep (`ms.Tensor`): + The current discrete timestep in the diffusion chain. + eta (`float`): + The weight of noise for added noise in diffusion step. DDIM when eta=0, DDPM when eta=1. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + + Outputs: + x_prev (`Tensor`): + Denoise and output x_prev. + + """ + # See formulas (12) and (16) of DDIM paper (https://arxiv.org/pdf/2010.02502.pdf) + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + assert 0 <= eta <= 1., "eta must in range: [0, 1]" + dtype = sample.dtype + batch_size = timestep.shape[0] + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + alpha_len = self.alphas_cumprod.shape[0] + assert (timestep < alpha_len).all(), "timestep out of bounds" + assert (prev_timestep < alpha_len).all() + + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep[0] >= 0 else \ + self.final_alpha_cumprod.repeat(batch_size) + # 2. get α_t/α_t−1 + beta_prod_t = 1 - alpha_prod_t + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from (https://arxiv.org/pdf/2010.02502.pdf) + pred_original_sample = self._pred_origin_sample( + model_output, sample, timestep) + pred_epsilon = self._pred_epsilon(model_output, sample, timestep) + # 4. compute variance: "sigma_t(η)" -> see formula (16) from (https://arxiv.org/pdf/2010.02502.pdf) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = (eta * ops.sqrt(variance)).astype(dtype) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = ( + (sample - (alpha_prod_t ** (0.5)).astype(dtype) * + pred_original_sample) / beta_prod_t ** (0.5) + ).astype(dtype) + + # 5. compute "direction pointing to x_t" of formula (12) from (https://arxiv.org/pdf/2010.02502.pdf) + pred_sample_direction = ( + (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5)).reshape(batch_size, 1, 1) * pred_epsilon + + # 6. compute x_t without "random noise" of formula (12) from (https://arxiv.org/pdf/2010.02502.pdf) + coef = ops.sqrt(alpha_prod_t_prev).reshape(batch_size, 1, 1) + prev_sample = coef * pred_original_sample + pred_sample_direction + if eta > 0: + variance_noise = ops.randn_like(model_output, dtype=dtype) + variance = std_dev_t.reshape(batch_size, 1, 1) * variance_noise + prev_sample += variance + return prev_sample + + +class DiffusionPipeline: + r""" + Pipeline for diffusion generation. + + Inputs: + model ([`nn.Cell`]): + A diffusion backbone to denoise the encoded image latents. + scheduler ([`DiffusionScheduler`]): + A scheduler to be used in combination with `model` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + batch_size ([`int`]): + The number of images to generate. + seq_len ([`int`]): + Sequence length of inputs. + num_inference_steps ([`int`]): + Number of Denoising steps. + compute_dtype: + The dtype of compute, it can be mstype.float32 or mstype.float16. The default value is mstype.float32. + + Example: + + ```py + >>> from mindflow.diffusion import DiffusionPipeline + + >>> # init condition + >>> cond = Tensor(np.random.rand(4, 5).astype(np.float32)) + + >>> # init model and scheduler + >>> model = DiffusionModel() + >>> scheduler = DDPMScheduler(...) + + >>> # init pipeline + >>> pipe = DiffusionPipeline(model, scheduler, batch_size=4, seq_len=64, num_inference_steps=100) + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pipe(cond) + + """ + + def __init__(self, model, scheduler, batch_size, seq_len, num_inference_steps=100, compute_dtype=mstype.float32): + self.model = model + self.scheduler = scheduler + self.seq_len = seq_len + self.compute_dtype = compute_dtype + self.batch_size = batch_size + self.scheduler.set_timesteps(num_inference_steps) + + def _pred_noise(self, sample, condition, timesteps): + if condition is not None: + inputs = (sample, timesteps, condition) + else: + inputs = (sample, timesteps) + model_output = self.model(*inputs) + return model_output + + def _sample_step(self, sample, condition, timesteps): + model_output = self._pred_noise(sample, condition, timesteps) + sample = self.scheduler.step( + model_output=model_output, sample=sample, timestep=timesteps) + return sample + + def __call__(self, condition=None): + r""" + The call function to the pipeline for generation. + + Inputs: + condition ([`ms.Tensor`]): + Condition for diffusion generation process. + Outputs: + sample ([`ms.Tensor`]): + Predicted original samples. + + """ + sample = Tensor(np.random.randn(self.batch_size, self.seq_len, + self.model.in_channels), dtype=self.compute_dtype) + if condition is not None: + condition = condition.reshape(self.batch_size, -1) + + for t in self.scheduler.num_timesteps: + batched_times = ops.ones((self.batch_size,), mstype.int32) * int(t) + sample = self._sample_step(sample, condition, batched_times) + return sample + + +class DDPMPipeline(DiffusionPipeline): + r""" + Pipeline for DDPM generation. + This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + the library implements for all pipelines. + """ + + +class DDIMPipeline(DiffusionPipeline): + r""" + Pipeline for DDIM generation. + This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + the library implements for all pipelines. + """ + # pylint: disable=W0235 + def __init__(self, model, scheduler, batch_size, seq_len, num_inference_steps=100, compute_dtype=mstype.float32): + super().__init__(model, scheduler, batch_size, + seq_len, num_inference_steps, compute_dtype) + + # pylint: disable=W0221 + def _sample_step(self, sample, condition, timesteps, eta, use_clipped_model_output): + model_output = self._pred_noise(sample, condition, timesteps) + sample = self.scheduler.step(model_output=model_output, sample=sample, timestep=timesteps, + eta=eta, use_clipped_model_output=use_clipped_model_output) + return sample + + def __call__(self, condition=None, eta=0., use_clipped_model_output=False): + r""" + The call function to the pipeline for generation. + + Inputs: + condition ([`ms.Tensor`]): + Condition for diffusion generation process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + Outputs: + sample ([`ms.Tensor`]): + Predicted original samples. + + """ + sample = Tensor(np.random.randn(self.batch_size, self.seq_len, + self.model.in_channels), dtype=self.compute_dtype) + if condition is not None: + condition = condition.reshape(self.batch_size, -1) + + for t in self.scheduler.num_timesteps: + batched_times = ops.ones((self.batch_size,), mstype.int32) * int(t) + sample = self._sample_step( + sample, condition, batched_times, eta, use_clipped_model_output) + + return sample + + +@jit_class +class DiffusionTrainer: + r""" + Diffusion Trainer base class. + + Args: + model (`nn.Cell`): + The diffusion backbone model. + scheduler (DiffusionScheduler): + DDPM or DDIM scheduler. + objective (`str`, defaults to `pred_noise`): + Prediction type of the scheduler function; can be `pred_noise` (predicts the noise of the diffusion process), + `pred_x0` (predicts the original sample`) or `pred_v` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + p2_loss_weight_gamma (`float`, defaults to `0`): + p2 loss weight gamma, from (https://arxiv.org/abs/2204.00227) - 0 is equivalent to weight of 1 across time. + p2_loss_weight_k (`float`, defaults to `1`): + p2 loss weight k, from (https://arxiv.org/abs/2204.00227). + loss_type (`str`, defaults to `l1`): + the type os loss, it can be l1 or l2. + """ + + def __init__(self, + model, + scheduler, + objective='pred_noise', + p2_loss_weight_gamma=0., + p2_loss_weight_k=1, + loss_type='l1', + ): + + self.model = model + self.scheduler = scheduler + p2_loss_weight = (p2_loss_weight_k + scheduler.alphas_cumprod / + (1 - scheduler.alphas_cumprod)) ** -p2_loss_weight_gamma + self.p2_loss_weight = Tensor(p2_loss_weight, mstype.float32) + self.objective = objective + self.loss_type = loss_type + if self.loss_type == 'l1': + self.loss_fn = nn.L1Loss('none') + elif self.loss_type == 'l2': + self.loss_fn = nn.MSELoss('none') + else: + raise ValueError(f'invalid loss type {self.loss_type}') + + def get_loss(self, original_samples: Tensor, noise: Tensor, timesteps: Tensor, condition: Tensor = None): + r""" + Inputs: + original_samples (`Tensor`): + The direct output from learned diffusion model. + noise (`Tensor`): + A current instance of a noise sample created by the diffusion process. + timesteps (`float`): + The current discrete timestep in the diffusion chain. + condition (`Tensor`, defaults to `None`): + The condition for desired outputs. + outputs: + loss (`Tensor`): + The model forward loss. + """ + + noised_sample = self.scheduler.add_noise( + original_samples, noise, timesteps) + if condition is None: + inputs = (noised_sample, timesteps) + else: + inputs = (noised_sample, timesteps, condition) + + if self.objective == 'pred_noise': + target = noise + elif self.objective == 'pred_x0': + target = original_samples + elif self.objective == 'pred_v': + target = (extract(self.scheduler.sqrt_alphas_cumprod, timesteps, original_samples.shape)*noise - + extract(self.scheduler.sqrt_one_minus_alphas_cumprod, timesteps, + original_samples.shape)*original_samples) + else: + target = noise + + model_out = self.model(*inputs) + loss = self.loss_fn(model_out, target) + loss = loss.reshape(loss.shape[0], -1) + loss = loss * extract(self.p2_loss_weight, timesteps, loss.shape) + return loss.mean() diff --git a/MindFlow/mindflow/cell/diffusion_transformer.py b/MindFlow/mindflow/cell/diffusion_transformer.py new file mode 100644 index 000000000..b8207552a --- /dev/null +++ b/MindFlow/mindflow/cell/diffusion_transformer.py @@ -0,0 +1,264 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Diffusion transformer api""" + +import math + +import numpy as np +from mindspore import nn, ops, Tensor +from mindspore import dtype as mstype +from mindflow.cell import AttentionBlock + + +class Mlp(nn.Cell): + """MLP""" + + def __init__(self, in_channels, out_channels, dropout=0., compute_dtype=mstype.float32): + super().__init__() + self.fc1 = nn.Dense( + in_channels, 4*in_channels).to_float(compute_dtype) + self.act = nn.GELU() + self.fc2 = nn.Dense( + 4*in_channels, out_channels).to_float(compute_dtype) + self.drop = nn.Dropout(p=dropout) + + def construct(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SinusoidalPosEmb(nn.Cell): + """sinusoidal embedding model""" + + def __init__(self, dim, max_period=10000, compute_dtype=mstype.float32): + super().__init__() + half_dim = dim // 2 + self.concat_zero = (dim % 2 == 1) + freqs = np.exp(-math.log(max_period) * + np.arange(start=0, stop=half_dim) / half_dim) + self.freqs = Tensor(freqs, compute_dtype) + + def construct(self, x): + emb = x[:, None] * self.freqs[None, :] + emb = ops.concat((ops.cos(emb), ops.sin(emb)), axis=-1) + if self.concat_zero: + emb = ops.concat([emb, ops.zeros_like(emb[:, :1])], axis=-1) + return emb + + +class Transformer(nn.Cell): + """Transformer backbone model""" + + def __init__(self, hidden_channels, layers, heads, compute_dtype=mstype.float32): + super().__init__() + self.hidden_channels = hidden_channels + self.layers = layers + self.blocks = nn.CellList([ + AttentionBlock( + in_channels=hidden_channels, + num_heads=heads, + drop_mode="dropout", + dropout_rate=0.0, + compute_dtype=compute_dtype, + ) + for _ in range(layers)]) + + def construct(self, x): + for block in self.blocks: + x = block(x) + return x + + +class DiffusionTransformer(nn.Cell): + r""" + Diffusion model with Transformer backbone implementation. + + Args: + in_channels (`int`): + The number of input channel. + out_channels (`int`): + The number of output channel. + hidden_channels (`int`): + The number of hidden channel. + layers (`int`): + The number of transformer block layers. + heads (`int`): + The number of transformer heads. + time_token_cond (`bool`, defaults to `True`): + Whether to use timestep as condition token. + compute_dtype (`ms.dtype`, defaults to `"mstype.float32"`): + compute_dtype: the dtype of compute, it can be mstype.float32 or mstype.float16. + The default value is mstype.float32. + + """ + + def __init__(self, + in_channels, + out_channels, + hidden_channels, + layers, + heads, + time_token_cond=True, + compute_dtype=mstype.float32): + super().__init__() + self.time_token_cond = time_token_cond + self.in_channels = in_channels + self.timestep_emb = SinusoidalPosEmb( + hidden_channels, compute_dtype=compute_dtype) + self.time_embed = Mlp( + in_channels=hidden_channels, + out_channels=hidden_channels, + dropout=0., + compute_dtype=compute_dtype + ) + + self.ln_pre = nn.LayerNorm( + (hidden_channels,), epsilon=1e-5).to_float(mstype.float32) + self.backbone = Transformer( + hidden_channels=hidden_channels, + layers=layers, + heads=heads, + compute_dtype=compute_dtype + ) + self.ln_post = nn.LayerNorm( + (hidden_channels,), epsilon=1e-5).to_float(mstype.float32) + self.input_proj = nn.Dense( + in_channels, hidden_channels).to_float(compute_dtype) + self.output_proj = nn.Dense( + hidden_channels, out_channels, weight_init='zeros', bias_init='zeros').to_float(compute_dtype) + + def construct(self, x, timestep): + """forward network with timestep input + Inputs: + x (`ms.Tensor`, a [B*N*C] shape tensor): + The input has a shape of [B*N*C]. B: batch_size N: sequence length C: input feature channels. + timestep (`ms.Tensor`, a [B] shape tensor): + Timestep inputs. + Outputs: + output: (`ms.Tensor`, a [B*N*C'] shape tensor): + The output has a shape of [B*N*C']. B: batch_size N: sequence length C': output feature channels. + """ + t_embed = self.time_embed(self.timestep_emb(timestep)) + return self._forward_with_cond(x, [(t_embed, self.time_token_cond)]) + + def _forward_with_cond(self, x, cond_token_list): + """forward network with condition input + Inputs: + x (`ms.Tensor`, a [B*N*C] shape tensor): + The input has a shape of [B*N*C]. B: batch_size N: sequence length C: input feature channels. + cond_token_list (List[Tuple[Tensor, bool]]): + The condition and whether use condition as tokens. + Outputs: + output: (`ms.Tensor`, a [B*N*C'] shape tensor): + The output has a shape of [B*N*C']. B: batch_size N: sequence length C: output feature channels. + """ + + h = self.input_proj(x) + extra_tokens = [] + for tokens, as_token in cond_token_list: + if as_token: + if len(tokens.shape) == 2: + extra_tokens.append(tokens.unsqueeze(1)) + else: + extra_tokens.append(tokens) + else: + h = h + tokens.unsqueeze(1) + + if extra_tokens: + h = ops.concat(extra_tokens + [h], axis=1) + + h = self.ln_pre(h) + h = self.backbone(h) + h = self.ln_post(h) + if extra_tokens: + # keep sequence length unchanged + h = h[:, sum(token.shape[1] for token in extra_tokens):] + h = self.output_proj(h) + return h + + +class ConditionDiffusionTransformer(DiffusionTransformer): + r""" + Conditioned Diffusion Transformer implementation. + + Args: + in_channels (`int`): + The number of input channel. + out_channels (`int`): + The number of output channel. + hidden_channels (`int`): + The number of hidden channel. + cond_channels (`int`): + The number of condition channel. + layers (`int`): + The number of transformer block layers. + heads (`int`): + The number of transformer heads. + time_token_cond (`bool`, defaults to `True`): + Whether to use timestep as condition token. + cond_as_token (`bool`, defaults to `True`): + Whether to use condition as token. + compute_dtype (`ms.dtype`, defaults to `"mstype.float32"`): + compute_dtype: the dtype of compute, it can be mstype.float32 or mstype.float16. + The default value is mstype.float32. + + """ + + def __init__(self, in_channels, + out_channels, + cond_channels, + hidden_channels, + layers, + heads, + time_token_cond=True, + cond_as_token=True, + compute_dtype=mstype.float32): + super().__init__(in_channels, + out_channels, + hidden_channels, + layers, + heads, + time_token_cond, + compute_dtype) + self.cond_as_token = cond_as_token + self.cond_embed = nn.Dense( + cond_channels, hidden_channels).to_float(compute_dtype) + + # pylint: disable=W0221 + def construct(self, x, timestep, cond=None): + """forward network with timestep and condition input + Inputs: + x (`ms.Tensor`, a [B*N*C] shape tensor): + The input has a shape of [B*N*C]. B: batch_size N: sequence length C: input feature channels. + timestep (`ms.Tensor`, a [B] shape tensor): + Timestep inputs. + cond (`ms.Tensor`, a [B*C_c] shape tensor, defaults to `None`): + Condition input has a shape of [B*N*C_c]. B: batch_size N: sequence length + C_c: condition feature channels. + + Outputs: + output: (`ms.Tensor`, a [B*N*C'] shape tensor): + The output has a shape of [B*N*C']. B: batch_size N: sequence length C': output feature channels. + """ + t_embed = self.time_embed(self.timestep_emb(timestep)) + full_cond = [(t_embed, self.time_token_cond)] + if cond is not None: + cond_emb = self.cond_embed(cond) + full_cond.append((cond_emb, self.cond_as_token)) + return self._forward_with_cond(x, full_cond) diff --git a/tests/st/mindflow/cell/diffusion/ae.py b/tests/st/mindflow/cell/diffusion/ae.py new file mode 100644 index 000000000..8e7641eb8 --- /dev/null +++ b/tests/st/mindflow/cell/diffusion/ae.py @@ -0,0 +1,156 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""autoencoder model for MNIST dataset""" +import time +import os + +import matplotlib.pyplot as plt +import numpy as np +from mindspore import Tensor, ops, amp, load_checkpoint, load_param_into_net, nn, jit, save_checkpoint + +from dataset import get_dataset, SampleScaler, CKPT_PATH, load_mnist + +LATENT_DIM = 16 +HIDDEN_DIM = 128 +INPUT_DIM = OUTPUT_DIM = 28*28 + + +class Encoder(nn.Cell): + """encoder""" + + def __init__(self, input_size, hidden_size, latent_size): + super().__init__() + + self.layers = nn.SequentialCell( + nn.Dense(input_size, hidden_size), + nn.Tanh(), + nn.Dense(hidden_size, hidden_size), + nn.Tanh(), + nn.Dense(hidden_size, latent_size) + ) + + def construct(self, x): # x: bs,input_size + x = self.layers(x) + return x + + +class Decoder(nn.Cell): + """decoder""" + + def __init__(self, latent_size, hidden_size, output_size): + super().__init__() + self.layers = nn.SequentialCell( + nn.Dense(latent_size, hidden_size), + nn.Tanh(), + nn.Dense(hidden_size, hidden_size), + nn.Tanh(), + nn.Dense(hidden_size, output_size), + nn.Sigmoid() + ) + + def construct(self, x): # x:bs,latent_size + x = self.layers(x) + return x + + +class AE(nn.Cell): + """AutoEncoder model""" + + def __init__(self, input_size, output_size, latent_size, hidden_size): + super(AE, self).__init__() + self.encoder = Encoder(input_size, hidden_size, latent_size) + self.decoder = Decoder(latent_size, hidden_size, output_size) + + def construct(self, x): # x: bs,input_size + feat = self.encoder(x) # feat: bs,latent_size + re_x = self.decoder(feat) # re_x: bs, output_size + return re_x + + +def train_ae(): + """train autoencoder on MNIST dataset""" + criterion = nn.MSELoss() + model = AE(INPUT_DIM, OUTPUT_DIM, LATENT_DIM, HIDDEN_DIM) + + def forward_fn(data): + pred = model(data) + loss = criterion(data, pred) + return loss + + lr = 0.001 + optimizer = nn.Adam(model.trainable_params(), lr, 0.9, 0.99) + grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters) + + @jit + def train_step(data): + loss, grads = grad_fn(data) + is_finite = amp.all_finite(grads) + if is_finite: + loss = ops.depend(loss, optimizer(grads)) + return loss + + batch_size = 256 + train_dataset = get_dataset(batch_size=batch_size) + epochs = 200 + for epoch in range(1, epochs+1): + time_beg = time.time() + for samples in train_dataset: + loss = train_step(samples) + time_end = time.time() + print(f'epoch: {epoch} loss: {loss} time: {time_end - time_beg}') + x = [] + model.set_train(False) + for samples in train_dataset: + x.append(samples) + x = Tensor(np.concatenate(x, axis=0)) + save_checkpoint(model, CKPT_PATH) + latent = model.encoder(x).numpy() + np.save('latent.npy', latent) + + +def generate_image(latent, diffusion_type='ddpm', save_dir='./images'): + """generate MNIST images from latent""" + latent = Tensor(latent) + model = AE(INPUT_DIM, OUTPUT_DIM, LATENT_DIM, HIDDEN_DIM) + param_dict = load_checkpoint(CKPT_PATH) + load_param_into_net(model, param_dict) + recon_img = model.decoder(latent).numpy() + recon_img = SampleScaler.unscale(recon_img) + recon_img = recon_img.reshape(-1, OUTPUT_DIM) + print(recon_img.shape) + save_dir = os.path.join(save_dir, f'{diffusion_type}') + os.makedirs(save_dir, exist_ok=True) + for i in range(recon_img.shape[0]): + plt.imsave(os.path.join(save_dir, f'{i}.png'), recon_img[i].reshape( + 28, 28), cmap="gray") + + +def export_latent(): + """export MNIST images latent""" + images, labels = load_mnist() + print(images.shape, labels.shape) + images /= 255. + images = images.reshape(images.shape[0], -1) + model = AE(INPUT_DIM, OUTPUT_DIM, LATENT_DIM, HIDDEN_DIM) + param_dict = load_checkpoint(CKPT_PATH) + load_param_into_net(model, param_dict) + latent = model.encoder(Tensor(images)).numpy() + np.save('latent.npy', latent) + labels = labels.reshape(labels.shape[0], 1) + np.save('labels.npy', labels) + + +if __name__ == '__main__': + export_latent() diff --git a/tests/st/mindflow/cell/diffusion/dataset.py b/tests/st/mindflow/cell/diffusion/dataset.py new file mode 100644 index 000000000..4fc6f0f6d --- /dev/null +++ b/tests/st/mindflow/cell/diffusion/dataset.py @@ -0,0 +1,129 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""dataset api""" +import os +import struct + +import numpy as np +from mindspore import jit_class +import mindspore.dataset as ds + +FILE_DIR = '/home/workspace/mindspore_dataset/mindscience/mindflow/diffusion' +CKPT_PATH = os.path.join(FILE_DIR, 'ae.ckpt') + + +def load_data(filename): + """load data from `FILE_DIR`""" + return np.load(os.path.join(FILE_DIR, filename)) + + +def load_mnist(kind='train'): + """Load MNIST data from `path`""" + path = os.path.join(FILE_DIR, 'MNIST') + labels_path = os.path.join(path, f'{kind}-labels.idx1-ubyte') + images_path = os.path.join(path, f'{kind}-images.idx3-ubyte') + with open(labels_path, 'rb') as lbpath: + _, _ = struct.unpack('>II', lbpath.read(8)) + labels = np.fromfile(lbpath, dtype=np.uint8) + + with open(images_path, 'rb') as imgpath: + _, _, _, _ = struct.unpack('>IIII', imgpath.read(16)) + images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) + + return images.astype(np.float32), labels + + +class TrainData: + """MNIST dataset class""" + + def __init__(self, x): + self.length = x.shape[0] + self.x = x + + def __getitem__(self, index): + return self.x[index] + + def __len__(self): + return self.length + + +class LatentData: + """MNIST latent dataset class""" + + def __init__(self, x, y): + self.length = x.shape[0] + self.x = x + self.y = y + + def __getitem__(self, index): + return self.x[index], self.y[index] + + def __len__(self): + return self.length + + +@jit_class +class SampleScaler: + """MNIST scale class""" + + def __init__(self): + pass + + @staticmethod + def scale(x): + """scale image""" + return x/255. + + @staticmethod + def unscale(x): + """unscale image""" + return x*255. + + +def get_dataset(batch_size): + """build MNIST dataset""" + x, _ = load_mnist() + # scale MNIST dataset + x = SampleScaler.scale(x) + x = x.reshape(x.shape[0], x.shape[1]) + + train_itr = TrainData(x) + train_dataset = ds.GeneratorDataset(train_itr, + column_names=['images'], + shuffle=True, + num_parallel_workers=1) + + train_dataset = train_dataset.batch( + batch_size=batch_size, drop_remainder=True) + print('train_dataset:', len(train_dataset)) + return train_dataset + + +def get_latent_dataset(latent_path, cond_path, batch_size): + """build MNIST latent dataset""" + x = load_data(latent_path) + x = x.reshape(x.shape[0], 1, x.shape[1]) + cond = load_data(cond_path) + cond = cond.reshape(cond.shape[0], 1).astype(np.float32) + train_itr = LatentData(x, cond) + train_dataset = ds.GeneratorDataset(train_itr, + column_names=['images', 'cond'], + shuffle=True, + num_parallel_workers=1) + + train_dataset = train_dataset.batch( + batch_size=batch_size, drop_remainder=True) + print('train_dataset:', len(train_dataset)) + return train_dataset diff --git a/tests/st/mindflow/cell/diffusion/ddim_gt.py b/tests/st/mindflow/cell/diffusion/ddim_gt.py new file mode 100644 index 000000000..3e8803ac0 --- /dev/null +++ b/tests/st/mindflow/cell/diffusion/ddim_gt.py @@ -0,0 +1,430 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""mo-ddim api""" +from typing import List, Optional, Union + +import numpy as np + +import mindspore as ms +from mindspore import ops, Tensor, set_seed + +from utils import rescale_zero_terminal_snr, betas_for_alpha_bar + +set_seed(0) +np.random.seed(0) + + +class DDIMScheduler: + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + def __init__(self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = ms.tensor(trained_betas, dtype=ms.float32) + elif beta_schedule == "linear": + self.betas = ms.tensor( + np.linspace(beta_start, beta_end, num_train_timesteps), dtype=ms.float32 + ) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + ms.tensor( + np.linspace(beta_start**0.5, beta_end ** + 0.5, num_train_timesteps), + dtype=ms.float32, + ) + ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}" + ) + self.betas = Tensor(self.betas) + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # we set this parameter simply to one or + self.final_alpha_cumprod = ms.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_train_timesteps = num_train_timesteps + self.num_inference_steps = None + self.timesteps = ms.Tensor( + np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64) + ) + self.timestep_spacing = timestep_spacing + self.prediction_type = prediction_type + self.sample_max_value = sample_max_value + self.thresholding = thresholding + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.steps_offset = steps_offset + self.rescale_betas_zero_snr = rescale_betas_zero_snr + self.clip_sample = clip_sample + self.clip_sample_range = clip_sample_range + + def _get_variance(self, timestep, prev_timestep): + """get variance""" + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] + if prev_timestep[0] >= 0 + else self.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * ( + 1 - alpha_prod_t / alpha_prod_t_prev + ) + + return variance + + def _threshold_sample(self, sample: ms.Tensor) -> ms.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + (https://arxiv.org/abs/2205.11487) + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (ms.float32, ms.float64): + # upcast for quantile calculation, and clamp not implemented for cpu half + sample = sample.float() + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * + np.prod(remaining_dims).item()) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = ms.Tensor.from_numpy( + np.quantile(abs_sample.asnumpy(), + self.dynamic_thresholding_ratio, axis=1) + ) + s = ops.clamp( + s, min=1, max=self.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + # (batch_size, 1) because clamp will broadcast along dim=0 + s = s.unsqueeze(1) + # "we threshold xt0 to the range [-s, s] and then divide by s" + sample = ops.clamp(sample, -s, s) / s + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of (https://arxiv.org/abs/2305.08891) + if self.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.num_train_timesteps - + 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.timestep_spacing == "leading": + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps) * step_ratio) + .round()[::-1] + .copy() + .astype(np.int64) + ) + timesteps += self.steps_offset + elif self.timestep_spacing == "trailing": + step_ratio = self.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round( + np.arange(self.num_train_timesteps, 0, -step_ratio) + ).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = ms.Tensor(timesteps) + + def step(self, + model_output: ms.Tensor, + timestep: int, + sample: ms.Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[ms.Tensor] = None, + ): + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`ms.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`ms.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`np.random.Generator`, *optional*): + A random number generator. + variance_noise (`ms.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + + Returns: + [`ms.Tensor`]: + Sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper (https://arxiv.org/pdf/2010.02502.pdf) + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + dtype = sample.dtype + batch_size = timestep.shape[0] + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep].reshape(-1, 1, 1) + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] + if prev_timestep[0] >= 0 + else self.final_alpha_cumprod.repeat(batch_size) + ) + alpha_prod_t_prev = alpha_prod_t_prev.reshape(-1, 1, 1) + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from (https://arxiv.org/pdf/2010.02502.pdf) + if self.prediction_type == "epsilon": + pred_original_sample = ( + (sample - (beta_prod_t ** (0.5)).to(dtype) * model_output) + / alpha_prod_t ** (0.5) + ).to(dtype) + pred_epsilon = model_output + elif self.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = ( + (sample - (alpha_prod_t ** (0.5)).to(dtype) * pred_original_sample) + / beta_prod_t ** (0.5) + ).to(dtype) + elif self.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5).to(dtype) * sample - ( + beta_prod_t**0.5 + ).to(dtype) * model_output + pred_epsilon = (alpha_prod_t**0.5).to(dtype) * model_output + ( + beta_prod_t**0.5 + ).to(dtype) * sample + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + # np.save('pred_original_sample.npy', pred_original_sample) + # 4. Clip or threshold "predicted x_0" + if self.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.clip_sample_range, self.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + + if eta == 0: + std_dev_t = 0 + else: + std_dev_t = (eta * ops.sqrt(variance)).to(dtype).reshape(-1, 1, 1) + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = ( + (sample - (alpha_prod_t ** (0.5)).to(dtype) * pred_original_sample) + / beta_prod_t ** (0.5) + ).to(dtype) + + # 6. compute "direction pointing to x_t" of formula (12) from (https://arxiv.org/pdf/2010.02502.pdf) + pred_sample_direction = ((1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5)).to( + dtype + ) * pred_epsilon + # 7. compute x_t without "random noise" of formula (12) from (https://arxiv.org/pdf/2010.02502.pdf) + coef = (alpha_prod_t_prev ** (0.5)).to(dtype) + prev_sample = coef * pred_original_sample + pred_sample_direction + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = ops.randn(model_output.shape, dtype=dtype) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance + + return prev_sample + + def add_noise(self, + original_samples: ms.Tensor, + noise: ms.Tensor, + timesteps: ms.Tensor, + ): + """add noise into sample""" + broadcast_shape = original_samples.shape + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + + sqrt_alpha_prod = ops.reshape( + sqrt_alpha_prod, + (timesteps.reshape((-1,)).shape[0],) + + (1,) * (len(broadcast_shape) - 1), + ) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + + sqrt_one_minus_alpha_prod = ops.reshape( + sqrt_one_minus_alpha_prod, + (timesteps.reshape((-1,)).shape[0],) + + (1,) * (len(broadcast_shape) - 1), + ) + + noisy_samples = ( + sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + ) + return noisy_samples diff --git a/tests/st/mindflow/cell/diffusion/ddpm_gt.py b/tests/st/mindflow/cell/diffusion/ddpm_gt.py new file mode 100644 index 000000000..dc687a0bb --- /dev/null +++ b/tests/st/mindflow/cell/diffusion/ddpm_gt.py @@ -0,0 +1,473 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""mo-ddpm api""" +from typing import List, Optional, Union + +import numpy as np + +import mindspore as ms +from mindspore import ops +import mindspore.common.dtype as mstype + +from utils import rescale_zero_terminal_snr, betas_for_alpha_bar + + +class DDPMScheduler: + """ + `DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + An array of betas to pass directly to the constructor without using `beta_start` and `beta_end`. + variance_type (`str`, defaults to `"fixed_small"`): + Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, + `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + def __init__(self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "squaredcos_cap_v2", + trained_betas: Optional[Union[np.ndarray, + List[float]]] = None, + variance_type: str = "fixed_small_log", + clip_sample: bool = True, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 0, + rescale_betas_zero_snr: int = False, + ): + if trained_betas is not None: + self.betas = ms.tensor(trained_betas, dtype=ms.float32) + elif beta_schedule == "linear": + self.betas = ms.tensor( + np.linspace(beta_start, beta_end, num_train_timesteps), dtype=ms.float32 + ) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + ms.tensor( + np.linspace(beta_start**0.5, beta_end ** + 0.5, num_train_timesteps), + dtype=ms.float32, + ) + ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = ms.tensor( + np.linspace(-6, 6, num_train_timesteps), dtype=ms.float32) + self.betas = ops.sigmoid( + betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}" + ) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = ops.cumprod(self.alphas, dim=0) + self.one = ms.Tensor(1.0) + + # standard deviation of the initial noise distribution + + # setable values + self.custom_timesteps = False + self.num_inference_steps = None + self.timesteps = ms.Tensor( + np.arange(0, num_train_timesteps)[::-1].copy()) + self.num_train_timesteps = num_train_timesteps + self.timestep_spacing = timestep_spacing + self.variance_type = variance_type + self.prediction_type = prediction_type + self.clip_sample = clip_sample + self.clip_sample_range = clip_sample_range + self.sample_max_value = sample_max_value + self.thresholding = thresholding + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.steps_offset = steps_offset + self.rescale_betas_zero_snr = rescale_betas_zero_snr + + def set_timesteps(self, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed, + `num_inference_steps` must be `None`. + + """ + if num_inference_steps is not None and timesteps is not None: + raise ValueError( + "Can only pass one of `num_inference_steps` or `custom_timesteps`." + ) + + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError( + "`custom_timesteps` must be in descending order.") + + if timesteps[0] >= self.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.train_timesteps`:" + f" {self.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True + else: + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + self.custom_timesteps = False + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of (https://arxiv.org/abs/2305.08891) + if self.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.num_train_timesteps - + 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.timestep_spacing == "leading": + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps) * step_ratio) + .round()[::-1] + .copy() + .astype(np.int64) + ) + timesteps += self.steps_offset + elif self.timestep_spacing == "trailing": + step_ratio = self.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round( + np.arange(self.num_train_timesteps, 0, -step_ratio) + ).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.timestep_spacing} is not supported. Please make sure to choose" + "one of 'linspace', 'leading' or 'trailing'." + ) + + self.timesteps = ms.Tensor(timesteps) + + def _get_variance(self, t, predicted_variance=None, variance_type=None): + """get variance""" + prev_t = self.previous_timestep(t) + alpha_prod_t = self.alphas_cumprod[t] + prev_t = ops.where(prev_t >= 0, prev_t, self.one).astype(mstype.int32) + alpha_prod_t_prev = self.alphas_cumprod[prev_t] + current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from (https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / \ + (1 - alpha_prod_t) * current_beta_t + + # we always take the log of variance, so clamp it to ensure it's not 0 + variance = ops.clamp(variance, min=1e-20) + + if variance_type is None: + variance_type = self.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small": + variance = variance + # for rl-diffuser (https://arxiv.org/abs/2205.09991) + elif variance_type == "fixed_small_log": + variance = ops.log(variance) + variance = ops.exp(0.5 * variance) + elif variance_type == "fixed_large": + variance = current_beta_t + elif variance_type == "fixed_large_log": + # Glide max_log + variance = ops.log(current_beta_t) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = ops.log(variance) + max_log = ops.log(current_beta_t) + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def _threshold_sample(self, sample: ms.Tensor) -> ms.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + see (https://arxiv.org/abs/2205.11487) + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (ms.float32, ms.float64): + # upcast for quantile calculation, and clamp not implemented for cpu half + sample = sample.float() + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * + np.prod(remaining_dims).item()) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = ms.Tensor.from_numpy( + np.quantile(abs_sample.asnumpy(), + self.dynamic_thresholding_ratio, axis=1) + ) + s = ops.clamp( + s, min=1, max=self.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + # (batch_size, 1) because clamp will broadcast along dim=0 + s = s.unsqueeze(1) + # "we threshold xt0 to the range [-s, s] and then divide by s" + sample = ops.clamp(sample, -s, s) / s + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def step(self, + model_output: ms.Tensor, + timestep: int, + sample: ms.Tensor, + ): + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`ms.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`ms.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + t = timestep + dtype = sample.dtype + + prev_t = self.previous_timestep(t) + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", + "learned_range", + ]: + model_output, predicted_variance = ops.split( + model_output, sample.shape[1], axis=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t].reshape(-1, 1, 1) + prev_t = ops.where(prev_t >= 0, prev_t, self.one).astype(mstype.int32) + + alpha_prod_t_prev = self.alphas_cumprod[prev_t].reshape(-1, 1, 1) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + current_alpha_t = alpha_prod_t / alpha_prod_t_prev + current_beta_t = 1 - current_alpha_t + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from (https://arxiv.org/pdf/2006.11239.pdf) + if self.prediction_type == "epsilon": + pred_original_sample = ( + (sample - (beta_prod_t ** (0.5) * model_output).to(dtype)) + / alpha_prod_t ** (0.5) + ).to(dtype) + elif self.prediction_type == "sample": + pred_original_sample = model_output + elif self.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5).to(dtype) * sample - ( + beta_prod_t**0.5 + ).to(dtype) * model_output + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." + ) + + # 3. Clip or threshold "predicted x_0" + if self.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.clip_sample_range, self.clip_sample_range + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from (https://arxiv.org/pdf/2006.11239.pdf) + pred_original_sample_coeff = ( + alpha_prod_t_prev ** (0.5) * current_beta_t + ) / beta_prod_t + current_sample_coeff = current_alpha_t ** ( + 0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from (https://arxiv.org/pdf/2006.11239.pdf) + pred_prev_sample = ( + pred_original_sample_coeff.to(dtype) * pred_original_sample + + current_sample_coeff.to(dtype) * sample + ) + + # 6. Add noise + variance = 0 + if t[0] > 0: + v = self._get_variance(t, predicted_variance=predicted_variance).reshape(-1, 1, 1) + # set pseudo noise for testcase + # variance_noise = ops.randn( + # model_output.shape, dtype=model_output.dtype) + variance_noise = model_output + if self.variance_type == "fixed_small_log": + variance = v.to(dtype) * variance_noise + elif self.variance_type == "learned_range": + variance = ops.exp(0.5 * v).to(dtype) * variance_noise + else: + variance = ops.sqrt(v).to(dtype) * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample + + def add_noise(self, + original_samples: ms.Tensor, + noise: ms.Tensor, + timesteps: ms.Tensor, # ms.int32 + ) -> ms.Tensor: + """add noist to sample""" + broadcast_shape = original_samples.shape + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + # while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + # sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + sqrt_alpha_prod = ops.reshape( + sqrt_alpha_prod, (timesteps.shape[0],) + + (1,) * (len(broadcast_shape) - 1) + ) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + # while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + # sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + sqrt_one_minus_alpha_prod = ops.reshape( + sqrt_one_minus_alpha_prod, + (timesteps.shape[0],) + (1,) * (len(broadcast_shape) - 1), + ) + + noisy_samples = ( + sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + ) + return noisy_samples + + def previous_timestep(self, timestep): + """get previous timestep""" + if self.custom_timesteps: + index = (self.timesteps == timestep).nonzero()[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = ms.Tensor(-1) + else: + prev_t = self.timesteps[index + 1] + else: + num_inference_steps = ( + self.num_inference_steps + if self.num_inference_steps + else self.num_train_timesteps + ) + prev_t = timestep - self.num_train_timesteps // num_inference_steps + + return prev_t diff --git a/tests/st/mindflow/cell/diffusion/test_diffusion.py b/tests/st/mindflow/cell/diffusion/test_diffusion.py new file mode 100644 index 000000000..45c82b65f --- /dev/null +++ b/tests/st/mindflow/cell/diffusion/test_diffusion.py @@ -0,0 +1,532 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""diffusion api testcase""" +# pylint: disable=C0413 +import os +import sys +import pytest + +import numpy as np +from mindspore import Tensor, ops +from mindspore import dtype as mstype + +from mindflow.cell import DiffusionScheduler, DDPMPipeline, DDIMPipeline, DDPMScheduler, DDIMScheduler, \ + DiffusionTransformer, ConditionDiffusionTransformer + +PROJECT_ROOT = os.path.abspath(os.path.join( + os.path.dirname(__file__), "../../../")) +sys.path.append(PROJECT_ROOT) + +from common.cell import FP32_RTOL, FP32_ATOL +from common.cell import compare_output + +from ddim_gt import DDIMScheduler as DDIMSchedulerGt +from ddpm_gt import DDPMScheduler as DDPMSchedulerGt +from dataset import load_data + +BATCH_SIZE, SEQ_LEN, IN_CHANNELS, OUT_CHANNELS, HIDDEN_CHANNELS, COND_CHANNELS = 8, 256, 16, 16, 64, 4 +LAYERS, HEADS, TRAIN_TIMESTEPS = 3, 4, 100 + + +def extract(a, t, x_shape): + """calculate a[timestep]""" + b = t.shape[0] + out = Tensor(a).gather(t, -1) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +class DDPMSchedulerMock(DDPMScheduler): + """modify `DDPMScheduler` step function""" + + def step(self, model_output: Tensor, sample: Tensor, timestep: Tensor, predicted_variance: Tensor = None): + """denoise function""" + pred_original_sample = self._pred_origin_sample( + model_output, sample, timestep) + + pred_prev_sample = ( + extract(self.posterior_mean_coef1, timestep, sample.shape)*pred_original_sample + + extract(self.posterior_mean_coef2, timestep, sample.shape)*sample + ) + + # 3. Add noise + v = self._get_variance(sample, timestep, predicted_variance) + variance = 0 + if timestep[0] > 0: + variance_noise = model_output + if self.variance_type == "fixed_small_log": + variance = v * variance_noise + elif self.variance_type == "learned_range": + variance = ops.exp(0.5 * v) * variance_noise + else: + variance = (v ** 0.5) * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample + + +def generate_inputs(): + """generate npy files""" + original_samples = np.random.rand(BATCH_SIZE, SEQ_LEN, IN_CHANNELS) + noise = np.random.rand(BATCH_SIZE, SEQ_LEN, IN_CHANNELS) + condition = np.random.rand(BATCH_SIZE, COND_CHANNELS) + timesteps = np.random.randint(30, 100, (BATCH_SIZE,)) + np.save('noise.npy', noise) + np.save('original_samples.npy', original_samples) + np.save('condition.npy', condition) + np.save('timesteps.npy', timesteps) + + +def load_inputs(dtype=mstype.float32): + """load npy""" + sample = load_data('original_samples.npy') + noise = load_data('noise.npy') + cond = load_data('condition.npy') + t = load_data('timesteps.npy') + + original_samples = Tensor(sample).astype(dtype) + noise = Tensor(noise).astype(dtype) + condition = Tensor(cond).astype(dtype) + timesteps = Tensor(t).astype(mstype.int32) + return original_samples, noise, condition, timesteps + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_ddim_step(): + """ + Feature: DDIM step + Description: test DDIM step function + Expectation: success + """ + compute_dtype = mstype.float32 + scheduler = DDIMScheduler(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=None, + thresholding=False, + prediction_type="epsilon", + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + scheduler_gt = DDIMSchedulerGt(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + trained_betas=None, + thresholding=False, + clip_sample=False, + sample_max_value=1.0, + prediction_type="epsilon", + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading") + scheduler.set_timesteps(20) + scheduler_gt.set_timesteps(20) + original_samples, noise, _, _ = load_inputs() + timesteps = Tensor(np.array([60]*BATCH_SIZE), dtype=mstype.int32) + x_prev = scheduler.step(noise, original_samples, timesteps).numpy() + x_prev_gt = scheduler_gt.step(noise, timesteps, original_samples).numpy() + validate_ans = compare_output(x_prev, x_prev_gt, rtol=FP32_RTOL, atol=FP32_ATOL) + assert validate_ans, f"test_ddim_step failed: {validate_ans}" + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_ddpm_step(): + """ + Feature: DDPM step + Description: test DDPM step function + Expectation: success + """ + compute_dtype = mstype.float32 + scheduler = DDPMSchedulerMock(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=None, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + scheduler_gt = DDPMSchedulerGt(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + trained_betas=None, + thresholding=False, + clip_sample=None, + sample_max_value=1.0, + prediction_type="epsilon", + variance_type="fixed_small_log", + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading") + original_samples, noise, _, _ = load_inputs() + scheduler.set_timesteps(TRAIN_TIMESTEPS) + scheduler_gt.set_timesteps(TRAIN_TIMESTEPS) + timesteps = Tensor(np.array([60]*BATCH_SIZE), dtype=mstype.int32) + x_prev = scheduler.step(noise, original_samples, timesteps).numpy() + x_prev_gt = scheduler_gt.step(noise, timesteps, original_samples).numpy() + validate_ans = compare_output(x_prev, x_prev_gt, rtol=FP32_RTOL, atol=FP32_ATOL) + assert validate_ans, f'test_ddpm_step failed: {validate_ans}' + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_set_timestep(): + """ + Feature: diffusion set inference timesteps + Description: test diffusion set timesteps function + Expectation: success + """ + compute_dtype = mstype.float32 + scheduler = DiffusionScheduler(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + + try: + scheduler.set_timesteps(1000) + except ValueError: + pass + else: + raise Exception("inference steps > train steps. Expected ValueError") + + +# pylint: disable=W0212 +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_diffusion_pred_origin_sample(): + """ + Feature: diffusion add noise + Description: test diffusion pred origin sample + Expectation: success + """ + compute_dtype = mstype.float32 + scheduler = DiffusionScheduler(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=None, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + + original_samples, noise, _, _ = load_inputs() + scheduler.set_timesteps(TRAIN_TIMESTEPS) + timesteps = Tensor(np.array([60]*BATCH_SIZE), dtype=mstype.int32) + pred_original_sample = scheduler._pred_origin_sample( + noise, original_samples, timesteps).numpy() + pred_original_sample_gt = load_data('ddpm_pred_original_sample.npy') + validate_ans = compare_output(pred_original_sample, pred_original_sample_gt, rtol=FP32_RTOL, atol=FP32_ATOL) + assert validate_ans, f'test_diffusion_pred_origin_sample failed: {validate_ans}' + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_ddpm_set_timestep(): + """ + Feature: DDPM set timesteps + Description: test DDPM set timesteps function + Expectation: success + """ + compute_dtype = mstype.float32 + scheduler = DDPMScheduler(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + + try: + scheduler.set_timesteps(50) + except ValueError as e: + print(e) + else: + raise Exception( + "DDPM num_inference_steps defaults to num_train_timesteps. Expected ValueError") + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_ddpm_pipe_no_cond(): + """ + Feature: DDPM pipeline generation + Description: test DDPM pipeline API with no condition input + Expectation: success + """ + compute_dtype = mstype.float32 + scheduler = DDPMScheduler(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + net = DiffusionTransformer(in_channels=IN_CHANNELS, + out_channels=IN_CHANNELS, + hidden_channels=HIDDEN_CHANNELS, + layers=LAYERS, + heads=HEADS, + time_token_cond=True, + compute_dtype=compute_dtype) + pipe = DDPMPipeline(model=net, scheduler=scheduler, + batch_size=BATCH_SIZE, seq_len=SEQ_LEN, num_inference_steps=TRAIN_TIMESTEPS) + + generated_samples = pipe() + assert generated_samples.shape == (BATCH_SIZE, SEQ_LEN, IN_CHANNELS) + assert generated_samples.dtype == compute_dtype + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_ddpm_pipe_cond(): + """ + Feature: DDPM pipeline generation + Description: test DDPM pipeline API with condition input + Expectation: success + """ + compute_dtype = mstype.float32 + scheduler = DDPMScheduler(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + _, _, cond, _ = load_inputs() + net = ConditionDiffusionTransformer(in_channels=IN_CHANNELS, + out_channels=IN_CHANNELS, + cond_channels=COND_CHANNELS, + hidden_channels=HIDDEN_CHANNELS, + layers=LAYERS, + heads=HEADS, + time_token_cond=True, + compute_dtype=compute_dtype) + pipe = DDPMPipeline(model=net, scheduler=scheduler, + batch_size=BATCH_SIZE, seq_len=SEQ_LEN, num_inference_steps=TRAIN_TIMESTEPS) + + generated_samples = pipe(cond) + assert generated_samples.shape == (BATCH_SIZE, SEQ_LEN, IN_CHANNELS) + assert generated_samples.dtype == compute_dtype + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_ddim_pipe_no_cond(): + """ + Feature: DDIM pipeline generation + Description: test DDIM pipeline API with no condition input + Expectation: success + """ + compute_dtype = mstype.float32 + scheduler = DDIMScheduler(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + net = DiffusionTransformer(in_channels=IN_CHANNELS, + out_channels=IN_CHANNELS, + hidden_channels=HIDDEN_CHANNELS, + layers=LAYERS, + heads=HEADS, + time_token_cond=True, + compute_dtype=compute_dtype) + pipe = DDIMPipeline(model=net, scheduler=scheduler, + batch_size=BATCH_SIZE, seq_len=SEQ_LEN, num_inference_steps=50) + + generated_samples = pipe() + assert generated_samples.shape == (BATCH_SIZE, SEQ_LEN, IN_CHANNELS) + assert generated_samples.dtype == compute_dtype + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_ddim_pipe_cond(): + """ + Feature: DDIM pipeline generation + Description: test DDIM pipeline API with condition input + Expectation: success + """ + compute_dtype = mstype.float32 + scheduler = DDIMScheduler(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + _, _, cond, _ = load_inputs() + net = ConditionDiffusionTransformer(in_channels=IN_CHANNELS, + out_channels=IN_CHANNELS, + cond_channels=COND_CHANNELS, + hidden_channels=HIDDEN_CHANNELS, + layers=LAYERS, + heads=HEADS, + time_token_cond=True, + compute_dtype=compute_dtype) + pipe = DDIMPipeline(model=net, scheduler=scheduler, + batch_size=BATCH_SIZE, seq_len=SEQ_LEN, num_inference_steps=50) + + generated_samples = pipe(cond) + assert generated_samples.shape == (BATCH_SIZE, SEQ_LEN, IN_CHANNELS) + assert generated_samples.dtype == compute_dtype + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_diffusion_addnoise_fp32(): + """ + Feature: diffusion add noise + Description: test diffusion-fp32 add noise + Expectation: success + """ + compute_dtype = mstype.float32 + scheduler = DiffusionScheduler(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + prediction_type='epsilon', + clip_sample_range=1.0, + thresholding=False, + sample_max_value=1., + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + + original_samples, noise, _, timesteps = load_inputs() + noised_sample = scheduler.add_noise(original_samples, noise, timesteps) + noised_sample_gt = load_data('ddpm_noised_sample.npy') + assert noised_sample.dtype == compute_dtype + assert noised_sample.shape == (BATCH_SIZE, SEQ_LEN, OUT_CHANNELS) + assert np.allclose(noised_sample.numpy(), noised_sample_gt) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_diffusion_addnoise_fp16(): + """ + Feature: diffusion add noise + Description: test diffusion-fp16 add noise + Expectation: success + """ + compute_dtype = mstype.float16 + scheduler = DiffusionScheduler(num_train_timesteps=TRAIN_TIMESTEPS, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + + original_samples, noise, _, timesteps = load_inputs(mstype.float16) + noised_sample = scheduler.add_noise(original_samples, noise, timesteps) + assert noised_sample.dtype == compute_dtype + assert noised_sample.shape == (BATCH_SIZE, SEQ_LEN, OUT_CHANNELS) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('compute_dtype', [mstype.float16, mstype.float32]) +def test_diffusion_transfomer(compute_dtype): + """ + Feature: Diffusion transformer + Description: test diffusion transformer dtype and shape + Expectation: success + """ + net = DiffusionTransformer(in_channels=IN_CHANNELS, + out_channels=OUT_CHANNELS, + hidden_channels=HIDDEN_CHANNELS, + layers=LAYERS, + heads=HEADS, + time_token_cond=True, + compute_dtype=compute_dtype) + x, _, _, timesteps = load_inputs() + out = net(x, timesteps) + assert out.shape == (BATCH_SIZE, SEQ_LEN, OUT_CHANNELS) + assert out.dtype == compute_dtype + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('compute_dtype', [mstype.float16, mstype.float32]) +def test_cond_diffusion_transfomer(compute_dtype): + """ + Feature: Condition Diffusion transformer + Description: test diffusion transformer dtype and shape + Expectation: success + """ + net = ConditionDiffusionTransformer(in_channels=IN_CHANNELS, + out_channels=OUT_CHANNELS, + cond_channels=COND_CHANNELS, + hidden_channels=HIDDEN_CHANNELS, + layers=LAYERS, + heads=HEADS, + cond_as_token=False, + time_token_cond=True, + compute_dtype=compute_dtype) + x, _, condition, timesteps = load_inputs() + out = net(x, timesteps, condition) + assert out.shape == (BATCH_SIZE, SEQ_LEN, OUT_CHANNELS) + assert out.dtype == compute_dtype diff --git a/tests/st/mindflow/cell/diffusion/test_diffusion_train.py b/tests/st/mindflow/cell/diffusion/test_diffusion_train.py new file mode 100644 index 000000000..97c4e4580 --- /dev/null +++ b/tests/st/mindflow/cell/diffusion/test_diffusion_train.py @@ -0,0 +1,227 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""diffusion ST testcase""" +import time +import pytest + +import numpy as np +from mindspore import Tensor, ops, amp, nn, jit +from mindspore import dtype as mstype + +from mindflow.cell import DiffusionTransformer, DiffusionTrainer, DDPMScheduler, DDIMScheduler, DDPMPipeline, \ + DDIMPipeline, ConditionDiffusionTransformer +from dataset import get_latent_dataset +from ae import LATENT_DIM, generate_image + + +def count_params(model): + """count_params""" + count = 0 + for p in model.trainable_params(): + t = 1 + for i in range(len(p.shape)): + t = t * p.shape[i] + count += t + print(model) + print(f'model params: {count}') + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_train_diffusion(diffusion_type='ddpm', latent_path='latent.npy', cond_path='labels.npy'): + """ + Feature: diffusion model train + Description: test DDPM/DDIM model generate MNIST image latent + Expectation: success + """ + compute_dtype = mstype.float32 + in_dim = LATENT_DIM + + model = DiffusionTransformer(in_channels=in_dim, + out_channels=in_dim, + hidden_channels=256, + layers=1, + heads=8, + time_token_cond=True, + compute_dtype=compute_dtype) + count_params(model) + num_train_timesteps = 500 + batch_size = 512 + infer_bs = 10 + if diffusion_type == 'ddpm': + scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + pipe = DDPMPipeline(model=model, scheduler=scheduler, batch_size=infer_bs, + seq_len=1, num_inference_steps=num_train_timesteps) + else: + scheduler = DDIMScheduler(num_train_timesteps=num_train_timesteps, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + num_inference_steps = 50 + pipe = DDIMPipeline(model=model, scheduler=scheduler, batch_size=infer_bs, + seq_len=1, num_inference_steps=num_inference_steps) + + trainer = DiffusionTrainer(model, + scheduler, + objective='pred_noise', + p2_loss_weight_gamma=0, + p2_loss_weight_k=1, + loss_type='l2') + + def forward_fn(data, t, noise): + loss = trainer.get_loss(data, noise, t) + return loss + lr = 0.0002 + optimizer = nn.Adam(model.trainable_params(), lr, 0.9, 0.99) + grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters) + + @jit + def train_step(data, t, noise): + loss, grads = grad_fn(data, t, noise) + is_finite = amp.all_finite(grads) + if is_finite: + loss = ops.depend(loss, optimizer(grads)) + return loss + + train_dataset = get_latent_dataset( + latent_path=latent_path, cond_path=cond_path, batch_size=batch_size) + epochs = 20 + print(f'scheduler.num_timesteps: {scheduler.num_timesteps}') + + for epoch in range(1, epochs+1): + time_beg = time.time() + for samples, _ in train_dataset: + timesteps = Tensor(np.random.randint( + 0, num_train_timesteps, (samples.shape[0],)).astype(np.int32)) + noise = Tensor(np.random.randn(*samples.shape), mstype.float32) + step_train_loss = train_step(samples, timesteps, noise) + print( + f"epoch: {epoch} train loss: {step_train_loss.asnumpy()} epoch time: {time.time() - time_beg:5.3f}s") + if epoch % 10 == 0: + result = pipe() + generate_image(result, diffusion_type) + print(f'save generated images') + assert step_train_loss < 0.4 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_train_diffusion_cond(diffusion_type='ddpm', latent_path='latent.npy', cond_path='labels.npy'): + """ + Feature: conditional diffusion model train + Description: test DDPM/DDIM model generate MNIST image latent + Expectation: success + """ + compute_dtype = mstype.float32 + in_dim = LATENT_DIM + model = ConditionDiffusionTransformer(in_channels=in_dim, + out_channels=in_dim, + cond_channels=1, + hidden_channels=256, + layers=1, + heads=8, + time_token_cond=True, + compute_dtype=compute_dtype) + + count_params(model) + num_train_timesteps = 500 + batch_size = 512 + infer_bs = 10 + if diffusion_type == 'ddpm': + scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + pipe = DDPMPipeline(model=model, scheduler=scheduler, batch_size=infer_bs, + seq_len=1, num_inference_steps=num_train_timesteps) + else: + scheduler = DDIMScheduler(num_train_timesteps=num_train_timesteps, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="squaredcos_cap_v2", + clip_sample_range=1.0, + thresholding=False, + dynamic_thresholding_ratio=None, + rescale_betas_zero_snr=False, + timestep_spacing="leading", + compute_dtype=compute_dtype) + num_inference_steps = 50 + pipe = DDIMPipeline(model=model, scheduler=scheduler, batch_size=infer_bs, + seq_len=1, num_inference_steps=num_inference_steps) + + trainer = DiffusionTrainer(model, + scheduler, + objective='pred_noise', + p2_loss_weight_gamma=0, + p2_loss_weight_k=1, + loss_type='l2') + + def forward_fn(data, t, noise, cond): + loss = trainer.get_loss(data, noise, t, cond) + return loss + lr = 0.0002 + optimizer = nn.Adam(model.trainable_params(), lr, 0.9, 0.99) + grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters) + + @jit + def train_step(data, t, noise, cond): + loss, grads = grad_fn(data, t, noise, cond) + is_finite = amp.all_finite(grads) + if is_finite: + loss = ops.depend(loss, optimizer(grads)) + return loss + + train_dataset = get_latent_dataset( + latent_path=latent_path, cond_path=cond_path, batch_size=batch_size) + epochs = 20 + print(f'scheduler.num_timesteps: {scheduler.num_timesteps}') + infer_cond = ops.arange(0, 10, dtype=mstype.float32).reshape(10, 1) + for epoch in range(1, epochs+1): + time_beg = time.time() + for samples, cond in train_dataset: + timesteps = Tensor(np.random.randint( + 0, num_train_timesteps, (samples.shape[0],)).astype(np.int32)) + noise = Tensor(np.random.randn(*samples.shape), mstype.float32) + step_train_loss = train_step(samples, timesteps, noise, cond) + print( + f"epoch: {epoch} train loss: {step_train_loss.asnumpy()} epoch time: {time.time() - time_beg:5.3f}s") + if epoch % 10 == 0: + result = pipe(infer_cond) + generate_image(result, diffusion_type) + print(f'save generated conditional images') + assert step_train_loss < 0.4 diff --git a/tests/st/mindflow/cell/diffusion/utils.py b/tests/st/mindflow/cell/diffusion/utils.py new file mode 100644 index 000000000..723d4999a --- /dev/null +++ b/tests/st/mindflow/cell/diffusion/utils.py @@ -0,0 +1,98 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""function""" +import math + +import mindspore as ms +from mindspore import ops + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine"): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`ms.Tensor`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError( + f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return ms.tensor(betas, dtype=ms.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on (https://arxiv.org/pdf/2305.08891.pdf) (Algorithm 1) + + + Args: + betas (`ms.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `ms.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = ops.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy() + alphas_bar_sqrt_t = alphas_bar_sqrt[-1].copy() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_t + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / \ + (alphas_bar_sqrt_0 - alphas_bar_sqrt_t) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = ops.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas -- Gitee