VTON-API/vton-api/server.py

267 lines
9.1 KiB
Python

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import base64
import io
from PIL import Image
import torch
import os
from transformers import CLIPVisionModelWithProjection
from huggingface_hub import snapshot_download
import numpy as np
import math
# Import custom modules (make sure these are in the same directory)
from preprocess.humanparsing.run_parsing import Parsing
from preprocess.dwpose import DWposeDetector
from src.pose_guider import PoseGuider
from src.pipeline_stable_diffusion_3_tryon import StableDiffusion3TryOnPipeline
from src.transformer_sd3_garm import SD3Transformer2DModel as SD3Transformer2DModel_Garm
from src.transformer_sd3_vton import SD3Transformer2DModel as SD3Transformer2DModel_Vton
from src.utils_mask import get_mask_location
app = FastAPI()
# Initialize models and configurations
weight_dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Download model weights
fitdit_repo = "BoyuanJiang/FitDiT"
repo_path = snapshot_download(repo_id=fitdit_repo)
# Load models
transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(
os.path.join(repo_path, "transformer_garm"),
torch_dtype=weight_dtype
)
transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(
os.path.join(repo_path, "transformer_vton"),
torch_dtype=weight_dtype
)
pose_guider = PoseGuider(
conditioning_embedding_channels=1536,
conditioning_channels=3,
block_out_channels=(32, 64, 256, 512)
)
pose_guider.load_state_dict(torch.load(os.path.join(repo_path, "pose_guider", "diffusion_pytorch_model.bin")))
# Load image encoders
image_encoder_large = CLIPVisionModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=weight_dtype
)
image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
torch_dtype=weight_dtype
)
# Move models to device
pose_guider.to(device=device, dtype=weight_dtype)
image_encoder_large.to(device=device)
image_encoder_bigG.to(device=device)
# Initialize pipeline
pipeline = StableDiffusion3TryOnPipeline.from_pretrained(
repo_path,
torch_dtype=weight_dtype,
transformer_garm=transformer_garm,
transformer_vton=transformer_vton,
pose_guider=pose_guider,
image_encoder_large=image_encoder_large,
image_encoder_bigG=image_encoder_bigG
)
pipeline.to(device)
# Initialize processors
dwprocessor = DWposeDetector(model_root=repo_path, device=device)
parsing_model = Parsing(model_root=repo_path, device=device)
class TryOnRequest(BaseModel):
model_image: str # base64 encoded image
garment_image: str # base64 encoded image
category: str
resolution: str = "768x1024"
n_steps: int = 20
image_scale: float = 2.0
num_images: int = 1
seed: int = -1
offset_top: int = 0
offset_bottom: int = 0
offset_left: int = 0
offset_right: int = 0
def base64_to_image(base64_str: str) -> Image.Image:
"""Convert base64 string to PIL Image"""
try:
image_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_data))
return image
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}")
def image_to_base64(image: Image.Image) -> str:
"""Convert PIL Image to base64 string"""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def pad_and_resize(im: Image.Image, new_width: int = 768, new_height: int = 1024, pad_color: tuple = (255, 255, 255), mode: int = Image.LANCZOS):
"""Pad and resize image to target dimensions while maintaining aspect ratio"""
old_width, old_height = im.size
ratio_w = new_width / old_width
ratio_h = new_height / old_height
if ratio_w < ratio_h:
new_size = (new_width, round(old_height * ratio_w))
else:
new_size = (round(old_width * ratio_h), new_height)
im_resized = im.resize(new_size, mode)
pad_w = math.ceil((new_width - im_resized.width) / 2)
pad_h = math.ceil((new_height - im_resized.height) / 2)
new_im = Image.new('RGB', (new_width, new_height), pad_color)
new_im.paste(im_resized, (pad_w, pad_h))
return new_im, pad_w, pad_h
def unpad_and_resize(padded_im: Image.Image, pad_w: int, pad_h: int, original_width: int, original_height: int) -> Image.Image:
"""Remove padding and resize image back to original dimensions"""
width, height = padded_im.size
left = pad_w
top = pad_h
right = width - pad_w
bottom = height - pad_h
cropped_im = padded_im.crop((left, top, right, bottom))
resized_im = cropped_im.resize((original_width, original_height), Image.LANCZOS)
return resized_im
def resize_image(img: Image.Image, target_size: int = 768) -> Image.Image:
"""Resize image maintaining aspect ratio"""
width, height = img.size
if width < height:
scale = target_size / width
else:
scale = target_size / height
new_width = int(round(width * scale))
new_height = int(round(height * scale))
return img.resize((new_width, new_height), Image.LANCZOS)
def generate_mask(
vton_img: Image.Image,
category: str,
offset_top: int,
offset_bottom: int,
offset_left: int,
offset_right: int
):
"""Generate mask for the model image"""
with torch.inference_mode():
vton_img_det = resize_image(vton_img)
pose_image, keypoints, _, candidate = dwprocessor(np.array(vton_img_det)[:,:,::-1])
candidate[candidate<0] = 0
candidate = candidate[0]
candidate[:, 0] *= vton_img_det.width
candidate[:, 1] *= vton_img_det.height
pose_image = pose_image[:,:,::-1]
pose_image = Image.fromarray(pose_image)
model_parse, _ = parsing_model(vton_img_det)
mask, mask_gray = get_mask_location(
category,
model_parse,
candidate,
model_parse.width,
model_parse.height,
offset_top,
offset_bottom,
offset_left,
offset_right
)
mask = mask.resize(vton_img.size)
mask_gray = mask_gray.resize(vton_img.size)
mask = mask.convert("L")
mask_gray = mask_gray.convert("L")
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
return {
'mask': mask,
'pose_image': pose_image,
'masked_image': masked_vton_img
}
@app.post("/try-on")
async def try_on(request: TryOnRequest):
try:
# Convert base64 to images
model_image = base64_to_image(request.model_image)
garment_image = base64_to_image(request.garment_image)
# Validate resolution
if request.resolution not in ["768x1024", "1152x1536", "1536x2048"]:
raise HTTPException(status_code=400, detail="Invalid resolution")
new_width, new_height = map(int, request.resolution.split("x"))
# Generate mask and pose
mask_result = generate_mask(
model_image,
request.category,
request.offset_top,
request.offset_bottom,
request.offset_left,
request.offset_right
)
# Process images for try-on
with torch.inference_mode():
# Prepare images
model_image_size = model_image.size
garment_image_resized, _, _ = pad_and_resize(garment_image, new_width, new_height)
model_image_resized, pad_w, pad_h = pad_and_resize(model_image, new_width, new_height)
# Prepare mask and pose
mask_resized = pad_and_resize(mask_result['mask'], new_width, new_height)[0]
pose_image_resized = pad_and_resize(mask_result['pose_image'], new_width, new_height)[0]
# Generate try-on images
seed = request.seed if request.seed != -1 else torch.randint(0, 2147483647, (1,)).item()
result_images = pipeline(
height=new_height,
width=new_width,
guidance_scale=request.image_scale,
num_inference_steps=request.n_steps,
generator=torch.Generator("cpu").manual_seed(seed),
cloth_image=garment_image_resized,
model_image=model_image_resized,
mask=mask_resized,
pose_image=pose_image_resized,
num_images_per_prompt=request.num_images
).images
# Convert results to base64
result_base64 = []
for img in result_images:
img_resized = unpad_and_resize(img, pad_w, pad_h, model_image_size[0], model_image_size[1])
result_base64.append(image_to_base64(img_resized))
return {
"status": "success",
"generated_images": result_base64,
"seed": seed
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)