File size: 2,356 Bytes
c3e4914
 
 
 
 
 
 
 
 
 
 
 
d2354a4
c3e4914
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ceee32
c3e4914
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
from pathlib import Path

import torch
import torch.nn as nn
from PIL import Image
from safetensors.torch import load_file
from transformers import AutoImageProcessor, AutoModelForImageClassification

HF_NAME = "microsoft/swinv2-base-patch4-window16-256"
WEIGHTS_PATH = Path(__file__).parent.parent / "model.safetensors"
NUM_LABELS = 2
SOFTMAX_TEMPERATURE = 2.0

_device = torch.device("cpu")
_processor = None
_model = None


def _strip_prefixes(state_dict: dict) -> dict:
    """Strip common wrapper prefixes (DDP, Lightning) from state dict keys."""
    prefixes = ("module.", "model.")
    cleaned = {}
    for k, v in state_dict.items():
        new_k = k
        for p in prefixes:
            if new_k.startswith(p):
                new_k = new_k[len(p):]
                break
        cleaned[new_k] = v
    return cleaned


def load_detector():
    global _processor, _model
    if _model is not None:
        return _processor, _model

    cache_dir = os.environ.get("HF_HOME", "/tmp/hf-cache")

    processor = AutoImageProcessor.from_pretrained(
        HF_NAME, cache_dir=cache_dir
    )
    model = AutoModelForImageClassification.from_pretrained(
        HF_NAME, cache_dir=cache_dir
    )

    model.num_labels = NUM_LABELS
    model.config.num_labels = NUM_LABELS
    model.config.id2label = {0: "real", 1: "fake"}
    model.config.label2id = {"real": 0, "fake": 1}
    model.classifier = nn.Linear(model.swinv2.num_features, NUM_LABELS)

    state_dict = load_file(str(WEIGHTS_PATH))
    state_dict = _strip_prefixes(state_dict)
    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    if missing:
        print(f"[load_detector] missing keys ({len(missing)}): {missing[:10]}")
    if unexpected:
        print(f"[load_detector] unexpected keys ({len(unexpected)}): {unexpected[:10]}")

    model.eval().to(_device)
    _processor = processor
    _model = model
    return _processor, _model


@torch.no_grad()
def predict_image(image: Image.Image) -> float:
    """Returns P(fake) in [0, 1]."""
    processor, model = load_detector()
    if image.mode != "RGB":
        image = image.convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(_device)
    logits = model(**inputs).logits
    probs = torch.softmax(logits / SOFTMAX_TEMPERATURE, dim=-1)
    return float(probs[0, 1].item())