Spaces:
Running
Running
File size: 2,356 Bytes
c3e4914 d2354a4 c3e4914 0ceee32 c3e4914 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | import os
from pathlib import Path
import torch
import torch.nn as nn
from PIL import Image
from safetensors.torch import load_file
from transformers import AutoImageProcessor, AutoModelForImageClassification
HF_NAME = "microsoft/swinv2-base-patch4-window16-256"
WEIGHTS_PATH = Path(__file__).parent.parent / "model.safetensors"
NUM_LABELS = 2
SOFTMAX_TEMPERATURE = 2.0
_device = torch.device("cpu")
_processor = None
_model = None
def _strip_prefixes(state_dict: dict) -> dict:
"""Strip common wrapper prefixes (DDP, Lightning) from state dict keys."""
prefixes = ("module.", "model.")
cleaned = {}
for k, v in state_dict.items():
new_k = k
for p in prefixes:
if new_k.startswith(p):
new_k = new_k[len(p):]
break
cleaned[new_k] = v
return cleaned
def load_detector():
global _processor, _model
if _model is not None:
return _processor, _model
cache_dir = os.environ.get("HF_HOME", "/tmp/hf-cache")
processor = AutoImageProcessor.from_pretrained(
HF_NAME, cache_dir=cache_dir
)
model = AutoModelForImageClassification.from_pretrained(
HF_NAME, cache_dir=cache_dir
)
model.num_labels = NUM_LABELS
model.config.num_labels = NUM_LABELS
model.config.id2label = {0: "real", 1: "fake"}
model.config.label2id = {"real": 0, "fake": 1}
model.classifier = nn.Linear(model.swinv2.num_features, NUM_LABELS)
state_dict = load_file(str(WEIGHTS_PATH))
state_dict = _strip_prefixes(state_dict)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f"[load_detector] missing keys ({len(missing)}): {missing[:10]}")
if unexpected:
print(f"[load_detector] unexpected keys ({len(unexpected)}): {unexpected[:10]}")
model.eval().to(_device)
_processor = processor
_model = model
return _processor, _model
@torch.no_grad()
def predict_image(image: Image.Image) -> float:
"""Returns P(fake) in [0, 1]."""
processor, model = load_detector()
if image.mode != "RGB":
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(_device)
logits = model(**inputs).logits
probs = torch.softmax(logits / SOFTMAX_TEMPERATURE, dim=-1)
return float(probs[0, 1].item())
|