267 lines
9.1 KiB
Python
267 lines
9.1 KiB
Python
|
from fastapi import FastAPI, HTTPException
|
||
|
from pydantic import BaseModel
|
||
|
import base64
|
||
|
import io
|
||
|
from PIL import Image
|
||
|
import torch
|
||
|
import os
|
||
|
from transformers import CLIPVisionModelWithProjection
|
||
|
from huggingface_hub import snapshot_download
|
||
|
import numpy as np
|
||
|
import math
|
||
|
|
||
|
# Import custom modules (make sure these are in the same directory)
|
||
|
from preprocess.humanparsing.run_parsing import Parsing
|
||
|
from preprocess.dwpose import DWposeDetector
|
||
|
from src.pose_guider import PoseGuider
|
||
|
from src.pipeline_stable_diffusion_3_tryon import StableDiffusion3TryOnPipeline
|
||
|
from src.transformer_sd3_garm import SD3Transformer2DModel as SD3Transformer2DModel_Garm
|
||
|
from src.transformer_sd3_vton import SD3Transformer2DModel as SD3Transformer2DModel_Vton
|
||
|
from src.utils_mask import get_mask_location
|
||
|
|
||
|
app = FastAPI()
|
||
|
|
||
|
# Initialize models and configurations
|
||
|
weight_dtype = torch.bfloat16
|
||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
|
||
|
# Download model weights
|
||
|
fitdit_repo = "BoyuanJiang/FitDiT"
|
||
|
repo_path = snapshot_download(repo_id=fitdit_repo)
|
||
|
|
||
|
# Load models
|
||
|
transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(
|
||
|
os.path.join(repo_path, "transformer_garm"),
|
||
|
torch_dtype=weight_dtype
|
||
|
)
|
||
|
transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(
|
||
|
os.path.join(repo_path, "transformer_vton"),
|
||
|
torch_dtype=weight_dtype
|
||
|
)
|
||
|
pose_guider = PoseGuider(
|
||
|
conditioning_embedding_channels=1536,
|
||
|
conditioning_channels=3,
|
||
|
block_out_channels=(32, 64, 256, 512)
|
||
|
)
|
||
|
pose_guider.load_state_dict(torch.load(os.path.join(repo_path, "pose_guider", "diffusion_pytorch_model.bin")))
|
||
|
|
||
|
# Load image encoders
|
||
|
image_encoder_large = CLIPVisionModelWithProjection.from_pretrained(
|
||
|
"openai/clip-vit-large-patch14",
|
||
|
torch_dtype=weight_dtype
|
||
|
)
|
||
|
image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained(
|
||
|
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
||
|
torch_dtype=weight_dtype
|
||
|
)
|
||
|
|
||
|
# Move models to device
|
||
|
pose_guider.to(device=device, dtype=weight_dtype)
|
||
|
image_encoder_large.to(device=device)
|
||
|
image_encoder_bigG.to(device=device)
|
||
|
|
||
|
# Initialize pipeline
|
||
|
pipeline = StableDiffusion3TryOnPipeline.from_pretrained(
|
||
|
repo_path,
|
||
|
torch_dtype=weight_dtype,
|
||
|
transformer_garm=transformer_garm,
|
||
|
transformer_vton=transformer_vton,
|
||
|
pose_guider=pose_guider,
|
||
|
image_encoder_large=image_encoder_large,
|
||
|
image_encoder_bigG=image_encoder_bigG
|
||
|
)
|
||
|
pipeline.to(device)
|
||
|
|
||
|
# Initialize processors
|
||
|
dwprocessor = DWposeDetector(model_root=repo_path, device=device)
|
||
|
parsing_model = Parsing(model_root=repo_path, device=device)
|
||
|
|
||
|
class TryOnRequest(BaseModel):
|
||
|
model_image: str # base64 encoded image
|
||
|
garment_image: str # base64 encoded image
|
||
|
category: str
|
||
|
resolution: str = "768x1024"
|
||
|
n_steps: int = 20
|
||
|
image_scale: float = 2.0
|
||
|
num_images: int = 1
|
||
|
seed: int = -1
|
||
|
offset_top: int = 0
|
||
|
offset_bottom: int = 0
|
||
|
offset_left: int = 0
|
||
|
offset_right: int = 0
|
||
|
|
||
|
def base64_to_image(base64_str: str) -> Image.Image:
|
||
|
"""Convert base64 string to PIL Image"""
|
||
|
try:
|
||
|
image_data = base64.b64decode(base64_str)
|
||
|
image = Image.open(io.BytesIO(image_data))
|
||
|
return image
|
||
|
except Exception as e:
|
||
|
raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}")
|
||
|
|
||
|
def image_to_base64(image: Image.Image) -> str:
|
||
|
"""Convert PIL Image to base64 string"""
|
||
|
buffered = io.BytesIO()
|
||
|
image.save(buffered, format="PNG")
|
||
|
return base64.b64encode(buffered.getvalue()).decode()
|
||
|
|
||
|
def pad_and_resize(im: Image.Image, new_width: int = 768, new_height: int = 1024, pad_color: tuple = (255, 255, 255), mode: int = Image.LANCZOS):
|
||
|
"""Pad and resize image to target dimensions while maintaining aspect ratio"""
|
||
|
old_width, old_height = im.size
|
||
|
|
||
|
ratio_w = new_width / old_width
|
||
|
ratio_h = new_height / old_height
|
||
|
if ratio_w < ratio_h:
|
||
|
new_size = (new_width, round(old_height * ratio_w))
|
||
|
else:
|
||
|
new_size = (round(old_width * ratio_h), new_height)
|
||
|
|
||
|
im_resized = im.resize(new_size, mode)
|
||
|
|
||
|
pad_w = math.ceil((new_width - im_resized.width) / 2)
|
||
|
pad_h = math.ceil((new_height - im_resized.height) / 2)
|
||
|
|
||
|
new_im = Image.new('RGB', (new_width, new_height), pad_color)
|
||
|
new_im.paste(im_resized, (pad_w, pad_h))
|
||
|
|
||
|
return new_im, pad_w, pad_h
|
||
|
|
||
|
def unpad_and_resize(padded_im: Image.Image, pad_w: int, pad_h: int, original_width: int, original_height: int) -> Image.Image:
|
||
|
"""Remove padding and resize image back to original dimensions"""
|
||
|
width, height = padded_im.size
|
||
|
|
||
|
left = pad_w
|
||
|
top = pad_h
|
||
|
right = width - pad_w
|
||
|
bottom = height - pad_h
|
||
|
|
||
|
cropped_im = padded_im.crop((left, top, right, bottom))
|
||
|
resized_im = cropped_im.resize((original_width, original_height), Image.LANCZOS)
|
||
|
|
||
|
return resized_im
|
||
|
|
||
|
def resize_image(img: Image.Image, target_size: int = 768) -> Image.Image:
|
||
|
"""Resize image maintaining aspect ratio"""
|
||
|
width, height = img.size
|
||
|
if width < height:
|
||
|
scale = target_size / width
|
||
|
else:
|
||
|
scale = target_size / height
|
||
|
|
||
|
new_width = int(round(width * scale))
|
||
|
new_height = int(round(height * scale))
|
||
|
return img.resize((new_width, new_height), Image.LANCZOS)
|
||
|
|
||
|
def generate_mask(
|
||
|
vton_img: Image.Image,
|
||
|
category: str,
|
||
|
offset_top: int,
|
||
|
offset_bottom: int,
|
||
|
offset_left: int,
|
||
|
offset_right: int
|
||
|
):
|
||
|
"""Generate mask for the model image"""
|
||
|
with torch.inference_mode():
|
||
|
vton_img_det = resize_image(vton_img)
|
||
|
pose_image, keypoints, _, candidate = dwprocessor(np.array(vton_img_det)[:,:,::-1])
|
||
|
candidate[candidate<0] = 0
|
||
|
candidate = candidate[0]
|
||
|
|
||
|
candidate[:, 0] *= vton_img_det.width
|
||
|
candidate[:, 1] *= vton_img_det.height
|
||
|
|
||
|
pose_image = pose_image[:,:,::-1]
|
||
|
pose_image = Image.fromarray(pose_image)
|
||
|
model_parse, _ = parsing_model(vton_img_det)
|
||
|
|
||
|
mask, mask_gray = get_mask_location(
|
||
|
category,
|
||
|
model_parse,
|
||
|
candidate,
|
||
|
model_parse.width,
|
||
|
model_parse.height,
|
||
|
offset_top,
|
||
|
offset_bottom,
|
||
|
offset_left,
|
||
|
offset_right
|
||
|
)
|
||
|
|
||
|
mask = mask.resize(vton_img.size)
|
||
|
mask_gray = mask_gray.resize(vton_img.size)
|
||
|
mask = mask.convert("L")
|
||
|
mask_gray = mask_gray.convert("L")
|
||
|
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
|
||
|
|
||
|
return {
|
||
|
'mask': mask,
|
||
|
'pose_image': pose_image,
|
||
|
'masked_image': masked_vton_img
|
||
|
}
|
||
|
|
||
|
@app.post("/try-on")
|
||
|
async def try_on(request: TryOnRequest):
|
||
|
try:
|
||
|
# Convert base64 to images
|
||
|
model_image = base64_to_image(request.model_image)
|
||
|
garment_image = base64_to_image(request.garment_image)
|
||
|
|
||
|
# Validate resolution
|
||
|
if request.resolution not in ["768x1024", "1152x1536", "1536x2048"]:
|
||
|
raise HTTPException(status_code=400, detail="Invalid resolution")
|
||
|
|
||
|
new_width, new_height = map(int, request.resolution.split("x"))
|
||
|
|
||
|
# Generate mask and pose
|
||
|
mask_result = generate_mask(
|
||
|
model_image,
|
||
|
request.category,
|
||
|
request.offset_top,
|
||
|
request.offset_bottom,
|
||
|
request.offset_left,
|
||
|
request.offset_right
|
||
|
)
|
||
|
|
||
|
# Process images for try-on
|
||
|
with torch.inference_mode():
|
||
|
# Prepare images
|
||
|
model_image_size = model_image.size
|
||
|
garment_image_resized, _, _ = pad_and_resize(garment_image, new_width, new_height)
|
||
|
model_image_resized, pad_w, pad_h = pad_and_resize(model_image, new_width, new_height)
|
||
|
|
||
|
# Prepare mask and pose
|
||
|
mask_resized = pad_and_resize(mask_result['mask'], new_width, new_height)[0]
|
||
|
pose_image_resized = pad_and_resize(mask_result['pose_image'], new_width, new_height)[0]
|
||
|
|
||
|
# Generate try-on images
|
||
|
seed = request.seed if request.seed != -1 else torch.randint(0, 2147483647, (1,)).item()
|
||
|
result_images = pipeline(
|
||
|
height=new_height,
|
||
|
width=new_width,
|
||
|
guidance_scale=request.image_scale,
|
||
|
num_inference_steps=request.n_steps,
|
||
|
generator=torch.Generator("cpu").manual_seed(seed),
|
||
|
cloth_image=garment_image_resized,
|
||
|
model_image=model_image_resized,
|
||
|
mask=mask_resized,
|
||
|
pose_image=pose_image_resized,
|
||
|
num_images_per_prompt=request.num_images
|
||
|
).images
|
||
|
|
||
|
# Convert results to base64
|
||
|
result_base64 = []
|
||
|
for img in result_images:
|
||
|
img_resized = unpad_and_resize(img, pad_w, pad_h, model_image_size[0], model_image_size[1])
|
||
|
result_base64.append(image_to_base64(img_resized))
|
||
|
|
||
|
return {
|
||
|
"status": "success",
|
||
|
"generated_images": result_base64,
|
||
|
"seed": seed
|
||
|
}
|
||
|
|
||
|
except Exception as e:
|
||
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import uvicorn
|
||
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|