main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import math
  4import re
  5from typing import Literal
  6
  7from ai.utils import clean_cmd_prefix
  8
  9
 10def extract_aspect_ratio(text: str) -> tuple[str, str]:
 11    """Infer aspect ratio from text.
 12
 13    If prompt startswith "width:height", set aspect_ratio to "width:height"
 14    and remove "width:height" from prompt.
 15
 16    Match:
 17    aspect_ratio = 9:16
 18    "aspect_ratio" : "16:9"
 19    aspect_ratio 9:16
 20    portrait
 21    landscape
 22    square
 23    16:9 prompt...
 24
 25    Args:
 26        text (str): Text.
 27
 28    Returns:
 29        tuple[str, str]: 1. Aspect ratio. (width:height)
 30                         2. Prompt.
 31    """
 32    # ruff: noqa: RUF001
 33    text = clean_cmd_prefix(text)
 34    # text startswith "width:height"
 35    if match := re.match(r"(\d+\s*[::]\s*\d+)", text):
 36        return match.group(1), text.removeprefix(match.group(1)).lstrip()
 37
 38    r"""
 39    (?i): 表示不区分大小写匹配。这意味着 Aspect_Ratio 也能被匹配到。
 40    (aspect_ratio|aspect ratio|ar|宽高比) : 匹配 "aspect_ratio", "aspect ratio", "ar""宽高比"
 41    [\s=::\"]*                  : 匹配空格 or 等号 `=` or 冒号 `:` or 全角冒号 `:` or 双引号 `"` (0 或多次)
 42    (\d+:\d+)             : 捕获组, 匹配一个或多个数字, 接着一个冒号, 再接着一个或多个数字 (例如 "5:4", "16:9")
 43    """
 44    pattern = r"(?i)(aspect_ratio|aspect ratio|ar|宽高比)[\s=::\"]*(\d+:\d+)"
 45    if match := re.search(pattern, text):
 46        return match.group(2), text
 47
 48    if "portrait" in text.lower():
 49        return "9:16", text
 50    if "landscape" in text.lower():
 51        return "16:9", text
 52    if "square" in text.lower():
 53        return "1:1", text
 54    return "", text  # default
 55
 56
 57def aspect_ratio_to_size(
 58    aspect_ratio: str,
 59    resolution: Literal["1K", "2K", "4K"] = "1K",
 60    max_width: int = int(1e16),
 61    max_height: int = int(1e16),
 62    max_size: int = int(1e32),
 63) -> tuple[int, int]:
 64    """Convert aspect ratio to image size (width, height)."""
 65    width, height = 1024, 1024
 66    if resolution.upper() == "1K":
 67        match aspect_ratio:
 68            case "1:1":
 69                width, height = 1024, 1024
 70            case "2:3":
 71                width, height = 832, 1248
 72            case "3:2":
 73                width, height = 1248, 832
 74            case "3:4":
 75                width, height = 864, 1152
 76            case "4:3":
 77                width, height = 1152, 864
 78            case "4:5":
 79                width, height = 928, 1152
 80            case "5:4":
 81                width, height = 1152, 928
 82            case "9:16":
 83                width, height = 720, 1280
 84            case "16:9":
 85                width, height = 1280, 720
 86            case "21:9":
 87                width, height = 1512, 648
 88            case _:
 89                width, height = 1024, 1024
 90    elif resolution.upper() == "2K":
 91        match aspect_ratio:
 92            case "1:1":
 93                width, height = 2048, 2048
 94            case "2:3":
 95                width, height = 1664, 2496
 96            case "3:2":
 97                width, height = 2496, 1664
 98            case "3:4":
 99                width, height = 1728, 2304
100            case "4:3":
101                width, height = 2304, 1728
102            case "4:5":
103                width, height = 1856, 2304
104            case "5:4":
105                width, height = 2304, 1856
106            case "9:16":
107                width, height = 1440, 2560
108            case "16:9":
109                width, height = 2560, 1440
110            case "21:9":
111                width, height = 3024, 1296
112            case _:
113                width, height = 2048, 2048
114    elif resolution.upper() == "4K":
115        match aspect_ratio:
116            case "1:1":
117                width, height = 4096, 4096
118            case "2:3":
119                width, height = 3328, 4992
120            case "3:2":
121                width, height = 4992, 3328
122            case "3:4":
123                width, height = 3456, 4608
124            case "4:3":
125                width, height = 4608, 3456
126            case "4:5":
127                width, height = 3648, 4560
128            case "5:4":
129                width, height = 4560, 3648
130            case "9:16":
131                width, height = 2880, 5120
132            case "16:9":
133                width, height = 5120, 2880
134            case "21:9":
135                width, height = 6048, 2592
136            case _:
137                width, height = 4096, 4096
138    return adjust_size(width, height, max_width, max_height, max_size)
139
140
141def adjust_size(width: int, height: int, max_width: int = int(1e16), max_height: int = int(1e16), max_size: int = int(1e32)) -> tuple[int, int]:
142    """Adjust image size to fit within max_width, max_height, max_size.
143
144    Args:
145        width (int): Image width.
146        height (int): Image height.
147        max_width (int, optional): Max width. Defaults to int(1E16).
148        max_height (int, optional): Max height. Defaults to int(1E16).
149        max_size (int, optional): Max size (width * height). Defaults to int(1E32).
150
151    Returns:
152        tuple[int, int]: Adjusted width, height.
153    """
154    # 1. Scale down to fit within max_size
155    scale = min(max_size / (width * height), 1.0)
156    width = math.floor(width * scale)
157    height = math.floor(height * scale)
158
159    # 2. Scale down to fit within max_width, max_height
160    scale = min(max_width / width, max_height / height, 1.0)
161    width = math.floor(width * scale)
162    height = math.floor(height * scale)
163
164    return width, height