from fastapi import FastAPI, HTTPException from pydantic import BaseModel import base64 import io from PIL import Image import torch import os from transformers import CLIPVisionModelWithProjection from huggingface_hub import snapshot_download import numpy as np import math # Import custom modules (make sure these are in the same directory) from preprocess.humanparsing.run_parsing import Parsing from preprocess.dwpose import DWposeDetector from src.pose_guider import PoseGuider from src.pipeline_stable_diffusion_3_tryon import StableDiffusion3TryOnPipeline from src.transformer_sd3_garm import SD3Transformer2DModel as SD3Transformer2DModel_Garm from src.transformer_sd3_vton import SD3Transformer2DModel as SD3Transformer2DModel_Vton from src.utils_mask import get_mask_location app = FastAPI() # Initialize models and configurations weight_dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" # Download model weights fitdit_repo = "BoyuanJiang/FitDiT" repo_path = snapshot_download(repo_id=fitdit_repo) # Load models transformer_garm = SD3Transformer2DModel_Garm.from_pretrained( os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype ) transformer_vton = SD3Transformer2DModel_Vton.from_pretrained( os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype ) pose_guider = PoseGuider( conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512) ) pose_guider.load_state_dict(torch.load(os.path.join(repo_path, "pose_guider", "diffusion_pytorch_model.bin"))) # Load image encoders image_encoder_large = CLIPVisionModelWithProjection.from_pretrained( "openai/clip-vit-large-patch14", torch_dtype=weight_dtype ) image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained( "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=weight_dtype ) # Move models to device pose_guider.to(device=device, dtype=weight_dtype) image_encoder_large.to(device=device) image_encoder_bigG.to(device=device) # Initialize pipeline pipeline = StableDiffusion3TryOnPipeline.from_pretrained( repo_path, torch_dtype=weight_dtype, transformer_garm=transformer_garm, transformer_vton=transformer_vton, pose_guider=pose_guider, image_encoder_large=image_encoder_large, image_encoder_bigG=image_encoder_bigG ) pipeline.to(device) # Initialize processors dwprocessor = DWposeDetector(model_root=repo_path, device=device) parsing_model = Parsing(model_root=repo_path, device=device) class TryOnRequest(BaseModel): model_image: str # base64 encoded image garment_image: str # base64 encoded image category: str resolution: str = "768x1024" n_steps: int = 20 image_scale: float = 2.0 num_images: int = 1 seed: int = -1 offset_top: int = 0 offset_bottom: int = 0 offset_left: int = 0 offset_right: int = 0 def base64_to_image(base64_str: str) -> Image.Image: """Convert base64 string to PIL Image""" try: image_data = base64.b64decode(base64_str) image = Image.open(io.BytesIO(image_data)) return image except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}") def image_to_base64(image: Image.Image) -> str: """Convert PIL Image to base64 string""" buffered = io.BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode() def pad_and_resize(im: Image.Image, new_width: int = 768, new_height: int = 1024, pad_color: tuple = (255, 255, 255), mode: int = Image.LANCZOS): """Pad and resize image to target dimensions while maintaining aspect ratio""" old_width, old_height = im.size ratio_w = new_width / old_width ratio_h = new_height / old_height if ratio_w < ratio_h: new_size = (new_width, round(old_height * ratio_w)) else: new_size = (round(old_width * ratio_h), new_height) im_resized = im.resize(new_size, mode) pad_w = math.ceil((new_width - im_resized.width) / 2) pad_h = math.ceil((new_height - im_resized.height) / 2) new_im = Image.new('RGB', (new_width, new_height), pad_color) new_im.paste(im_resized, (pad_w, pad_h)) return new_im, pad_w, pad_h def unpad_and_resize(padded_im: Image.Image, pad_w: int, pad_h: int, original_width: int, original_height: int) -> Image.Image: """Remove padding and resize image back to original dimensions""" width, height = padded_im.size left = pad_w top = pad_h right = width - pad_w bottom = height - pad_h cropped_im = padded_im.crop((left, top, right, bottom)) resized_im = cropped_im.resize((original_width, original_height), Image.LANCZOS) return resized_im def resize_image(img: Image.Image, target_size: int = 768) -> Image.Image: """Resize image maintaining aspect ratio""" width, height = img.size if width < height: scale = target_size / width else: scale = target_size / height new_width = int(round(width * scale)) new_height = int(round(height * scale)) return img.resize((new_width, new_height), Image.LANCZOS) def generate_mask( vton_img: Image.Image, category: str, offset_top: int, offset_bottom: int, offset_left: int, offset_right: int ): """Generate mask for the model image""" with torch.inference_mode(): vton_img_det = resize_image(vton_img) pose_image, keypoints, _, candidate = dwprocessor(np.array(vton_img_det)[:,:,::-1]) candidate[candidate<0] = 0 candidate = candidate[0] candidate[:, 0] *= vton_img_det.width candidate[:, 1] *= vton_img_det.height pose_image = pose_image[:,:,::-1] pose_image = Image.fromarray(pose_image) model_parse, _ = parsing_model(vton_img_det) mask, mask_gray = get_mask_location( category, model_parse, candidate, model_parse.width, model_parse.height, offset_top, offset_bottom, offset_left, offset_right ) mask = mask.resize(vton_img.size) mask_gray = mask_gray.resize(vton_img.size) mask = mask.convert("L") mask_gray = mask_gray.convert("L") masked_vton_img = Image.composite(mask_gray, vton_img, mask) return { 'mask': mask, 'pose_image': pose_image, 'masked_image': masked_vton_img } @app.post("/try-on") async def try_on(request: TryOnRequest): try: # Convert base64 to images model_image = base64_to_image(request.model_image) garment_image = base64_to_image(request.garment_image) # Validate resolution if request.resolution not in ["768x1024", "1152x1536", "1536x2048"]: raise HTTPException(status_code=400, detail="Invalid resolution") new_width, new_height = map(int, request.resolution.split("x")) # Generate mask and pose mask_result = generate_mask( model_image, request.category, request.offset_top, request.offset_bottom, request.offset_left, request.offset_right ) # Process images for try-on with torch.inference_mode(): # Prepare images model_image_size = model_image.size garment_image_resized, _, _ = pad_and_resize(garment_image, new_width, new_height) model_image_resized, pad_w, pad_h = pad_and_resize(model_image, new_width, new_height) # Prepare mask and pose mask_resized = pad_and_resize(mask_result['mask'], new_width, new_height)[0] pose_image_resized = pad_and_resize(mask_result['pose_image'], new_width, new_height)[0] # Generate try-on images seed = request.seed if request.seed != -1 else torch.randint(0, 2147483647, (1,)).item() result_images = pipeline( height=new_height, width=new_width, guidance_scale=request.image_scale, num_inference_steps=request.n_steps, generator=torch.Generator("cpu").manual_seed(seed), cloth_image=garment_image_resized, model_image=model_image_resized, mask=mask_resized, pose_image=pose_image_resized, num_images_per_prompt=request.num_images ).images # Convert results to base64 result_base64 = [] for img in result_images: img_resized = unpad_and_resize(img, pad_w, pad_h, model_image_size[0], model_image_size[1]) result_base64.append(image_to_base64(img_resized)) return { "status": "success", "generated_images": result_base64, "seed": seed } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8001)