VTON-API/vton-api/server.py

267 lines
9.1 KiB
Python
Raw Normal View History

2025-01-28 21:48:35 +00:00
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import base64
import io
from PIL import Image
import torch
import os
from transformers import CLIPVisionModelWithProjection
from huggingface_hub import snapshot_download
import numpy as np
import math
# Import custom modules (make sure these are in the same directory)
from preprocess.humanparsing.run_parsing import Parsing
from preprocess.dwpose import DWposeDetector
from src.pose_guider import PoseGuider
from src.pipeline_stable_diffusion_3_tryon import StableDiffusion3TryOnPipeline
from src.transformer_sd3_garm import SD3Transformer2DModel as SD3Transformer2DModel_Garm
from src.transformer_sd3_vton import SD3Transformer2DModel as SD3Transformer2DModel_Vton
from src.utils_mask import get_mask_location
app = FastAPI()
# Initialize models and configurations
weight_dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Download model weights
fitdit_repo = "BoyuanJiang/FitDiT"
repo_path = snapshot_download(repo_id=fitdit_repo)
# Load models
transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(
os.path.join(repo_path, "transformer_garm"),
torch_dtype=weight_dtype
)
transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(
os.path.join(repo_path, "transformer_vton"),
torch_dtype=weight_dtype
)
pose_guider = PoseGuider(
conditioning_embedding_channels=1536,
conditioning_channels=3,
block_out_channels=(32, 64, 256, 512)
)
pose_guider.load_state_dict(torch.load(os.path.join(repo_path, "pose_guider", "diffusion_pytorch_model.bin")))
# Load image encoders
image_encoder_large = CLIPVisionModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=weight_dtype
)
image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
torch_dtype=weight_dtype
)
# Move models to device
pose_guider.to(device=device, dtype=weight_dtype)
image_encoder_large.to(device=device)
image_encoder_bigG.to(device=device)
# Initialize pipeline
pipeline = StableDiffusion3TryOnPipeline.from_pretrained(
repo_path,
torch_dtype=weight_dtype,
transformer_garm=transformer_garm,
transformer_vton=transformer_vton,
pose_guider=pose_guider,
image_encoder_large=image_encoder_large,
image_encoder_bigG=image_encoder_bigG
)
pipeline.to(device)
# Initialize processors
dwprocessor = DWposeDetector(model_root=repo_path, device=device)
parsing_model = Parsing(model_root=repo_path, device=device)
class TryOnRequest(BaseModel):
model_image: str # base64 encoded image
garment_image: str # base64 encoded image
category: str
resolution: str = "768x1024"
n_steps: int = 20
image_scale: float = 2.0
num_images: int = 1
seed: int = -1
offset_top: int = 0
offset_bottom: int = 0
offset_left: int = 0
offset_right: int = 0
def base64_to_image(base64_str: str) -> Image.Image:
"""Convert base64 string to PIL Image"""
try:
image_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_data))
return image
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}")
def image_to_base64(image: Image.Image) -> str:
"""Convert PIL Image to base64 string"""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def pad_and_resize(im: Image.Image, new_width: int = 768, new_height: int = 1024, pad_color: tuple = (255, 255, 255), mode: int = Image.LANCZOS):
"""Pad and resize image to target dimensions while maintaining aspect ratio"""
old_width, old_height = im.size
ratio_w = new_width / old_width
ratio_h = new_height / old_height
if ratio_w < ratio_h:
new_size = (new_width, round(old_height * ratio_w))
else:
new_size = (round(old_width * ratio_h), new_height)
im_resized = im.resize(new_size, mode)
pad_w = math.ceil((new_width - im_resized.width) / 2)
pad_h = math.ceil((new_height - im_resized.height) / 2)
new_im = Image.new('RGB', (new_width, new_height), pad_color)
new_im.paste(im_resized, (pad_w, pad_h))
return new_im, pad_w, pad_h
def unpad_and_resize(padded_im: Image.Image, pad_w: int, pad_h: int, original_width: int, original_height: int) -> Image.Image:
"""Remove padding and resize image back to original dimensions"""
width, height = padded_im.size
left = pad_w
top = pad_h
right = width - pad_w
bottom = height - pad_h
cropped_im = padded_im.crop((left, top, right, bottom))
resized_im = cropped_im.resize((original_width, original_height), Image.LANCZOS)
return resized_im
def resize_image(img: Image.Image, target_size: int = 768) -> Image.Image:
"""Resize image maintaining aspect ratio"""
width, height = img.size
if width < height:
scale = target_size / width
else:
scale = target_size / height
new_width = int(round(width * scale))
new_height = int(round(height * scale))
return img.resize((new_width, new_height), Image.LANCZOS)
def generate_mask(
vton_img: Image.Image,
category: str,
offset_top: int,
offset_bottom: int,
offset_left: int,
offset_right: int
):
"""Generate mask for the model image"""
with torch.inference_mode():
vton_img_det = resize_image(vton_img)
pose_image, keypoints, _, candidate = dwprocessor(np.array(vton_img_det)[:,:,::-1])
candidate[candidate<0] = 0
candidate = candidate[0]
candidate[:, 0] *= vton_img_det.width
candidate[:, 1] *= vton_img_det.height
pose_image = pose_image[:,:,::-1]
pose_image = Image.fromarray(pose_image)
model_parse, _ = parsing_model(vton_img_det)
mask, mask_gray = get_mask_location(
category,
model_parse,
candidate,
model_parse.width,
model_parse.height,
offset_top,
offset_bottom,
offset_left,
offset_right
)
mask = mask.resize(vton_img.size)
mask_gray = mask_gray.resize(vton_img.size)
mask = mask.convert("L")
mask_gray = mask_gray.convert("L")
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
return {
'mask': mask,
'pose_image': pose_image,
'masked_image': masked_vton_img
}
@app.post("/try-on")
async def try_on(request: TryOnRequest):
try:
# Convert base64 to images
model_image = base64_to_image(request.model_image)
garment_image = base64_to_image(request.garment_image)
# Validate resolution
if request.resolution not in ["768x1024", "1152x1536", "1536x2048"]:
raise HTTPException(status_code=400, detail="Invalid resolution")
new_width, new_height = map(int, request.resolution.split("x"))
# Generate mask and pose
mask_result = generate_mask(
model_image,
request.category,
request.offset_top,
request.offset_bottom,
request.offset_left,
request.offset_right
)
# Process images for try-on
with torch.inference_mode():
# Prepare images
model_image_size = model_image.size
garment_image_resized, _, _ = pad_and_resize(garment_image, new_width, new_height)
model_image_resized, pad_w, pad_h = pad_and_resize(model_image, new_width, new_height)
# Prepare mask and pose
mask_resized = pad_and_resize(mask_result['mask'], new_width, new_height)[0]
pose_image_resized = pad_and_resize(mask_result['pose_image'], new_width, new_height)[0]
# Generate try-on images
seed = request.seed if request.seed != -1 else torch.randint(0, 2147483647, (1,)).item()
result_images = pipeline(
height=new_height,
width=new_width,
guidance_scale=request.image_scale,
num_inference_steps=request.n_steps,
generator=torch.Generator("cpu").manual_seed(seed),
cloth_image=garment_image_resized,
model_image=model_image_resized,
mask=mask_resized,
pose_image=pose_image_resized,
num_images_per_prompt=request.num_images
).images
# Convert results to base64
result_base64 = []
for img in result_images:
img_resized = unpad_and_resize(img, pad_w, pad_h, model_image_size[0], model_image_size[1])
result_base64.append(image_to_base64(img_resized))
return {
"status": "success",
"generated_images": result_base64,
"seed": seed
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)