| """Visualization helpers for single-case PanCancerSeg inference.""" |
|
|
| from pathlib import Path |
|
|
| import cv2 |
| import numpy as np |
| import SimpleITK as sitk |
|
|
| import matplotlib |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
|
|
| DEFAULT_OVERLAY_COLOR = (255, 0, 0) |
|
|
|
|
| def preprocess_volume(volume, wl, ww): |
| """Apply CT windowing and return uint8 data in [0, 255].""" |
| volume = volume.astype(np.float32, copy=False) |
| lower_bound = wl - ww / 2 |
| upper_bound = wl + ww / 2 |
| clipped = np.clip(volume, lower_bound, upper_bound) |
| return _normalize_to_uint8(clipped) |
|
|
|
|
| def overlay_mask(gray_slice, mask_slice, color=DEFAULT_OVERLAY_COLOR, alpha=0.5): |
| """Apply a semi-transparent RGB overlay to one grayscale slice.""" |
| gray_slice = np.asarray(gray_slice, dtype=np.uint8) |
| if gray_slice.ndim != 2: |
| raise ValueError(f"Expected a 2D grayscale slice, got shape {gray_slice.shape}") |
|
|
| rgb = np.stack([gray_slice] * 3, axis=-1) |
| mask = mask_slice > 0 |
| if not np.any(mask): |
| return rgb |
|
|
| out = rgb.copy() |
| color_arr = np.asarray(color, dtype=np.float32) |
| blended = out[mask].astype(np.float32) * (1 - alpha) + color_arr * alpha |
| out[mask] = np.clip(blended, 0, 255).astype(np.uint8) |
| return out |
|
|
|
|
| def find_key_slices(mask_vol): |
| """Return named representative z-slices for a mask in [z, y, x] order.""" |
| if mask_vol.ndim != 3: |
| raise ValueError(f"Expected a 3D mask volume, got shape {mask_vol.shape}") |
|
|
| depth = mask_vol.shape[0] |
| if depth == 0: |
| raise ValueError("Cannot select key slices from an empty z-dimension") |
|
|
| mask = mask_vol > 0 |
| if np.any(mask): |
| z_indices = np.where(np.any(mask, axis=(1, 2)))[0] |
| areas = mask.reshape(depth, -1).sum(axis=1) |
| coords = np.argwhere(mask) |
| centroid_z = int(round(float(coords[:, 0].mean()))) |
| min_z = int(z_indices.min()) |
| max_z = int(z_indices.max()) |
| return { |
| "centroid": _clip_slice(centroid_z, depth), |
| "max_area": int(areas.argmax()), |
| "extent25": _clip_slice(round(min_z + 0.25 * (max_z - min_z)), depth), |
| "extent75": _clip_slice(round(min_z + 0.75 * (max_z - min_z)), depth), |
| } |
|
|
| middle = depth // 2 |
| offset = max(1, depth // 10) |
| return { |
| "centroid": middle, |
| "max_area": _clip_slice(middle - offset, depth), |
| "extent25": _clip_slice(middle + offset, depth), |
| "extent75": _clip_slice(middle + 2 * offset, depth), |
| } |
|
|
|
|
| def generate_slice_images( |
| image_uint8, |
| mask_vol, |
| output_dir, |
| case_name, |
| color=DEFAULT_OVERLAY_COLOR, |
| alpha=0.5, |
| ): |
| """Save side-by-side PNGs for representative slices.""" |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| key_slices = find_key_slices(mask_vol) |
| outputs = {} |
|
|
| for label, z_idx in key_slices.items(): |
| gray_slice = image_uint8[z_idx] |
| mask_slice = mask_vol[z_idx] > 0 |
| overlay = overlay_mask(gray_slice, mask_slice, color=color, alpha=alpha) |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=150) |
| axes[0].imshow(gray_slice, cmap="gray", vmin=0, vmax=255) |
| axes[0].set_title("Image") |
| axes[0].axis("off") |
| axes[1].imshow(overlay) |
| axes[1].set_title("Segmentation overlay") |
| axes[1].axis("off") |
| fig.suptitle(f"{case_name} | z={z_idx}") |
| fig.tight_layout() |
|
|
| out_path = output_dir / f"{case_name}_slice_{label}.png" |
| fig.savefig(out_path, dpi=150, bbox_inches="tight", facecolor="white") |
| plt.close(fig) |
| outputs[label] = out_path |
|
|
| return outputs |
|
|
|
|
| def generate_video( |
| image_uint8, |
| mask_vol, |
| output_dir, |
| case_name, |
| cancer_type, |
| color=DEFAULT_OVERLAY_COLOR, |
| alpha=0.5, |
| fps=10, |
| ): |
| """Generate an MP4 scroll-through overlay video.""" |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| video_path = output_dir / f"{case_name}_overlay.mp4" |
|
|
| start_z, end_z = _video_z_range(mask_vol) |
| first_frame = _make_video_frame( |
| image_uint8[start_z], |
| mask_vol[start_z], |
| color, |
| alpha, |
| start_z, |
| image_uint8.shape[0], |
| cancer_type, |
| ) |
| height, width = first_frame.shape[:2] |
|
|
| writer = _open_video_writer(video_path, fps, width, height) |
| |
| writer.write(cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR)) |
|
|
| for z_idx in range(start_z + 1, end_z + 1): |
| frame = _make_video_frame( |
| image_uint8[z_idx], |
| mask_vol[z_idx], |
| color, |
| alpha, |
| z_idx, |
| image_uint8.shape[0], |
| cancer_type, |
| ) |
| writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) |
|
|
| writer.release() |
| return video_path |
|
|
|
|
| def generate_outputs( |
| image_path, |
| mask_path, |
| output_dir, |
| case_name, |
| cancer_type, |
| wl, |
| ww, |
| color=DEFAULT_OVERLAY_COLOR, |
| alpha=0.5, |
| fps=10, |
| ): |
| """Read image and mask volumes, then write PNG previews and MP4 video.""" |
| image = sitk.ReadImage(str(image_path)) |
| mask = sitk.ReadImage(str(mask_path)) |
| image_vol = sitk.GetArrayFromImage(image) |
| mask_vol = sitk.GetArrayFromImage(mask) |
|
|
| if image_vol.shape != mask_vol.shape: |
| raise ValueError( |
| "Image and segmentation shapes do not match: " |
| f"image={image_vol.shape}, segmentation={mask_vol.shape}. " |
| "Both arrays are expected in [z, y, x] order." |
| ) |
|
|
| image_uint8 = preprocess_volume(image_vol, wl, ww) |
| slice_paths = generate_slice_images( |
| image_uint8, |
| mask_vol, |
| output_dir, |
| case_name, |
| color, |
| alpha, |
| ) |
| video_path = generate_video( |
| image_uint8, |
| mask_vol, |
| output_dir, |
| case_name, |
| cancer_type, |
| color, |
| alpha, |
| fps, |
| ) |
| return {"slices": slice_paths, "video": video_path} |
|
|
|
|
| def _normalize_to_uint8(volume): |
| v_min = float(np.min(volume)) |
| v_max = float(np.max(volume)) |
| if not np.isfinite(v_min) or not np.isfinite(v_max) or v_max <= v_min: |
| return np.zeros(volume.shape, dtype=np.uint8) |
| normalized = (volume - v_min) / (v_max - v_min) * 255.0 |
| return np.clip(normalized, 0, 255).astype(np.uint8) |
|
|
|
|
| def _clip_slice(index, depth): |
| return int(np.clip(index, 0, depth - 1)) |
|
|
|
|
| def _video_z_range(mask_vol, padding=10, empty_window=80): |
| depth = mask_vol.shape[0] |
| mask = mask_vol > 0 |
| if np.any(mask): |
| z_indices = np.where(np.any(mask, axis=(1, 2)))[0] |
| return ( |
| max(0, int(z_indices.min()) - padding), |
| min(depth - 1, int(z_indices.max()) + padding), |
| ) |
|
|
| if depth <= empty_window: |
| return 0, depth - 1 |
| middle = depth // 2 |
| half = empty_window // 2 |
| return max(0, middle - half), min(depth - 1, middle + half) |
|
|
|
|
| def _make_video_frame(gray_slice, mask_slice, color, alpha, z_idx, depth, cancer_type): |
| frame = overlay_mask(gray_slice, mask_slice, color=color, alpha=alpha) |
| frame = _upscale_if_small(frame) |
|
|
| annotation = f"Slice {z_idx + 1}/{depth} | {cancer_type}" |
| font = cv2.FONT_HERSHEY_SIMPLEX |
| font_scale = max(0.6, min(frame.shape[:2]) / 900) |
| thickness = max(1, int(round(font_scale * 2))) |
| text_size, baseline = cv2.getTextSize(annotation, font, font_scale, thickness) |
| x, y = 12, 12 + text_size[1] |
| cv2.rectangle( |
| frame, |
| (x - 6, y - text_size[1] - 6), |
| (x + text_size[0] + 6, y + baseline + 6), |
| (0, 0, 0), |
| thickness=-1, |
| ) |
| cv2.putText(frame, annotation, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) |
| return frame |
|
|
|
|
| def _upscale_if_small(frame, min_short_side=512): |
| height, width = frame.shape[:2] |
| short_side = min(height, width) |
| if short_side >= min_short_side: |
| return frame |
| scale = min_short_side / short_side |
| new_size = (int(round(width * scale)), int(round(height * scale))) |
| return cv2.resize(frame, new_size, interpolation=cv2.INTER_LINEAR) |
|
|
|
|
| def _open_video_writer(video_path, fps, width, height): |
| attempts = [ |
| ("avc1", "H.264/avc1"), |
| ("mp4v", "MPEG-4/mp4v"), |
| ] |
| for fourcc_text, label in attempts: |
| fourcc = cv2.VideoWriter_fourcc(*fourcc_text) |
| writer = cv2.VideoWriter(str(video_path), fourcc, float(fps), (width, height)) |
| if writer.isOpened(): |
| return writer |
| writer.release() |
| raise RuntimeError( |
| f"Could not open MP4 writer at {video_path}. Tried " |
| + ", ".join(label for _, label in attempts) |
| + ". Install an OpenCV build with MP4 codec support or try another machine." |
| ) |
|
|