OpenFakeDemo / app /main.py
vicliv's picture
Improve report form: video support, expanded reasons with detail fields, professional disclaimer wording, motion blur limitation
f95013c
import io
import json
import os
import random
import tempfile
import uuid
from datetime import datetime, timezone
from pathlib import Path
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.staticfiles import StaticFiles
from PIL import Image, ImageOps
from .model import load_detector, predict_image
from .screenshot import preprocess
from .video import sample_frames
MAX_IMAGE_SIZE_MB = 50
MAX_VIDEO_SIZE_MB = 300
N_VIDEO_FRAMES = 5
IMAGE_TYPES = {"image/jpeg", "image/jpg", "image/png", "image/webp"}
VIDEO_TYPES = {"video/mp4", "video/quicktime", "video/webm", "video/x-matroska"}
HF_REPORT_REPO = os.environ.get("HF_REPORT_REPO", "ComplexDataLab/openfake-reports")
HF_TOKEN = os.environ.get("HF_TOKEN")
app = FastAPI(title="Deepfake Detector")
@app.on_event("startup")
def warmup():
load_detector()
def _predict_with_preprocess(image: Image.Image) -> dict:
"""Run the screenshot-aware prediction pipeline on a single image.
Returns a dict with p_fake, the preprocessing status, and the crop boxes
in the EXIF-rotated coordinate frame so the frontend can overlay them on
the user-visible image.
"""
# Apply EXIF rotation up front so crop_box coords and image_size are in
# the same frame as the browser-rendered image.
image = ImageOps.exif_transpose(image)
width, height = image.size
result = preprocess(image)
crop_box = None
if result.crop_box is not None:
boxes = result.crop_box if isinstance(result.crop_box, list) else [result.crop_box]
crop_box = [list(b) for b in boxes]
base = {
"preprocess_status": result.status,
"image_size": [width, height],
"crop_box": crop_box,
}
if result.status == "cropped":
crops = result.image if isinstance(result.image, list) else [result.image]
probs = [predict_image(c) for c in crops]
p_fake = sum(probs) / len(probs)
return {**base, "p_fake": p_fake, "n_crops": len(crops)}
if result.status == "text_only":
raw_p_fake = predict_image(image)
# The detector is unreliable on pure-text screenshots and tends to
# flag them as AI-generated. If it leans "AI", soften to uncertain;
# if it leans "real", keep the score.
if raw_p_fake > 0.5:
p_fake = random.uniform(0.4, 0.6)
else:
p_fake = raw_p_fake
return {**base, "p_fake": p_fake, "raw_p_fake": raw_p_fake}
p_fake = predict_image(image)
return {**base, "p_fake": p_fake}
@app.post("/api/predict")
async def predict(file: UploadFile = File(...)):
content_type = (file.content_type or "").lower()
raw = await file.read()
size_mb = len(raw) / (1024 * 1024)
if content_type in IMAGE_TYPES:
if size_mb > MAX_IMAGE_SIZE_MB:
raise HTTPException(413, f"Image exceeds {MAX_IMAGE_SIZE_MB} MB")
try:
image = Image.open(io.BytesIO(raw))
except Exception:
raise HTTPException(400, "Invalid image")
pred = _predict_with_preprocess(image)
p_fake = pred["p_fake"]
return {
"media_type": "image",
"p_fake": p_fake,
"reliability": 1.0 - p_fake,
"n_frames": 1,
**{k: v for k, v in pred.items() if k != "p_fake"},
}
if content_type in VIDEO_TYPES:
if size_mb > MAX_VIDEO_SIZE_MB:
raise HTTPException(413, f"Video exceeds {MAX_VIDEO_SIZE_MB} MB")
suffix = Path(file.filename or "video.mp4").suffix or ".mp4"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(raw)
tmp_path = tmp.name
try:
frames = sample_frames(tmp_path, N_VIDEO_FRAMES)
except ValueError as e:
raise HTTPException(400, str(e))
finally:
try:
Path(tmp_path).unlink(missing_ok=True)
except Exception:
pass
probs = [predict_image(f) for f in frames]
p_fake = sum(probs) / len(probs)
return {
"media_type": "video",
"p_fake": p_fake,
"reliability": 1.0 - p_fake,
"n_frames": len(frames),
"frame_probs": probs,
}
raise HTTPException(415, f"Unsupported media type: {content_type}")
@app.post("/api/report")
async def report(
file: UploadFile = File(...),
is_real: str = Form(...),
reason: str = Form(...),
reason_other: str = Form(""),
reason_details: str = Form(""),
comment: str = Form(""),
p_fake: float = Form(...),
consent: str = Form(...),
):
"""Save an error report (form answers + media file) to a Hugging Face dataset repo."""
if consent != "true":
raise HTTPException(400, "Consent to save the file is required.")
if not HF_TOKEN:
raise HTTPException(
503, "Reporting is not configured (missing HF_TOKEN)."
)
# Read the uploaded file
raw = await file.read()
content_type = (file.content_type or "").lower()
if content_type not in IMAGE_TYPES | VIDEO_TYPES:
raise HTTPException(415, "Unsupported file type for reporting.")
# Build report payload
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
short_id = uuid.uuid4().hex[:8]
folder_name = f"{ts}_{short_id}"
report_data = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"is_real": is_real,
"reason": reason,
"reason_other": reason_other if reason == "other" else "",
"reason_details": reason_details,
"comment": comment,
"p_fake": p_fake,
"content_type": content_type,
"original_filename": file.filename or "unknown",
}
# Write to a temp directory then upload to HF
with tempfile.TemporaryDirectory() as tmpdir:
report_dir = Path(tmpdir) / folder_name
report_dir.mkdir()
# Save report JSON
(report_dir / "report.json").write_text(
json.dumps(report_data, indent=2, ensure_ascii=False)
)
# Save media file with original extension
ext = Path(file.filename or "file.bin").suffix or ".bin"
(report_dir / f"media{ext}").write_bytes(raw)
# Upload to HF dataset repo
try:
from huggingface_hub import HfApi
api = HfApi(token=HF_TOKEN)
api.upload_folder(
folder_path=str(report_dir),
path_in_repo=f"reports/{folder_name}",
repo_id=HF_REPORT_REPO,
repo_type="dataset",
)
except Exception as e:
raise HTTPException(500, f"Failed to upload report: {e}")
return {"status": "ok"}
static_dir = Path(__file__).parent / "static"
app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static")