1 Star 0 Fork 2

chhw/step-video-ti2v

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
run_parallel.py 1.28 KB
一键复制 编辑 原始数据 按行查看 历史
xionghuixin 提交于 2025-03-15 11:14 +08:00 . add output_file_name param
from stepvideo.diffusion.video_pipeline import StepVideoPipeline
import torch.distributed as dist
import torch
from stepvideo.config import parse_args
from stepvideo.parallel import initialize_parall_group, get_parallel_group
from stepvideo.utils import setup_seed
if __name__ == "__main__":
args = parse_args()
initialize_parall_group(ring_degree=args.ring_degree, ulysses_degree=args.ulysses_degree)
local_rank = get_parallel_group().local_rank
device = torch.device(f"cuda:{local_rank}")
setup_seed(args.seed)
pipeline = StepVideoPipeline.from_pretrained(args.model_dir).to(dtype=torch.bfloat16, device="cpu")
pipeline.transformer = pipeline.transformer.to(device)
pipeline.setup_pipeline(args)
prompt = args.prompt
videos = pipeline(
prompt=prompt,
first_image=args.first_image_path,
num_frames=args.num_frames,
height=args.height,
width=args.width,
num_inference_steps = args.infer_steps,
guidance_scale=args.cfg_scale,
time_shift=args.time_shift,
pos_magic=args.pos_magic,
neg_magic=args.neg_magic,
output_file_name=args.output_file_name or prompt[:50],
motion_score=args.motion_score,
)
dist.destroy_process_group()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/chhw/step-video-ti2v.git
git@gitee.com:chhw/step-video-ti2v.git
chhw
step-video-ti2v
step-video-ti2v
main

搜索帮助