OpenFakeDemo / app /screenshot.py
vicliv's picture
text heuristic fix
0e2f6dd
"""Screenshot preprocessing pipeline.
Given an input image, decides whether it is a screenshot containing an
embedded photograph/video that should be cropped out before running the
detector. Returns a `PreprocessResult` describing the decision:
- status="full": not a screenshot, feed the original image through
- status="cropped": one or more embedded media regions were extracted
- status="text_only": screenshot is essentially text (tweet, doc, ...)
Text region detection uses the EAST scene-text detector via OpenCV's
`cv2.dnn`. The model file (`frozen_east_text_detection.pb`) lives next to
this module; if it's missing, text detection degrades gracefully to "no
text found" (status flips toward `cropped`/`full` rather than text_only).
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
from PIL import Image, ImageOps
# ──────────────────────────────────────────────────────────────
# Result
# ──────────────────────────────────────────────────────────────
@dataclass
class PreprocessResult:
image: Optional[Image.Image | list[Image.Image]]
status: str
crop_box: Optional[tuple | list[tuple]]
text_fraction: float
debug: dict
# ──────────────────────────────────────────────────────────────
# Tuning parameters
# ──────────────────────────────────────────────────────────────
TEXT_ONLY_FRACTION = 0.10
EMBEDDED_MIN_AREA = 0.12
SECOND_PASS_MIN_AREA = 0.20
SECOND_PASS_MIN_SHRINK = 0.02
# ──────────────────────────────────────────────────────────────
# Text region detection (EAST scene-text detector via cv2.dnn)
# ──────────────────────────────────────────────────────────────
EAST_MIN_SIZE = 320 # input dim for small images (multiple of 32)
EAST_MAX_SIZE = 1024 # cap for very large images
EAST_SCALE_DIVISOR = 3 # native_dim / EAST_SCALE_DIVISOR β†’ target dim
# then rounded down to a multiple of 32
EAST_SCORE_THRESHOLD = 0.5
EAST_NMS_THRESHOLD = 0.4
EAST_MODEL_FILENAME = "frozen_east_text_detection.pb"
EAST_OUTPUT_LAYERS = (
"feature_fusion/Conv_7/Sigmoid",
"feature_fusion/concat_3",
)
_east_net = None
_east_load_attempted = False
def _get_east_net():
"""Load and cache the EAST text detector. Returns None if unavailable."""
global _east_net, _east_load_attempted
if _east_load_attempted:
return _east_net
_east_load_attempted = True
candidates = [
Path(__file__).parent / EAST_MODEL_FILENAME,
Path(__file__).parent.parent / EAST_MODEL_FILENAME,
Path("/code/app") / EAST_MODEL_FILENAME,
]
for path in candidates:
if not path.exists():
continue
try:
_east_net = cv2.dnn.readNet(str(path))
print(f"[screenshot] EAST text detector loaded from {path}")
return _east_net
except Exception as exc:
print(f"[screenshot] EAST load failed at {path}: {exc}")
print(
"[screenshot] EAST model not found β€” text detection disabled. "
"Download frozen_east_text_detection.pb and place it next to app/."
)
return None
def _decode_east(scores: np.ndarray, geometry: np.ndarray,
score_threshold: float) -> tuple[list, list]:
"""Decode raw EAST outputs into (rotated-rect, confidence) pairs."""
num_rows, num_cols = scores.shape[2:4]
rects: list = []
confidences: list = []
for y in range(num_rows):
scores_row = scores[0, 0, y]
x0 = geometry[0, 0, y]
x1 = geometry[0, 1, y]
x2 = geometry[0, 2, y]
x3 = geometry[0, 3, y]
angles = geometry[0, 4, y]
for x in range(num_cols):
score = float(scores_row[x])
if score < score_threshold:
continue
offset_x = x * 4.0
offset_y = y * 4.0
angle = float(angles[x])
cos_a = math.cos(angle)
sin_a = math.sin(angle)
h_box = float(x0[x] + x2[x])
w_box = float(x1[x] + x3[x])
end_x = offset_x + cos_a * float(x1[x]) + sin_a * float(x2[x])
end_y = offset_y - sin_a * float(x1[x]) + cos_a * float(x2[x])
start_x = end_x - w_box
start_y = end_y - h_box
cx = (start_x + end_x) / 2.0
cy = (start_y + end_y) / 2.0
rects.append(((float(cx), float(cy)),
(w_box, h_box),
-float(angle) * 180.0 / math.pi))
confidences.append(score)
return rects, confidences
def detect_text_boxes(image: np.ndarray) -> list[tuple]:
"""Find text-region bounding boxes via the EAST scene-text detector.
Returns axis-aligned (x, y, w, h) tuples in original image coords. The
image is resampled to a fixed 320Γ—320 EAST input for speed; the resulting
rotated rectangles are projected back to the original frame as their
axis-aligned bounding boxes (good enough for masking and density math).
Returns [] if the EAST model file isn't present.
"""
net = _get_east_net()
if net is None:
return []
h, w = image.shape[:2]
if h < 4 or w < 4:
return []
img = image
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
# Pick an EAST input size that keeps small UI text legible without paying
# for inference at native resolution: ~1/3 of the longer dimension,
# clamped to [EAST_MIN_SIZE, EAST_MAX_SIZE] and rounded down to a multiple
# of 32 (EAST requires that).
longest = max(h, w)
target = max(EAST_MIN_SIZE,
min(EAST_MAX_SIZE, longest // EAST_SCALE_DIVISOR))
target = (target // 32) * 32
if target < 32:
target = 32
ratio_w = w / float(target)
ratio_h = h / float(target)
resized = cv2.resize(img, (target, target))
blob = cv2.dnn.blobFromImage(
resized,
scalefactor=1.0,
size=(target, target),
mean=(123.68, 116.78, 103.94),
swapRB=True,
crop=False,
)
net.setInput(blob)
try:
scores, geometry = net.forward(list(EAST_OUTPUT_LAYERS))
except cv2.error as exc:
print(f"[screenshot] EAST forward failed: {exc}")
return []
rects, confidences = _decode_east(scores, geometry, EAST_SCORE_THRESHOLD)
if not rects:
return []
indices = cv2.dnn.NMSBoxesRotated(
rects, confidences, EAST_SCORE_THRESHOLD, EAST_NMS_THRESHOLD
)
if indices is None or len(indices) == 0:
return []
indices = np.asarray(indices).flatten()
boxes: list[tuple] = []
for i in indices:
rect = rects[int(i)]
pts = cv2.boxPoints(rect)
xs = pts[:, 0] * ratio_w
ys = pts[:, 1] * ratio_h
x0 = int(max(0, math.floor(xs.min())))
y0 = int(max(0, math.floor(ys.min())))
x1 = int(min(w, math.ceil(xs.max())))
y1 = int(min(h, math.ceil(ys.max())))
if x1 > x0 and y1 > y0:
boxes.append((x0, y0, x1 - x0, y1 - y0))
return boxes
def _box_union_fraction(boxes: list[tuple], h: int, w: int) -> float:
"""Fraction of image area covered by the *union* of boxes.
Sum-of-areas would overcount any time boxes overlap. Rasterizing into a
mask and averaging gives the correct geometric coverage.
"""
if not boxes or h <= 0 or w <= 0:
return 0.0
mask = np.zeros((h, w), dtype=np.uint8)
for (bx, by, bw, bh) in boxes:
x0 = max(0, bx); y0 = max(0, by)
x1 = min(w, bx + bw); y1 = min(h, by + bh)
if x1 > x0 and y1 > y0:
mask[y0:y1, x0:x1] = 1
return float(mask.mean())
# ──────────────────────────────────────────────────────────────
# Tier 1: cheap screenshot signals
# ──────────────────────────────────────────────────────────────
def _border_uniformity(gray: np.ndarray) -> float:
h, w = gray.shape
strip = max(8, min(h, w) // 50)
top = gray[:strip, :].std()
bottom = gray[-strip:, :].std()
left = gray[:, :strip].std()
right = gray[:, -strip:].std()
return float(min(top, bottom, left, right))
def _is_candidate_screenshot(image: np.ndarray) -> dict:
h, w = image.shape[:2]
aspect = h / w
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) if image.ndim == 3 else image
border_std = _border_uniformity(gray)
info = {
"aspect_ratio": round(aspect, 3),
"border_std": round(border_std, 2),
"is_candidate": False,
"reason": "",
}
if aspect > 1.9:
# Modern phone screenshots are 19.5:9 or 20:9 (β‰₯ 2.0). 16:9 portrait
# photos (1.78) fall through to the border_std check so natural photos
# don't get cropped just for being tall.
info["is_candidate"] = True
info["reason"] = f"tall aspect ratio ({aspect:.2f} > 1.9)"
elif aspect < 0.45:
info["is_candidate"] = True
info["reason"] = f"wide aspect ratio ({aspect:.2f} < 0.45)"
elif 0.5 <= aspect <= 0.8:
# Desktop screenshot aspect (16:9, 16:10, etc.). These have decorated
# borders (menu bar, dock, tabs) so border_std is uninformative β€” let
# Tier 2 decide on its own.
info["is_candidate"] = True
info["reason"] = f"desktop aspect ratio ({aspect:.2f})"
elif border_std < 3.0:
info["is_candidate"] = True
info["reason"] = f"uniform border (std={border_std:.2f} < 3.0)"
else:
info["reason"] = "natural photo (no screenshot signals)"
return info
# ──────────────────────────────────────────────────────────────
# Crop refinement: trim / expand
# ──────────────────────────────────────────────────────────────
def _refine_crop(gray: np.ndarray, x: int, y: int, bw: int, bh: int,
strip: int = 8, var_threshold: float = 8.0) -> tuple:
"""Tighten a crop box by trimming uniform (low-variance) strips from edges."""
img_h, img_w = gray.shape
while bh > strip * 3:
row = gray[y:y + strip, x:x + bw]
if row.std() < var_threshold:
y += strip
bh -= strip
else:
break
while bh > strip * 3:
row = gray[y + bh - strip:y + bh, x:x + bw]
if row.std() < var_threshold:
bh -= strip
else:
break
while bw > strip * 3:
col = gray[y:y + bh, x:x + strip]
if col.std() < var_threshold:
x += strip
bw -= strip
else:
break
while bw > strip * 3:
col = gray[y:y + bh, x + bw - strip:x + bw]
if col.std() < var_threshold:
bw -= strip
else:
break
return (x, y, bw, bh)
def _ui_chrome_color(arr_rgb: np.ndarray) -> Optional[tuple]:
"""Estimate the screenshot's dominant UI chrome color from corner pixels."""
h, w = arr_rgb.shape[:2]
p = max(20, min(h, w) // 30)
corners = [
arr_rgb[:p, :p],
arr_rgb[:p, -p:],
arr_rgb[-p:, :p],
arr_rgb[-p:, -p:],
]
means = np.array([c.reshape(-1, 3).mean(axis=0) for c in corners])
centroid = means.mean(axis=0)
if float(np.max(np.linalg.norm(means - centroid, axis=1))) > 40.0:
return None
if all(c < 30 for c in centroid) or all(c > 225 for c in centroid):
return None
return tuple(float(c) for c in centroid)
def _expand_crop(arr_rgb: np.ndarray, sat: np.ndarray, val: np.ndarray,
text_mask: np.ndarray,
x: int, y: int, bw: int, bh: int,
ui_dark_max: int = 25,
ui_bright_min: int = 235,
ui_sat_max: int = 20,
chrome_color_tol: float = 35.0,
chrome_match_ratio: float = 0.6,
text_threshold: float = 0.30,
max_growth_ratio: float = 4.0) -> tuple:
"""Grow a crop bbox outward until it bumps into screenshot UI chrome."""
img_h, img_w = val.shape
strip = max(4, min(img_h, img_w) // 200)
orig_area = bw * bh
max_area = max_growth_ratio * orig_area
chrome = _ui_chrome_color(arr_rgb)
def is_ui_strip(s_strip: np.ndarray, v_strip: np.ndarray,
t_strip: np.ndarray, rgb_strip: np.ndarray) -> bool:
if v_strip.size == 0:
return True
if float(t_strip.mean()) > text_threshold:
return True
mean_v = float(v_strip.mean())
mean_s = float(s_strip.mean())
if mean_s < ui_sat_max and (mean_v < ui_dark_max or mean_v > ui_bright_min):
return True
if chrome is not None:
diff = rgb_strip.astype(np.float32) - np.array(chrome, dtype=np.float32)
per_pixel_dist = np.linalg.norm(diff, axis=-1)
match_ratio = float((per_pixel_dist < chrome_color_tol).mean())
if match_ratio > chrome_match_ratio:
return True
return False
def too_big() -> bool:
return bw * bh >= max_area
while y > 0 and not too_big():
new_y = max(0, y - strip)
delta = y - new_y
if delta == 0:
break
if not is_ui_strip(sat[new_y:y, x:x + bw],
val[new_y:y, x:x + bw],
text_mask[new_y:y, x:x + bw],
arr_rgb[new_y:y, x:x + bw]):
y = new_y
bh += delta
else:
break
while y + bh < img_h and not too_big():
new_bottom = min(img_h, y + bh + strip)
delta = new_bottom - (y + bh)
if delta == 0:
break
if not is_ui_strip(sat[y + bh:new_bottom, x:x + bw],
val[y + bh:new_bottom, x:x + bw],
text_mask[y + bh:new_bottom, x:x + bw],
arr_rgb[y + bh:new_bottom, x:x + bw]):
bh += delta
else:
break
while x > 0 and not too_big():
new_x = max(0, x - strip)
delta = x - new_x
if delta == 0:
break
if not is_ui_strip(sat[y:y + bh, new_x:x],
val[y:y + bh, new_x:x],
text_mask[y:y + bh, new_x:x],
arr_rgb[y:y + bh, new_x:x]):
x = new_x
bw += delta
else:
break
while x + bw < img_w and not too_big():
new_right = min(img_w, x + bw + strip)
delta = new_right - (x + bw)
if delta == 0:
break
if not is_ui_strip(sat[y:y + bh, x + bw:new_right],
val[y:y + bh, x + bw:new_right],
text_mask[y:y + bh, x + bw:new_right],
arr_rgb[y:y + bh, x + bw:new_right]):
bw += delta
else:
break
return (x, y, bw, bh)
def _is_repeating_pattern(gray: np.ndarray) -> bool:
"""Detect repeating background patterns (e.g. WhatsApp doodle wallpaper)."""
h, w = gray.shape
if h < 200 or w < 200:
return False
sample_w = w // 3
col = gray[:, :sample_w].astype(np.float32)
profile = col.mean(axis=1)
n = len(profile)
mean_p = profile.mean()
denom = np.sum((profile - mean_p) ** 2)
if denom < 1e-6:
return False
for lag in range(100, min(301, n // 3)):
corr = np.sum((profile[:n-lag] - mean_p) * (profile[lag:] - mean_p))
r = corr / denom
if r > 0.7:
return True
return False
# ──────────────────────────────────────────────────────────────
# Candidate generation: texture + contour
# ──────────────────────────────────────────────────────────────
def _texture_candidates(
gray: np.ndarray,
text_mask: np.ndarray,
min_area_ratio: float,
min_side_px: int,
) -> list[tuple]:
h, w = gray.shape
f = gray.astype(np.float32)
mu = cv2.boxFilter(f, -1, (15, 15))
mu2 = cv2.boxFilter(f * f, -1, (15, 15))
local_var = mu2 - mu * mu
has_texture = (local_var > 60.0).astype(np.uint8)
candidate = (has_texture & (1 - text_mask)).astype(np.uint8)
k = max(9, min(h, w) // 120)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k, k))
candidate = cv2.morphologyEx(candidate, cv2.MORPH_CLOSE, kernel)
num, labels, stats, _ = cv2.connectedComponentsWithStats(candidate, connectivity=8)
if num <= 1:
return []
min_area = min_area_ratio * h * w
results = []
for label_id in range(1, num):
lx = int(stats[label_id, cv2.CC_STAT_LEFT])
ly = int(stats[label_id, cv2.CC_STAT_TOP])
lw = int(stats[label_id, cv2.CC_STAT_WIDTH])
lh = int(stats[label_id, cv2.CC_STAT_HEIGHT])
pixel_area = int(stats[label_id, cv2.CC_STAT_AREA])
bbox_area = lw * lh
if lw < min_side_px or lh < min_side_px:
continue
if bbox_area < min_area:
continue
if lw / lh > 6 or lh / lw > 6:
continue
fill = pixel_area / bbox_area if bbox_area > 0 else 0
if fill < 0.20:
continue
results.append((lx, ly, lw, lh))
return results
def _contour_candidates(
gray: np.ndarray,
min_area_ratio: float,
min_side_px: int,
) -> list[tuple]:
h, w = gray.shape
blurred = cv2.bilateralFilter(gray, 9, 75, 75)
edges = cv2.Canny(blurred, 40, 120)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
edges = cv2.dilate(edges, kernel, iterations=2)
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
min_area = min_area_ratio * h * w
results = []
for cnt in contours:
cx, cy, cw, ch = cv2.boundingRect(cnt)
bbox_area = cw * ch
if bbox_area < min_area:
continue
if cw < min_side_px or ch < min_side_px:
continue
if cw / ch > 6 or ch / cw > 6:
continue
cnt_area = cv2.contourArea(cnt)
fill = cnt_area / bbox_area if bbox_area > 0 else 0
if fill < 0.40:
continue
results.append((cx, cy, cw, ch))
return results
def _merge_overlapping(rects: list[tuple], iou_thresh: float = 0.3) -> list[tuple]:
if not rects:
return []
rects = sorted(rects, key=lambda r: r[2] * r[3], reverse=True)
keep = []
for rect in rects:
rx, ry, rw, rh = rect
merged = False
for kx, ky, kw, kh in keep:
ix0 = max(rx, kx)
iy0 = max(ry, ky)
ix1 = min(rx + rw, kx + kw)
iy1 = min(ry + rh, ky + kh)
if ix1 > ix0 and iy1 > iy0:
inter = (ix1 - ix0) * (iy1 - iy0)
smaller_area = min(rw * rh, kw * kh)
if inter / smaller_area > iou_thresh:
merged = True
break
if not merged:
keep.append(rect)
return keep
def _merge_close_candidates(rects: list[tuple], img_h: int, img_w: int,
max_gap_ratio: float = 0.06,
min_overlap_ratio: float = 0.35) -> list[tuple]:
if not rects:
return []
max_gap = max_gap_ratio * min(img_h, img_w)
rects = list(rects)
def union(r1, r2):
x1, y1, w1, h1 = r1
x2, y2, w2, h2 = r2
x = min(x1, x2)
y = min(y1, y2)
return (x, y, max(x1 + w1, x2 + w2) - x, max(y1 + h1, y2 + h2) - y)
def should_merge(r1, r2):
x1, y1, w1, h1 = r1
x2, y2, w2, h2 = r2
h_overlap = max(0, min(x1 + w1, x2 + w2) - max(x1, x2))
v_overlap = max(0, min(y1 + h1, y2 + h2) - max(y1, y2))
v_gap = 0 if v_overlap > 0 else max(y1, y2) - min(y1 + h1, y2 + h2)
h_gap = 0 if h_overlap > 0 else max(x1, x2) - min(x1 + w1, x2 + w2)
if h_overlap > min_overlap_ratio * min(w1, w2) and v_gap < max_gap:
return True
if v_overlap > min_overlap_ratio * min(h1, h2) and h_gap < max_gap:
return True
return False
changed = True
while changed:
changed = False
for i in range(len(rects)):
for j in range(i + 1, len(rects)):
if should_merge(rects[i], rects[j]):
rects[i] = union(rects[i], rects[j])
rects.pop(j)
changed = True
break
if changed:
break
return rects
# ──────────────────────────────────────────────────────────────
# Reels UI detection
# ──────────────────────────────────────────────────────────────
def _find_reels_icons_white(gray: np.ndarray, w_img: int, h_img: int) -> list[dict]:
_, thresh = cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
icons = []
for c in contours:
area = cv2.contourArea(c)
if 50 < area < 5000:
x, y, cw, ch = cv2.boundingRect(c)
if 0.4 < cw / ch < 2.5 and cw >= 35 and ch >= 35:
M = cv2.moments(c)
if M["m00"] != 0:
icons.append({"cx": int(M["m10"] / M["m00"]),
"cy": int(M["m01"] / M["m00"])})
return icons
def _find_reels_icons_edges(gray: np.ndarray, w_img: int, h_img: int) -> list[dict]:
edges = cv2.Canny(gray, 50, 150)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
edges = cv2.dilate(edges, kernel, iterations=1)
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
strip_w = gray.shape[1]
icons = []
for c in contours:
area = cv2.contourArea(c)
if 100 < area < 8000:
x, y, cw, ch = cv2.boundingRect(c)
if (0.4 < cw / ch < 2.5 and cw >= 25 and ch >= 25
and x > strip_w * 0.3):
M = cv2.moments(c)
if M["m00"] != 0:
cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])
r = max(20, min(35, max(cw, ch)))
patch = gray[
max(0, cy - r):min(gray.shape[0], cy + r),
max(0, cx - r):min(gray.shape[1], cx + r),
]
bright_ratio = float((patch > 220).mean()) if patch.size else 0.0
dark_ratio = float((patch < 60).mean()) if patch.size else 0.0
if bright_ratio > 0.70 and dark_ratio > 0.05:
continue
icons.append({"cx": cx, "cy": cy})
return icons
def _check_vertical_alignment(icons: list[dict], w_img: int, h_img: int,
min_icons: int = 3) -> bool:
if len(icons) < min_icons:
return False
icons_sorted = sorted(icons, key=lambda ic: ic["cx"])
for i in range(len(icons_sorted) - min_icons + 1):
group = icons_sorted[i:i + min_icons]
max_cx = max(g["cx"] for g in group)
min_cx = min(g["cx"] for g in group)
if max_cx - min_cx < w_img * 0.025:
min_cy = min(g["cy"] for g in group)
max_cy = max(g["cy"] for g in group)
if max_cy - min_cy > h_img * 0.05:
return True
return False
def _is_reels_ui(image: np.ndarray) -> bool:
h, w = image.shape[:2]
if h / w < 1.7:
return False
margin = int(w * 0.15)
right_strip = image[int(h * 0.4):int(h * 0.9), w - margin:w]
gray = cv2.cvtColor(right_strip, cv2.COLOR_RGB2GRAY) if right_strip.ndim == 3 else right_strip
icons = _find_reels_icons_white(gray, w, h)
if _check_vertical_alignment(icons, gray.shape[1], gray.shape[0]):
return True
icons = _find_reels_icons_edges(gray, w, h)
return _check_vertical_alignment(icons, gray.shape[1], gray.shape[0])
# ──────────────────────────────────────────────────────────────
# Card β†’ embedded media refinement
# ──────────────────────────────────────────────────────────────
def _refine_to_saturated_media(
arr: np.ndarray,
crop_box: tuple,
text_boxes: Optional[list[tuple]] = None,
) -> tuple:
"""Tighten broad cards/messages to the embedded photo-like region."""
x, y, bw, bh = crop_box
sub = arr[y:y + bh, x:x + bw]
if sub.size == 0 or bw < 80 or bh < 80:
return crop_box
hsv = cv2.cvtColor(sub, cv2.COLOR_RGB2HSV)
sat = hsv[:, :, 1]
val = hsv[:, :, 2]
text_mask = np.zeros((bh, bw), dtype=np.uint8)
if text_boxes:
pad = max(4, min(bw, bh) // 200)
for (tx, ty, tw, th) in text_boxes:
ix0 = max(x, tx - pad)
iy0 = max(y, ty - pad)
ix1 = min(x + bw, tx + tw + pad)
iy1 = min(y + bh, ty + th + pad)
if ix1 > ix0 and iy1 > iy0:
text_mask[iy0 - y:iy1 - y, ix0 - x:ix1 - x] = 1
k = max(15, min(bw, bh) // 40)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k, k))
best = None
media_masks = [
((sat > 35) & (val > 35)).astype(np.uint8),
((val > 175) & (sat < 100)).astype(np.uint8),
]
for raw_mask in media_masks:
if float(raw_mask.mean()) < 0.08:
continue
mask = cv2.morphologyEx(raw_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
mask = cv2.morphologyEx(
mask,
cv2.MORPH_OPEN,
cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7)),
)
num, _, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
for label_id in range(1, num):
lx = int(stats[label_id, cv2.CC_STAT_LEFT])
ly = int(stats[label_id, cv2.CC_STAT_TOP])
lw = int(stats[label_id, cv2.CC_STAT_WIDTH])
lh = int(stats[label_id, cv2.CC_STAT_HEIGHT])
area = int(stats[label_id, cv2.CC_STAT_AREA])
bbox_area = lw * lh
if bbox_area <= 0:
continue
fill = area / bbox_area
if lw < 0.75 * bw or lh < 0.25 * bh:
continue
if area < 0.10 * bw * bh or fill < 0.45:
continue
text_density = float(text_mask[ly:ly + lh, lx:lx + lw].mean())
if text_density > 0.06:
continue
if best is None or area > best[-1]:
best = (lx, ly, lw, lh, area)
if best is None:
return crop_box
lx, ly, lw, lh, _ = best
if lx < 0.03 * bw and lx + lw < 0.92 * bw:
return crop_box
nearly_full_width = lw > 0.94 * bw and lx < 0.03 * bw
nearly_full_height = lh > 0.88 * bh and ly < 0.06 * bh
if nearly_full_width and nearly_full_height:
return crop_box
if lw < 80 or lh < 80 or lw * lh < 0.08 * bw * bh:
return crop_box
def removed_band_is_ui(s_band: np.ndarray, v_band: np.ndarray, t_band: np.ndarray) -> bool:
if v_band.size == 0:
return False
text_density = float(t_band.mean()) if t_band.size else 0.0
mean_v = float(v_band.mean())
mean_s = float(s_band.mean())
std_v = float(v_band.std())
if text_density > 0.04:
return True
if mean_v < 70.0 and std_v < 20.0:
return True
if mean_s < 35.0 and (mean_v > 215.0 or mean_v < 45.0) and std_v < 25.0:
return True
return False
removed_ui = False
if ly > 0.06 * bh:
removed_ui = removed_ui or removed_band_is_ui(sat[:ly, :], val[:ly, :], text_mask[:ly, :])
if ly + lh < 0.92 * bh:
removed_ui = removed_ui or removed_band_is_ui(
sat[ly + lh:, :], val[ly + lh:, :], text_mask[ly + lh:, :]
)
if lx > 0.06 * bw:
removed_ui = removed_ui or removed_band_is_ui(sat[:, :lx], val[:, :lx], text_mask[:, :lx])
if lx + lw < 0.94 * bw:
removed_ui = removed_ui or removed_band_is_ui(
sat[:, lx + lw:], val[:, lx + lw:], text_mask[:, lx + lw:]
)
if not removed_ui:
return crop_box
return (x + lx, y + ly, lw, lh)
def _trim_full_width_ui_chrome(arr: np.ndarray, crop_box: tuple) -> tuple:
"""Trim app chrome from full-width social post candidates."""
x, y, bw, bh = crop_box
sub = arr[y:y + bh, x:x + bw]
if sub.size == 0 or bw < 120 or bh < 120:
return crop_box
hsv = cv2.cvtColor(sub, cv2.COLOR_RGB2HSV)
sat = hsv[:, :, 1]
val = hsv[:, :, 2]
text_mask = np.zeros((bh, bw), dtype=np.uint8)
sub_boxes = detect_text_boxes(sub)
if sub_boxes:
pad = max(4, min(bw, bh) // 200)
for (tx, ty, tw, th) in sub_boxes:
x0 = max(0, tx - pad)
y0 = max(0, ty - pad)
x1 = min(bw, tx + tw + pad)
y1 = min(bh, ty + th + pad)
text_mask[y0:y1, x0:x1] = 1
masks = [
(((sat > 35) & (val > 35)).astype(np.float32), 0.45),
(((val > 175) & (sat < 100)).astype(np.float32), 0.15),
]
trim_candidates = []
def chrome_band_score(v_band: np.ndarray, t_band: np.ndarray) -> tuple[bool, bool]:
if v_band.size == 0:
return False, False
text_dense = float(t_band.mean()) > 0.04 if t_band.size else False
flat_dark = float(v_band.mean()) < 70.0 and float(v_band.std()) < 20.0
return text_dense or flat_dark, flat_dark
def accept_trim(rx: int, ry: int, rw: int, rh: int) -> bool:
if rh < 80 or rw < 80:
return False
retained_h = rh / float(bh)
left_inset = rx > 0.025 * bw
right_inset = rx + rw < 0.975 * bw
side_inset = left_inset or right_inset
top_trimmed = ry > 0.06 * bh
bottom_trimmed = ry + rh < 0.92 * bh
top_ok, _ = chrome_band_score(val[:ry, :], text_mask[:ry, :]) if top_trimmed else (False, False)
bottom_ok, _ = chrome_band_score(
val[ry + rh:, :], text_mask[ry + rh:, :]
) if bottom_trimmed else (False, False)
side_ok = False
if left_inset:
_, side_ok = chrome_band_score(val[ry:ry + rh, :rx], text_mask[ry:ry + rh, :rx])
if right_inset:
_, right_flat = chrome_band_score(
val[ry:ry + rh, rx + rw:], text_mask[ry:ry + rh, rx + rw:]
)
side_ok = side_ok or right_flat
if not (top_ok or bottom_ok or side_ok):
return False
top_frac = ry / float(bh)
bottom_frac = (bh - (ry + rh)) / float(bh)
large_one_sided_chrome = side_ok and (
(top_ok and top_frac > 0.08) or (bottom_ok and bottom_frac > 0.18)
)
if retained_h < 0.75 and not ((top_ok and bottom_ok) or large_one_sided_chrome):
return False
if not side_inset and retained_h < 0.75:
return False
return True
best_span = None
window = max(9, bh // 80)
kernel_1d = np.ones(window, dtype=np.float32) / window
for mask, threshold in masks:
row_score = np.convolve(mask.mean(axis=1), kernel_1d, mode="same")
is_media = row_score > threshold
start = None
for idx, flag in enumerate(is_media):
if flag and start is None:
start = idx
if start is not None and (not flag or idx == bh - 1):
end = idx if not flag else idx + 1
if end - start > 0.20 * bh:
score = float(row_score[start:end].mean()) * (end - start)
if best_span is None or score > best_span[2]:
best_span = (start, end, score)
start = None
if best_span is not None:
top, bottom, _ = best_span
pad = max(2, bh // 250)
top = max(0, top - pad)
bottom = min(bh, bottom + pad)
if (top > 0.06 * bh or bottom < 0.92 * bh) and accept_trim(0, top, bw, bottom - top):
trim_candidates.append((x, y + top, bw, bottom - top))
gray = cv2.cvtColor(sub, cv2.COLOR_RGB2GRAY)
blurred = cv2.bilateralFilter(gray, 9, 75, 75)
edges = cv2.Canny(blurred, 40, 120)
edges = cv2.dilate(edges, cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)), iterations=2)
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
rects = []
for cnt in contours:
rx, ry, rw, rh = cv2.boundingRect(cnt)
area = rw * rh
if area < 0.05 * bw * bh or rw < 0.35 * bw or rh < 0.20 * bh:
continue
fill = cv2.contourArea(cnt) / area if area else 0.0
if fill < 0.10:
continue
rects.append((rx, ry, rw, rh))
if rects:
rects = _merge_close_candidates(rects, bh, bw, max_gap_ratio=0.12, min_overlap_ratio=0.10)
best = max(rects, key=lambda r: r[2] * r[3])
rx, ry, rw, rh = best
if rw * rh >= 0.12 * bw * bh:
if accept_trim(rx, ry, rw, rh):
trim_candidates.append((x + rx, y + ry, rw, rh))
if not trim_candidates:
return crop_box
return max(trim_candidates, key=lambda r: r[2] * r[3])
def _second_pass_refine(arr: np.ndarray, crop_box: tuple) -> tuple:
"""Trim text bands from the top and/or bottom of a crop."""
x, y, bw, bh = crop_box
sub = arr[y:y + bh, x:x + bw]
if sub.size == 0:
return crop_box
h, w = sub.shape[:2]
if h < 100:
return crop_box
sub_boxes = detect_text_boxes(sub)
if not sub_boxes:
return crop_box
text_mask = np.zeros((h, w), dtype=np.float32)
pad = max(4, min(h, w) // 200)
for (bx, by_, bw_, bh_) in sub_boxes:
x0 = max(0, bx - pad)
y0 = max(0, by_ - pad)
x1 = min(w, bx + bw_ + pad)
y1 = min(h, by_ + bh_ + pad)
text_mask[y0:y1, x0:x1] = 1.0
row_text = text_mask.mean(axis=1)
window = max(20, h // 30)
kernel_1d = np.ones(window, dtype=np.float32) / window
smooth = np.convolve(row_text, kernel_1d, mode="same")
is_text = smooth > 0.06
margin = int(0.10 * h)
top_trim = 0
start_top = 0
for r in range(margin):
if is_text[r]:
start_top = r
break
else:
start_top = -1
if start_top != -1:
top_trim = start_top
for r in range(start_top, h):
if not is_text[r]:
break
top_trim = r + 1
gap_limit = max(15, h // 40)
scan = top_trim
while scan < min(h, top_trim + gap_limit):
if is_text[scan]:
for r in range(scan, h):
if not is_text[r]:
break
top_trim = r + 1
scan = top_trim
else:
scan += 1
bottom_trim = 0
start_bottom = -1
for r in range(h - 1, h - 1 - margin, -1):
if is_text[r]:
start_bottom = r
break
if start_bottom != -1:
bottom_trim = h - start_bottom - 1
for r in range(start_bottom, -1, -1):
if not is_text[r]:
break
bottom_trim = h - r
gap_limit = max(15, h // 40)
scan = h - bottom_trim - 1
while scan >= max(0, h - bottom_trim - gap_limit):
if is_text[scan]:
for r in range(scan, -1, -1):
if not is_text[r]:
break
bottom_trim = h - r
scan = h - bottom_trim - 1
else:
scan -= 1
min_trim_px = int(0.08 * h)
if top_trim < min_trim_px:
top_trim = 0
if bottom_trim < min_trim_px:
bottom_trim = 0
if top_trim == 0 and bottom_trim == 0:
return crop_box
total_trim = top_trim + bottom_trim
if total_trim > 0.55 * h:
scale = (0.55 * h) / total_trim
top_trim = int(top_trim * scale)
bottom_trim = int(bottom_trim * scale)
new_top = top_trim
new_bottom = h - bottom_trim
new_h = new_bottom - new_top
if new_h < 80:
return crop_box
return (x, y + new_top, bw, new_h)
# ──────────────────────────────────────────────────────────────
# Embedded image search
# ──────────────────────────────────────────────────────────────
def _find_embedded_image(
image: np.ndarray,
text_boxes: list[tuple],
min_area_ratio: float = 0.05,
min_side_px: int = 80,
gen_min_area_ratio: float = 0.04,
) -> list[tuple]:
"""Find embedded image regions.
`gen_min_area_ratio` controls the minimum size a *raw* texture/contour
candidate must reach to be considered for merging. `min_area_ratio` is the
minimum for the *final* (post-merge) crop. The split lets small adjacent
pieces (e.g. two side-by-side video thumbnails) be detected individually,
merged, and then evaluated as one larger region.
"""
h, w = image.shape[:2]
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) if image.ndim == 3 else image
if image.ndim == 3:
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
sat = hsv[:, :, 1]
val = hsv[:, :, 2]
else:
sat = np.zeros_like(gray)
val = gray
text_mask = np.zeros((h, w), dtype=np.uint8)
pad = max(6, min(h, w) // 200)
for (bx, by, bw, bh) in text_boxes:
x0 = max(0, bx - pad)
y0 = max(0, by - pad)
x1 = min(w, bx + bw + pad)
y1 = min(h, by + bh + pad)
text_mask[y0:y1, x0:x1] = 1
has_wallpaper = _is_repeating_pattern(gray)
candidates = []
candidates.extend(_texture_candidates(gray, text_mask,
gen_min_area_ratio, min_side_px))
candidates.extend(_contour_candidates(gray, gen_min_area_ratio, min_side_px))
if not candidates:
return []
# Drop candidates that already exceed the final max area before merging,
# so a giant "whole-image" component doesn't shadow legitimate sub-region
# candidates during overlap merging.
pre_max = 0.92 * h * w
candidates = [c for c in candidates if c[2] * c[3] <= pre_max]
if not candidates:
return []
candidates = _merge_overlapping(candidates)
candidates = _merge_close_candidates(candidates, h, w)
strip = max(4, min(h, w) // 200)
refined = []
for (cx, cy, cw, ch) in candidates:
rx, ry, rw, rh = _refine_crop(gray, cx, cy, cw, ch, strip=strip)
if rw < min_side_px or rh < min_side_px:
continue
rx, ry, rw, rh = _expand_crop(image, sat, val, text_mask,
rx, ry, rw, rh)
refined.append((rx, ry, rw, rh))
if not refined:
return []
img_area = h * w
max_area_ratio = 0.80 if has_wallpaper else 0.92
valid_crops = []
for r in refined:
area = r[2] * r[3]
if min_area_ratio * img_area <= area <= max_area_ratio * img_area:
valid_crops.append(r)
valid_crops = sorted(valid_crops, key=lambda r: r[1])
return valid_crops
# ──────────────────────────────────────────────────────────────
# Entry point
# ──────────────────────────────────────────────────────────────
def preprocess(pil_image: Image.Image) -> PreprocessResult:
# Honor EXIF orientation (phone photos often store landscape pixels with a
# rotation tag) before any geometry-dependent checks run.
pil_image = ImageOps.exif_transpose(pil_image)
pil_image = pil_image.convert("RGB")
arr = np.array(pil_image)
h, w = arr.shape[:2]
tier1 = _is_candidate_screenshot(arr)
if not tier1["is_candidate"]:
return PreprocessResult(
image=pil_image,
status="full",
crop_box=None,
text_fraction=0.0,
debug={"tier": 1, **tier1},
)
boxes = detect_text_boxes(arr)
text_fraction = _box_union_fraction(boxes, h, w)
if _is_reels_ui(arr):
cw = int(w * 0.85)
ch = int(h * 0.75)
reels_crop = (0, 0, cw, ch)
return PreprocessResult(
image=pil_image.crop((0, 0, cw, ch)),
status="cropped",
crop_box=reels_crop,
text_fraction=text_fraction,
debug={"tier": 2, "n_text_boxes": len(boxes), "reels_ui": True, **tier1},
)
embedded_candidates = _find_embedded_image(
arr, boxes, min_area_ratio=EMBEDDED_MIN_AREA
)
if embedded_candidates:
final_crops = []
cropped_images = []
for emb in embedded_candidates:
refined_media = _refine_to_saturated_media(arr, emb, boxes)
if refined_media == emb:
ex, _, ew, _ = emb
if ex <= 2 and ew >= w - 4:
emb = _trim_full_width_ui_chrome(arr, emb)
else:
emb = _second_pass_refine(arr, emb)
else:
emb = refined_media
x, y, bw, bh = emb
final_crops.append((x, y, bw, bh))
cropped_images.append(pil_image.crop((x, y, x + bw, y + bh)))
total_crop_area = sum(bw * bh for _, _, bw, bh in final_crops)
crop_pct = round(100.0 * total_crop_area / (h * w), 1)
crop_arr = np.array(cropped_images[0])
crop_boxes = detect_text_boxes(crop_arr)
crop_h, crop_w = crop_arr.shape[:2]
crop_text_frac = _box_union_fraction(crop_boxes, crop_h, crop_w)
crop_hsv = cv2.cvtColor(crop_arr, cv2.COLOR_RGB2HSV)
mean_saturation = float(crop_hsv[:, :, 1].mean())
is_document = (
(crop_text_frac > 0.15 and mean_saturation < 30)
or crop_text_frac > 0.40
)
if is_document:
return PreprocessResult(
image=None,
status="text_only",
crop_box=None,
text_fraction=text_fraction,
debug={"tier": 2, "n_text_boxes": len(boxes),
"crop_text_frac": f"{crop_text_frac:.1%}",
"crop_pct": f"{crop_pct}%", **tier1},
)
return PreprocessResult(
image=cropped_images if len(cropped_images) > 1 else cropped_images[0],
status="cropped",
crop_box=final_crops if len(final_crops) > 1 else final_crops[0],
text_fraction=text_fraction,
debug={"tier": 2, "n_text_boxes": len(boxes),
"crop_pct": f"{crop_pct}%", "n_crops": len(final_crops), **tier1},
)
if text_fraction > TEXT_ONLY_FRACTION:
return PreprocessResult(
image=None,
status="text_only",
crop_box=None,
text_fraction=text_fraction,
debug={"tier": 2, "n_text_boxes": len(boxes), **tier1},
)
return PreprocessResult(
image=pil_image,
status="full",
crop_box=None,
text_fraction=text_fraction,
debug={"tier": 2, "fallback": True, **tier1},
)