WorkTimer's picture
Add PanCancerSeg Gradio inference app
ce1057b verified
"""Visualization helpers for single-case PanCancerSeg inference."""
from pathlib import Path
import cv2
import numpy as np
import SimpleITK as sitk
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
DEFAULT_OVERLAY_COLOR = (255, 0, 0)
def preprocess_volume(volume, wl, ww):
"""Apply CT windowing and return uint8 data in [0, 255]."""
volume = volume.astype(np.float32, copy=False)
lower_bound = wl - ww / 2
upper_bound = wl + ww / 2
clipped = np.clip(volume, lower_bound, upper_bound)
return _normalize_to_uint8(clipped)
def overlay_mask(gray_slice, mask_slice, color=DEFAULT_OVERLAY_COLOR, alpha=0.5):
"""Apply a semi-transparent RGB overlay to one grayscale slice."""
gray_slice = np.asarray(gray_slice, dtype=np.uint8)
if gray_slice.ndim != 2:
raise ValueError(f"Expected a 2D grayscale slice, got shape {gray_slice.shape}")
rgb = np.stack([gray_slice] * 3, axis=-1)
mask = mask_slice > 0
if not np.any(mask):
return rgb
out = rgb.copy()
color_arr = np.asarray(color, dtype=np.float32)
blended = out[mask].astype(np.float32) * (1 - alpha) + color_arr * alpha
out[mask] = np.clip(blended, 0, 255).astype(np.uint8)
return out
def find_key_slices(mask_vol):
"""Return named representative z-slices for a mask in [z, y, x] order."""
if mask_vol.ndim != 3:
raise ValueError(f"Expected a 3D mask volume, got shape {mask_vol.shape}")
depth = mask_vol.shape[0]
if depth == 0:
raise ValueError("Cannot select key slices from an empty z-dimension")
mask = mask_vol > 0
if np.any(mask):
z_indices = np.where(np.any(mask, axis=(1, 2)))[0]
areas = mask.reshape(depth, -1).sum(axis=1)
coords = np.argwhere(mask)
centroid_z = int(round(float(coords[:, 0].mean())))
min_z = int(z_indices.min())
max_z = int(z_indices.max())
return {
"centroid": _clip_slice(centroid_z, depth),
"max_area": int(areas.argmax()),
"extent25": _clip_slice(round(min_z + 0.25 * (max_z - min_z)), depth),
"extent75": _clip_slice(round(min_z + 0.75 * (max_z - min_z)), depth),
}
middle = depth // 2
offset = max(1, depth // 10)
return {
"centroid": middle,
"max_area": _clip_slice(middle - offset, depth),
"extent25": _clip_slice(middle + offset, depth),
"extent75": _clip_slice(middle + 2 * offset, depth),
}
def generate_slice_images(
image_uint8,
mask_vol,
output_dir,
case_name,
color=DEFAULT_OVERLAY_COLOR,
alpha=0.5,
):
"""Save side-by-side PNGs for representative slices."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
key_slices = find_key_slices(mask_vol)
outputs = {}
for label, z_idx in key_slices.items():
gray_slice = image_uint8[z_idx]
mask_slice = mask_vol[z_idx] > 0
overlay = overlay_mask(gray_slice, mask_slice, color=color, alpha=alpha)
fig, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=150)
axes[0].imshow(gray_slice, cmap="gray", vmin=0, vmax=255)
axes[0].set_title("Image")
axes[0].axis("off")
axes[1].imshow(overlay)
axes[1].set_title("Segmentation overlay")
axes[1].axis("off")
fig.suptitle(f"{case_name} | z={z_idx}")
fig.tight_layout()
out_path = output_dir / f"{case_name}_slice_{label}.png"
fig.savefig(out_path, dpi=150, bbox_inches="tight", facecolor="white")
plt.close(fig)
outputs[label] = out_path
return outputs
def generate_video(
image_uint8,
mask_vol,
output_dir,
case_name,
cancer_type,
color=DEFAULT_OVERLAY_COLOR,
alpha=0.5,
fps=10,
):
"""Generate an MP4 scroll-through overlay video."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
video_path = output_dir / f"{case_name}_overlay.mp4"
start_z, end_z = _video_z_range(mask_vol)
first_frame = _make_video_frame(
image_uint8[start_z],
mask_vol[start_z],
color,
alpha,
start_z,
image_uint8.shape[0],
cancer_type,
)
height, width = first_frame.shape[:2]
writer = _open_video_writer(video_path, fps, width, height)
# Frame annotations are drawn in RGB space; convert only when writing to OpenCV.
writer.write(cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
for z_idx in range(start_z + 1, end_z + 1):
frame = _make_video_frame(
image_uint8[z_idx],
mask_vol[z_idx],
color,
alpha,
z_idx,
image_uint8.shape[0],
cancer_type,
)
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
writer.release()
return video_path
def generate_outputs(
image_path,
mask_path,
output_dir,
case_name,
cancer_type,
wl,
ww,
color=DEFAULT_OVERLAY_COLOR,
alpha=0.5,
fps=10,
):
"""Read image and mask volumes, then write PNG previews and MP4 video."""
image = sitk.ReadImage(str(image_path))
mask = sitk.ReadImage(str(mask_path))
image_vol = sitk.GetArrayFromImage(image)
mask_vol = sitk.GetArrayFromImage(mask)
if image_vol.shape != mask_vol.shape:
raise ValueError(
"Image and segmentation shapes do not match: "
f"image={image_vol.shape}, segmentation={mask_vol.shape}. "
"Both arrays are expected in [z, y, x] order."
)
image_uint8 = preprocess_volume(image_vol, wl, ww)
slice_paths = generate_slice_images(
image_uint8,
mask_vol,
output_dir,
case_name,
color,
alpha,
)
video_path = generate_video(
image_uint8,
mask_vol,
output_dir,
case_name,
cancer_type,
color,
alpha,
fps,
)
return {"slices": slice_paths, "video": video_path}
def _normalize_to_uint8(volume):
v_min = float(np.min(volume))
v_max = float(np.max(volume))
if not np.isfinite(v_min) or not np.isfinite(v_max) or v_max <= v_min:
return np.zeros(volume.shape, dtype=np.uint8)
normalized = (volume - v_min) / (v_max - v_min) * 255.0
return np.clip(normalized, 0, 255).astype(np.uint8)
def _clip_slice(index, depth):
return int(np.clip(index, 0, depth - 1))
def _video_z_range(mask_vol, padding=10, empty_window=80):
depth = mask_vol.shape[0]
mask = mask_vol > 0
if np.any(mask):
z_indices = np.where(np.any(mask, axis=(1, 2)))[0]
return (
max(0, int(z_indices.min()) - padding),
min(depth - 1, int(z_indices.max()) + padding),
)
if depth <= empty_window:
return 0, depth - 1
middle = depth // 2
half = empty_window // 2
return max(0, middle - half), min(depth - 1, middle + half)
def _make_video_frame(gray_slice, mask_slice, color, alpha, z_idx, depth, cancer_type):
frame = overlay_mask(gray_slice, mask_slice, color=color, alpha=alpha)
frame = _upscale_if_small(frame)
annotation = f"Slice {z_idx + 1}/{depth} | {cancer_type}"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = max(0.6, min(frame.shape[:2]) / 900)
thickness = max(1, int(round(font_scale * 2)))
text_size, baseline = cv2.getTextSize(annotation, font, font_scale, thickness)
x, y = 12, 12 + text_size[1]
cv2.rectangle(
frame,
(x - 6, y - text_size[1] - 6),
(x + text_size[0] + 6, y + baseline + 6),
(0, 0, 0),
thickness=-1,
)
cv2.putText(frame, annotation, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
return frame
def _upscale_if_small(frame, min_short_side=512):
height, width = frame.shape[:2]
short_side = min(height, width)
if short_side >= min_short_side:
return frame
scale = min_short_side / short_side
new_size = (int(round(width * scale)), int(round(height * scale)))
return cv2.resize(frame, new_size, interpolation=cv2.INTER_LINEAR)
def _open_video_writer(video_path, fps, width, height):
attempts = [
("avc1", "H.264/avc1"),
("mp4v", "MPEG-4/mp4v"),
]
for fourcc_text, label in attempts:
fourcc = cv2.VideoWriter_fourcc(*fourcc_text)
writer = cv2.VideoWriter(str(video_path), fourcc, float(fps), (width, height))
if writer.isOpened():
return writer
writer.release()
raise RuntimeError(
f"Could not open MP4 writer at {video_path}. Tried "
+ ", ".join(label for _, label in attempts)
+ ". Install an OpenCV build with MP4 codec support or try another machine."
)