代码拉取完成,页面将自动刷新
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
from cog import BasePredictor, Input, Path
import torch
import numpy as np
import random
import os
import shutil
import subprocess
import time
os.environ["HF_HUB_CACHE"] = "models"
os.environ["HF_HUB_CACHE_OFFLINE"] = "true"
from diffusers.utils import load_image
from diffusers import EulerDiscreteScheduler
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from huggingface_hub import hf_hub_download
from transformers import CLIPImageProcessor
from photomaker import PhotoMakerStableDiffusionXLPipeline
from gradio_demo.style_template import styles
MAX_SEED = np.iinfo(np.int32).max
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "Photographic (Default)"
FEATURE_EXTRACTOR = "./feature-extractor"
SAFETY_CACHE = "./models/safety-cache"
SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar"
BASE_MODEL_URL = "https://weights.replicate.delivery/default/SG161222--RealVisXL_V3.0-11ee564ebf4bd96d90ed5d473cb8e7f2e6450bcf.tar"
BASE_MODEL_PATH = "models/SG161222/RealVisXL_V3.0"
PHOTOMAKER_URL = "https://weights.replicate.delivery/default/TencentARC--PhotoMaker/photomaker-v1.bin"
PHOTOMAKER_PATH = "models/photomaker-v1.bin"
def download_weights(url, dest, extract=True):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
args = ["pget"]
if extract:
args.append("-x")
subprocess.check_call(args + [url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
# utility function for style templates
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
return p.replace("{prompt}", positive), n + " " + negative
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# download PhotoMaker checkpoint to cache
# if we already have the model, this doesn't do anything
if not os.path.exists(PHOTOMAKER_PATH):
download_weights(PHOTOMAKER_URL, PHOTOMAKER_PATH, extract=False)
if not os.path.exists(BASE_MODEL_PATH):
download_weights(BASE_MODEL_URL, BASE_MODEL_PATH)
print("Loading safety checker...")
if not os.path.exists(SAFETY_CACHE):
download_weights(SAFETY_URL, SAFETY_CACHE)
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_CACHE, torch_dtype=torch.float16
).to("cuda")
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
self.pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(
BASE_MODEL_PATH,
torch_dtype=torch.bfloat16,
use_safetensors=True,
variant="fp16",
).to(self.device)
self.pipe.load_photomaker_adapter(
os.path.dirname(PHOTOMAKER_PATH),
subfolder="",
weight_name=os.path.basename(PHOTOMAKER_PATH),
trigger_word="img",
)
self.pipe.id_encoder.to(self.device)
self.pipe.scheduler = EulerDiscreteScheduler.from_config(
self.pipe.scheduler.config
)
self.pipe.fuse_lora()
@torch.inference_mode()
def predict(
self,
input_image: Path = Input(
description="The input image, for example a photo of your face."
),
input_image2: Path = Input(
description="Additional input image (optional)",
default=None
),
input_image3: Path = Input(
description="Additional input image (optional)",
default=None
),
input_image4: Path = Input(
description="Additional input image (optional)",
default=None
),
prompt: str = Input(
description="Prompt. Example: 'a photo of a man/woman img'. The phrase 'img' is the trigger word.",
default="A photo of a person img",
),
style_name: str = Input(
description="Style template. The style template will add a style-specific prompt and negative prompt to the user's prompt.",
choices=STYLE_NAMES,
default=DEFAULT_STYLE_NAME,
),
negative_prompt: str = Input(
description="Negative Prompt. The negative prompt should NOT contain the trigger word.",
default="nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
),
num_steps: int = Input(
description="Number of sample steps", default=20, ge=1, le=100
),
style_strength_ratio: float = Input(
description="Style strength (%)", default=20, ge=15, le=50
),
num_outputs: int = Input(
description="Number of output images", default=1, ge=1, le=4
),
guidance_scale: float = Input(
description="Guidance scale. A guidance scale of 1 corresponds to doing no classifier free guidance.", default=5, ge=1, le=10.0
),
seed: int = Input(description="Seed. Leave blank to use a random number", default=None, ge=0, le=MAX_SEED),
disable_safety_checker: bool = Input(
description="Disable safety checker for generated images.",
default=False
)
) -> list[Path]:
"""Run a single prediction on the model"""
# remove old outputs
output_folder = Path('outputs')
if output_folder.exists():
shutil.rmtree(output_folder)
os.makedirs(str(output_folder), exist_ok=False)
# randomize seed if necessary
if seed is None:
seed = random.randint(0, MAX_SEED)
print(f"Using seed {seed}...")
# check the prompt for the trigger word
image_token_id = self.pipe.tokenizer.convert_tokens_to_ids(self.pipe.trigger_word)
input_ids = self.pipe.tokenizer.encode(prompt)
if image_token_id not in input_ids:
raise ValueError(
f"Cannot find the trigger word '{self.pipe.trigger_word}' in text prompt!")
if input_ids.count(image_token_id) > 1:
raise ValueError(
f"Cannot use multiple trigger words '{self.pipe.trigger_word}' in text prompt!"
)
# check the negative prompt for the trigger word
if negative_prompt:
negative_prompt_ids = self.pipe.tokenizer.encode(negative_prompt)
if image_token_id in negative_prompt_ids:
raise ValueError(
f"Cannot use trigger word '{self.pipe.trigger_word}' in negative prompt!"
)
# apply the style template
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
# load the input images
input_id_images = []
for maybe_image in [input_image, input_image2, input_image3, input_image4]:
if maybe_image:
print(f"Loading image {maybe_image}...")
input_id_images.append(load_image(str(maybe_image)))
print(f"Setting seed...")
generator = torch.Generator(device=self.device).manual_seed(seed)
print("Start inference...")
print(f"[Debug] Prompt: {prompt}")
print(f"[Debug] Neg Prompt: {negative_prompt}")
start_merge_step = int(float(style_strength_ratio) / 100 * num_steps)
if start_merge_step > 30:
start_merge_step = 30
print(f"Start merge step: {start_merge_step}")
images = self.pipe(
prompt=prompt,
input_id_images=input_id_images,
negative_prompt=negative_prompt,
num_images_per_prompt=num_outputs,
num_inference_steps=num_steps,
start_merge_step=start_merge_step,
generator=generator,
guidance_scale=guidance_scale,
).images
if not disable_safety_checker:
print(f"Running safety checker...")
_, has_nsfw_content = self.run_safety_checker(images)
# save results to file
print(f"Saving images to file...")
output_paths = []
for i, image in enumerate(images):
if not disable_safety_checker:
if has_nsfw_content[i]:
print(f"NSFW content detected in image {i}")
continue
output_path = output_folder / f"image_{i}.png"
image.save(output_path)
output_paths.append(output_path)
return [Path(p) for p in output_paths]
def run_safety_checker(self, image):
safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(
"cuda"
)
np_image = [np.array(val) for val in image]
image, has_nsfw_concept = self.safety_checker(
images=np_image,
clip_input=safety_checker_input.pixel_values.to(torch.float16),
)
return image, has_nsfw_concept
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。