Image Classification
Transformers
PyTorch
English
sybil
medical
cancer
ct-scan
risk-prediction
healthcare
vision
Instructions to use Lab-Rasool/sybil with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Lab-Rasool/sybil with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-classification", model="Lab-Rasool/sybil") pipe("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Lab-Rasool/sybil", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from huggingface_hub import snapshot_download | |
| import sys | |
| import os | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import json | |
| import re | |
| import pydicom | |
| from datetime import datetime | |
| from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor | |
| from pathlib import Path | |
| import threading | |
| import multiprocessing as mp | |
| # Download and setup model | |
| model_path = snapshot_download(repo_id="Lab-Rasool/sybil") | |
| sys.path.append(model_path) | |
| from modeling_sybil_hf import SybilHFWrapper | |
| from configuration_sybil import SybilConfig | |
| def load_model(device_id=0): | |
| """ | |
| Load and initialize the Sybil model once. | |
| Args: | |
| device_id: GPU device ID to load model on | |
| Returns: | |
| Initialized SybilHFWrapper model | |
| """ | |
| print(f"Loading Sybil model on GPU {device_id}...") | |
| config = SybilConfig() | |
| model = SybilHFWrapper(config) | |
| # Move model to specific GPU | |
| device = torch.device(f'cuda:{device_id}') | |
| # CRITICAL: Set the model's internal device attribute | |
| # This ensures preprocessing moves data to the correct GPU | |
| model.device = device | |
| # Move all ensemble models to the correct GPU | |
| for m in model.models: | |
| m.to(device) | |
| m.eval() | |
| print(f"Model loaded successfully on GPU {device_id}!") | |
| print(f" Model internal device: {model.device}") | |
| return model, device | |
| def is_localizer_scan(dicom_folder): | |
| """ | |
| Check if a DICOM folder contains a localizer/scout scan. | |
| Based on preprocessing.py logic. | |
| Returns: | |
| Tuple of (is_localizer, reason) | |
| """ | |
| folder_path = Path(dicom_folder) | |
| folder_name = folder_path.name.lower() | |
| localizer_keywords = ['localizer', 'scout', 'topogram', 'surview', 'scanogram'] | |
| # Check folder name | |
| if any(keyword in folder_name for keyword in localizer_keywords): | |
| return True, f"Folder name contains localizer keyword: {folder_name}" | |
| try: | |
| dcm_files = list(folder_path.glob("*.dcm")) | |
| if not dcm_files: | |
| return False, "No DICOM files found" | |
| # Check first few DICOM files for localizer metadata | |
| sample_files = dcm_files[:min(3, len(dcm_files))] | |
| for dcm_file in sample_files: | |
| try: | |
| dcm = pydicom.dcmread(str(dcm_file), stop_before_pixels=True) | |
| # Check ImageType field | |
| if hasattr(dcm, 'ImageType'): | |
| image_type_str = ' '.join(str(val).lower() for val in dcm.ImageType) | |
| if any(keyword in image_type_str for keyword in localizer_keywords): | |
| return True, f"ImageType indicates localizer: {dcm.ImageType}" | |
| # Check SeriesDescription field | |
| if hasattr(dcm, 'SeriesDescription'): | |
| if any(keyword in dcm.SeriesDescription.lower() for keyword in localizer_keywords): | |
| return True, f"SeriesDescription indicates localizer: {dcm.SeriesDescription}" | |
| except Exception as e: | |
| continue | |
| except Exception as e: | |
| pass | |
| return False, "Not a localizer scan" | |
| def extract_timepoint_from_path(scan_dir): | |
| """ | |
| Extract timepoint from scan directory path based on year. | |
| 1999 -> T0, 2000 -> T1, 2001 -> T2, etc. | |
| Looks for year patterns in folder names in date format MM-DD-YYYY. | |
| Args: | |
| scan_dir: Directory path string | |
| Returns: | |
| Timepoint string (e.g., 'T0', 'T1', 'T2') or None if not found | |
| """ | |
| # Split path into components | |
| path_parts = scan_dir.split('/') | |
| # Look for date patterns like "01-02-2000-NLST-LSS" | |
| # Pattern: Date format MM-DD-YYYY at the start of a folder name | |
| date_pattern = r'^\d{2}-\d{2}-(19\d{2}|20\d{2})' | |
| base_year = 1999 | |
| for part in path_parts: | |
| # Check for date pattern (e.g., "01-02-2000-NLST-LSS-50335") | |
| match = re.match(date_pattern, part) | |
| if match: | |
| year = int(match.group(1)) | |
| if 1999 <= year <= 2010: # Reasonable range for NLST | |
| timepoint_num = year - base_year | |
| print(f" DEBUG: Found year {year} in '{part}' -> T{timepoint_num}") | |
| return f'T{timepoint_num}' | |
| return None | |
| def extract_embedding_single_model(model_idx, ensemble_model, pixel_values, device): | |
| """ | |
| Extract embedding from a single ensemble model. | |
| Args: | |
| model_idx: Index of the model in the ensemble | |
| ensemble_model: Single model from the ensemble | |
| pixel_values: Preprocessed pixel values tensor (already on correct device) | |
| device: Device to run on (e.g., cuda:0, cuda:1) | |
| Returns: | |
| numpy array of embeddings from this model | |
| """ | |
| embeddings_buffer = [] | |
| def create_hook(buffer): | |
| def hook(module, input, output): | |
| # Capture the output of ReLU layer (before dropout) | |
| buffer.append(output.detach().cpu()) | |
| return hook | |
| # Register hook on the ReLU layer (this is AFTER pooling, BEFORE dropout/classification) | |
| hook_handle = ensemble_model.relu.register_forward_hook(create_hook(embeddings_buffer)) | |
| # Run forward pass on THIS model only with keyword argument | |
| with torch.no_grad(): | |
| _ = ensemble_model(pixel_values=pixel_values) | |
| # Remove hook | |
| hook_handle.remove() | |
| # Get the embeddings (should be shape [1, 512]) | |
| if embeddings_buffer: | |
| embedding = embeddings_buffer[0].numpy().squeeze() | |
| print(f"Model {model_idx + 1}: Embedding shape = {embedding.shape}") | |
| return embedding | |
| return None | |
| def extract_embeddings(model, dicom_paths, device, use_parallel=True): | |
| """ | |
| Extract embeddings from the layer after ReLU, before Dropout. | |
| Processes ensemble models in parallel for speed. | |
| Args: | |
| model: Pre-loaded SybilHFWrapper model | |
| dicom_paths: List of DICOM file paths | |
| device: Device to run on (e.g., cuda:0, cuda:1) | |
| use_parallel: If True, process ensemble models in parallel | |
| Returns: | |
| numpy array of shape (512,) - averaged embeddings across ensemble | |
| """ | |
| # Preprocess ONCE (not 5 times!) | |
| # The model's preprocessing handles moving data to the correct device | |
| with torch.no_grad(): | |
| # Get the preprocessed input by calling the wrapper's preprocess_dicom method | |
| # This returns the tensor that would be fed to each ensemble model | |
| pixel_values = model.preprocess_dicom(dicom_paths) | |
| if use_parallel: | |
| # Process all ensemble models in parallel using ThreadPoolExecutor | |
| all_embeddings = [] | |
| with ThreadPoolExecutor(max_workers=len(model.models)) as executor: | |
| # Submit all models for parallel processing with the SAME preprocessed input | |
| futures = [ | |
| executor.submit(extract_embedding_single_model, model_idx, ensemble_model, pixel_values, device) | |
| for model_idx, ensemble_model in enumerate(model.models) | |
| ] | |
| # Collect results as they complete | |
| for future in futures: | |
| embedding = future.result() | |
| if embedding is not None: | |
| all_embeddings.append(embedding) | |
| else: | |
| # Sequential processing (original implementation) | |
| all_embeddings = [] | |
| for model_idx, ensemble_model in enumerate(model.models): | |
| embedding = extract_embedding_single_model(model_idx, ensemble_model, pixel_values, device) | |
| if embedding is not None: | |
| all_embeddings.append(embedding) | |
| # Average embeddings across ensemble | |
| averaged_embedding = np.mean(all_embeddings, axis=0) | |
| return averaged_embedding | |
| def check_directory_for_dicoms(dirpath): | |
| """ | |
| Check a single directory for valid DICOM files. | |
| Returns (dirpath, num_files, subject_id, filter_reason) or None if invalid. | |
| """ | |
| try: | |
| # Quick check: does this directory have .dcm files? | |
| dcm_files = [f for f in os.listdir(dirpath) | |
| if f.endswith('.dcm') and os.path.isfile(os.path.join(dirpath, f))] | |
| if not dcm_files: | |
| return None | |
| num_files = len(dcm_files) | |
| # Filter out scans with 1-2 DICOM files (likely localizers) | |
| if num_files <= 2: | |
| return (dirpath, num_files, None, 'too_few_slices') | |
| # Check if it's a localizer scan | |
| is_loc, _ = is_localizer_scan(dirpath) | |
| if is_loc: | |
| return (dirpath, num_files, None, 'localizer') | |
| # Extract subject ID (PID) from path | |
| # Path structure: /NLST/<PID>/<date-info>/<scan-info> | |
| # Example: /NLST/106639/01-02-1999-NLST-LSS-45699/1.000000-0OPLGEHSQXAnullna... | |
| path_parts = dirpath.rstrip('/').split('/') | |
| # Find the PID: it's the part after 'NLST' directory | |
| try: | |
| nlst_idx = path_parts.index('NLST') | |
| subject_id = path_parts[nlst_idx + 1] # PID is right after 'NLST' | |
| except (ValueError, IndexError): | |
| # Fallback to old logic if path structure is different | |
| subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1] | |
| return (dirpath, num_files, subject_id, 'valid') | |
| except Exception as e: | |
| return None | |
| def save_directory_cache(dicom_dirs, cache_file): | |
| """ | |
| Save the list of DICOM directories to a cache file. | |
| Args: | |
| dicom_dirs: List of directory paths | |
| cache_file: Path to cache file | |
| """ | |
| print(f"\n💾 Saving directory cache to {cache_file}...") | |
| cache_data = { | |
| "timestamp": datetime.now().isoformat(), | |
| "num_directories": len(dicom_dirs), | |
| "directories": dicom_dirs | |
| } | |
| with open(cache_file, 'w') as f: | |
| json.dump(cache_data, f, indent=2) | |
| print(f"✓ Cache saved with {len(dicom_dirs)} directories\n") | |
| def load_directory_cache(cache_file): | |
| """ | |
| Load the list of DICOM directories from a cache file. | |
| Args: | |
| cache_file: Path to cache file | |
| Returns: | |
| List of directory paths, or None if cache doesn't exist or is invalid | |
| """ | |
| if not os.path.exists(cache_file): | |
| return None | |
| try: | |
| with open(cache_file, 'r') as f: | |
| cache_data = json.load(f) | |
| dicom_dirs = cache_data.get("directories", []) | |
| timestamp = cache_data.get("timestamp", "unknown") | |
| print(f"\n✓ Loaded directory cache from {cache_file}") | |
| print(f" Cache created: {timestamp}") | |
| print(f" Directories: {len(dicom_dirs)}\n") | |
| return dicom_dirs | |
| except Exception as e: | |
| print(f"⚠️ Failed to load cache: {e}") | |
| return None | |
| def find_dicom_directories(root_dir, max_subjects=None, num_workers=12, cache_file=None, filter_pids=None): | |
| """ | |
| Walk through directory tree and find all directories containing DICOM files. | |
| Uses parallel processing for much faster scanning of large directory trees. | |
| Only returns leaf directories (directories with .dcm files, not their parents). | |
| Filters out localizer scans with 1-2 DICOM files. | |
| Args: | |
| root_dir: Root directory to search | |
| max_subjects: Optional maximum number of unique subjects to process (None = all) | |
| num_workers: Number of parallel workers for directory scanning (default: 12) | |
| cache_file: Optional path to cache file for saving/loading directory list | |
| filter_pids: Optional set of PIDs to filter (only include these subjects) | |
| Returns: | |
| List of directory paths containing .dcm files | |
| """ | |
| # Try to load from cache first | |
| if cache_file: | |
| cached_dirs = load_directory_cache(cache_file) | |
| if cached_dirs is not None: | |
| print("✓ Using cached directory list (skipping scan)") | |
| # Apply PID filter if specified | |
| if filter_pids: | |
| print(f" Filtering to {len(filter_pids)} PIDs from CSV...") | |
| filtered_dirs = [] | |
| for d in cached_dirs: | |
| # Extract PID from path: /NLST/<PID>/<date>/<scan> | |
| path_parts = d.rstrip('/').split('/') | |
| try: | |
| nlst_idx = path_parts.index('NLST') | |
| subject_id = path_parts[nlst_idx + 1] | |
| except (ValueError, IndexError): | |
| subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1] | |
| if subject_id in filter_pids: | |
| filtered_dirs.append(d) | |
| print(f" ✓ Found {len(filtered_dirs)} scans matching PIDs") | |
| return filtered_dirs | |
| # Still apply max_subjects limit if specified | |
| if max_subjects: | |
| subjects_seen = set() | |
| filtered_dirs = [] | |
| for d in cached_dirs: | |
| # Extract PID from path: /NLST/<PID>/<date>/<scan> | |
| path_parts = d.rstrip('/').split('/') | |
| try: | |
| nlst_idx = path_parts.index('NLST') | |
| subject_id = path_parts[nlst_idx + 1] | |
| except (ValueError, IndexError): | |
| subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1] | |
| # Check if we should include this scan | |
| # Add scan if: (1) already collecting this subject, OR (2) under subject limit | |
| if subject_id in subjects_seen: | |
| # Already collecting this subject - add this scan | |
| filtered_dirs.append(d) | |
| elif len(subjects_seen) < max_subjects: | |
| # New subject and under limit - start collecting this subject | |
| subjects_seen.add(subject_id) | |
| filtered_dirs.append(d) | |
| # Stop once we have enough subjects | |
| if len(subjects_seen) >= max_subjects: | |
| # Count remaining scans from these subjects | |
| remaining_count = 0 | |
| for remaining_d in cached_dirs[cached_dirs.index(d)+1:]: | |
| remaining_parts = remaining_d.rstrip('/').split('/') | |
| try: | |
| remaining_nlst_idx = remaining_parts.index('NLST') | |
| remaining_subject_id = remaining_parts[remaining_nlst_idx + 1] | |
| except (ValueError, IndexError): | |
| remaining_subject_id = remaining_parts[-3] if len(remaining_parts) >= 3 else remaining_parts[-1] | |
| if remaining_subject_id in subjects_seen: | |
| filtered_dirs.append(remaining_d) | |
| break | |
| print(f" ✓ Limited to {len(subjects_seen)} subjects ({len(filtered_dirs)} total scans)") | |
| return filtered_dirs | |
| return cached_dirs | |
| print(f"Starting parallel directory scan with {num_workers} workers...") | |
| if filter_pids: | |
| print(f"⚡ FAST MODE: Only scanning {len(filter_pids)} PIDs (skipping others)") | |
| else: | |
| print("Scanning ALL subjects (this may take a while)") | |
| # Phase 1: Fast parallel scan to find all directories with DICOM files | |
| # BUT: Skip subject directories not in filter_pids for MASSIVE speedup | |
| print("\nPhase 1: Scanning filesystem for DICOM directories...") | |
| start_time = datetime.now() | |
| # Collect all directories first (fast) - WITH EARLY FILTERING | |
| all_dirs = [] | |
| for dirpath, dirnames, filenames in os.walk(root_dir): | |
| # EARLY FILTER: If we have filter_pids, only descend into matching PID directories | |
| if filter_pids: | |
| path_parts = dirpath.rstrip('/').split('/') | |
| try: | |
| nlst_idx = path_parts.index('NLST') | |
| # If this is a subject directory (one level below NLST) | |
| if len(path_parts) == nlst_idx + 2: | |
| subject_id = path_parts[nlst_idx + 1] | |
| # Skip this subject if not in filter list | |
| if subject_id not in filter_pids: | |
| dirnames.clear() # Don't descend into this subject's subdirs | |
| continue | |
| except (ValueError, IndexError): | |
| pass | |
| # Quick check: if directory has .dcm files, add to list | |
| if any(f.endswith('.dcm') for f in filenames): | |
| all_dirs.append(dirpath) | |
| print(f"Found {len(all_dirs)} potential DICOM directories in {(datetime.now() - start_time).total_seconds():.1f}s") | |
| # Phase 2: Parallel validation and filtering | |
| print(f"\nPhase 2: Validating directories in parallel ({num_workers} workers)...") | |
| from concurrent.futures import ProcessPoolExecutor, as_completed | |
| dicom_dirs = [] | |
| subjects_found = set() | |
| filtered_stats = {'localizers': 0, 'too_few_slices': 0} | |
| with ProcessPoolExecutor(max_workers=num_workers) as executor: | |
| # Submit all directories for checking | |
| future_to_dir = {executor.submit(check_directory_for_dicoms, d): d for d in all_dirs} | |
| # Process results as they complete | |
| for i, future in enumerate(as_completed(future_to_dir), 1): | |
| # Print progress every 1000 dirs (more frequent for visibility) | |
| if i % 1000 == 0: | |
| elapsed = (datetime.now() - start_time).total_seconds() | |
| rate = i / elapsed if elapsed > 0 else 0 | |
| remaining = (len(all_dirs) - i) / rate if rate > 0 else 0 | |
| print(f" [{i}/{len(all_dirs)}] Found: {len(dicom_dirs)} scans from {len(subjects_found)} PIDs | " | |
| f"Filtered: {filtered_stats['localizers'] + filtered_stats['too_few_slices']} | " | |
| f"ETA: {remaining/60:.1f} min") | |
| try: | |
| result = future.result() | |
| if result is None: | |
| continue | |
| dirpath, num_files, subject_id, status = result | |
| if status == 'too_few_slices': | |
| filtered_stats['too_few_slices'] += 1 | |
| elif status == 'localizer': | |
| filtered_stats['localizers'] += 1 | |
| elif status == 'valid': | |
| # Check PID filter | |
| if filter_pids is not None and subject_id not in filter_pids: | |
| continue | |
| # Check subject limit | |
| if max_subjects is not None and subject_id not in subjects_found and len(subjects_found) >= max_subjects: | |
| continue | |
| subjects_found.add(subject_id) | |
| dicom_dirs.append(dirpath) | |
| # Print when we find a new PID match (helpful for filtered runs) | |
| if filter_pids and len(dicom_dirs) % 100 == 1: | |
| print(f" ✓ Found {len(dicom_dirs)} scans so far ({len(subjects_found)} unique PIDs)") | |
| # Stop if we've hit subject limit | |
| if max_subjects is not None and len(subjects_found) >= max_subjects: | |
| print(f"\n✓ Reached limit of {max_subjects} subjects. Stopping search.") | |
| # Cancel remaining futures | |
| for f in future_to_dir: | |
| f.cancel() | |
| break | |
| except Exception as e: | |
| continue | |
| scan_time = (datetime.now() - start_time).total_seconds() | |
| print(f"\n{'='*80}") | |
| print(f"Directory Scan Complete in {scan_time:.1f}s ({scan_time/60:.1f} minutes)") | |
| print(f"{'='*80}") | |
| print(f"Filtering Summary:") | |
| print(f" ✅ Valid scans found: {len(dicom_dirs)}") | |
| print(f" 🚫 Localizers filtered: {filtered_stats['localizers']}") | |
| print(f" ⏭️ Too few slices (≤2) filtered: {filtered_stats['too_few_slices']}") | |
| print(f" 📊 Unique subjects: {len(subjects_found)}") | |
| print(f" ⚡ Speed: {len(all_dirs)/scan_time:.0f} dirs/second") | |
| print(f"{'='*80}\n") | |
| # Save to cache if specified | |
| if cache_file: | |
| save_directory_cache(dicom_dirs, cache_file) | |
| return dicom_dirs | |
| def prepare_scan_metadata(scan_dir): | |
| """ | |
| Prepare metadata for a scan without processing. | |
| Args: | |
| scan_dir: Directory containing DICOM files for one scan | |
| Returns: | |
| tuple: (dicom_file_paths, num_files, subject_id, scan_id) | |
| """ | |
| # Count DICOM files (ensure they are actual files, not directories) | |
| dicom_files = [f for f in os.listdir(scan_dir) | |
| if f.endswith('.dcm') and os.path.isfile(os.path.join(scan_dir, f))] | |
| num_dicom_files = len(dicom_files) | |
| if num_dicom_files == 0: | |
| raise ValueError("No valid DICOM files found") | |
| # Create list of full paths to DICOM files | |
| dicom_file_paths = [os.path.join(scan_dir, f) for f in dicom_files] | |
| # Parse directory path to extract identifiers | |
| # Path structure: /NLST/<PID>/<date-info>/<scan-info> | |
| path_parts = scan_dir.rstrip('/').split('/') | |
| scan_id = path_parts[-1] if path_parts[-1] else path_parts[-2] | |
| # Extract PID from path | |
| try: | |
| nlst_idx = path_parts.index('NLST') | |
| subject_id = path_parts[nlst_idx + 1] # PID is right after 'NLST' | |
| except (ValueError, IndexError): | |
| # Fallback to old logic | |
| subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1] | |
| return dicom_file_paths, num_dicom_files, subject_id, scan_id | |
| def save_checkpoint(all_embeddings, all_metadata, failed, output_dir, checkpoint_num): | |
| """ | |
| Save a checkpoint of embeddings and metadata. | |
| Args: | |
| all_embeddings: List of embedding arrays | |
| all_metadata: List of metadata dictionaries | |
| failed: List of failed scans | |
| output_dir: Output directory | |
| checkpoint_num: Checkpoint number | |
| """ | |
| print(f"\n💾 Saving checkpoint {checkpoint_num}...") | |
| # Convert embeddings to array | |
| embeddings_array = np.array(all_embeddings) | |
| embedding_dim = int(embeddings_array.shape[1]) if len(embeddings_array.shape) > 1 else int(embeddings_array.shape[0]) | |
| # Create DataFrame | |
| df_data = { | |
| 'case_number': [m['case_number'] for m in all_metadata], | |
| 'subject_id': [m['subject_id'] for m in all_metadata], | |
| 'scan_id': [m['scan_id'] for m in all_metadata], | |
| 'timepoint': [m.get('timepoint') for m in all_metadata], | |
| 'dicom_directory': [m['dicom_directory'] for m in all_metadata], | |
| 'num_dicom_files': [m['num_dicom_files'] for m in all_metadata], | |
| 'embedding_index': [m['embedding_index'] for m in all_metadata], | |
| 'embedding': list(embeddings_array) | |
| } | |
| df = pd.DataFrame(df_data) | |
| # Save checkpoint parquet | |
| checkpoint_path = os.path.join(output_dir, f"checkpoint_{checkpoint_num}_embeddings.parquet") | |
| df.to_parquet(checkpoint_path, index=False, compression='snappy') | |
| print(f" ✓ Saved embeddings checkpoint: {checkpoint_path}") | |
| # Save checkpoint metadata | |
| checkpoint_metadata = { | |
| "checkpoint_num": checkpoint_num, | |
| "timestamp": datetime.now().isoformat(), | |
| "total_scans": len(all_embeddings), | |
| "failed_scans": len(failed), | |
| "embedding_shape": list(embeddings_array.shape), | |
| "scans": all_metadata, | |
| "failed_scans": failed | |
| } | |
| metadata_path = os.path.join(output_dir, f"checkpoint_{checkpoint_num}_metadata.json") | |
| with open(metadata_path, 'w') as f: | |
| json.dump(checkpoint_metadata, f, indent=2) | |
| print(f" ✓ Saved metadata checkpoint: {metadata_path}") | |
| print(f"💾 Checkpoint {checkpoint_num} complete!\n") | |
| def process_scan(model, device, scan_dir): | |
| """ | |
| Process a single scan directory and extract embeddings. | |
| Args: | |
| model: Pre-loaded SybilHFWrapper model | |
| device: Device to run on (e.g., cuda:0, cuda:1) | |
| scan_dir: Directory containing DICOM files for one scan | |
| Returns: | |
| tuple: (embeddings, scan_metadata) | |
| """ | |
| dicom_file_paths, num_dicom_files, subject_id, scan_id = prepare_scan_metadata(scan_dir) | |
| print(f"\nProcessing: {scan_dir}") | |
| print(f"DICOM files: {num_dicom_files}") | |
| # Extract embeddings | |
| embeddings = extract_embeddings(model, dicom_file_paths, device) | |
| print(f"Embedding shape: {embeddings.shape}") | |
| # Extract timepoint from path (e.g., 1999 -> T0, 2000 -> T1) | |
| timepoint = extract_timepoint_from_path(scan_dir) | |
| if timepoint: | |
| print(f"Timepoint: {timepoint}") | |
| else: | |
| print(f"Timepoint: Not detected") | |
| # Create metadata for this scan | |
| scan_metadata = { | |
| "case_number": subject_id, # Case number (e.g., 205749) | |
| "subject_id": subject_id, | |
| "scan_id": scan_id, | |
| "timepoint": timepoint, # T0, T1, T2, etc. or None | |
| "dicom_directory": scan_dir, | |
| "num_dicom_files": num_dicom_files, | |
| "embedding_index": None, # Will be set later | |
| "statistics": { | |
| "mean": float(np.mean(embeddings)), | |
| "std": float(np.std(embeddings)), | |
| "min": float(np.min(embeddings)), | |
| "max": float(np.max(embeddings)) | |
| } | |
| } | |
| return embeddings, scan_metadata | |
| # Main execution | |
| if __name__ == "__main__": | |
| import argparse | |
| # Parse command line arguments | |
| parser = argparse.ArgumentParser(description='Extract Sybil embeddings from DICOM scans') | |
| # Input/Output | |
| parser.add_argument('--root-dir', type=str, required=True, | |
| help='Root directory containing DICOM files (e.g., /path/to/NLST)') | |
| parser.add_argument('--pid-csv', type=str, default=None, | |
| help='CSV file with "pid" column to filter subjects (e.g., subsets/hybridModels-train.csv)') | |
| parser.add_argument('--output-dir', type=str, default='embeddings_output', | |
| help='Output directory for embeddings (default: embeddings_output)') | |
| parser.add_argument('--max-subjects', type=int, default=None, | |
| help='Maximum number of subjects to process (for testing)') | |
| # Performance tuning | |
| parser.add_argument('--num-gpus', type=int, default=1, | |
| help='Number of GPUs to use (default: 1)') | |
| parser.add_argument('--num-parallel', type=int, default=1, | |
| help='Number of parallel scans to process simultaneously (default: 1, recommended: 1-4 depending on GPU memory)') | |
| parser.add_argument('--num-workers', type=int, default=4, | |
| help='Number of parallel workers for directory scanning (default: 4, recommended: 4-12 depending on storage speed)') | |
| parser.add_argument('--checkpoint-interval', type=int, default=1000, | |
| help='Save checkpoint every N scans (default: 1000)') | |
| args = parser.parse_args() | |
| # ========================================== | |
| # CONFIGURATION | |
| # ========================================== | |
| root_dir = args.root_dir | |
| output_dir = args.output_dir | |
| max_subjects = args.max_subjects | |
| num_gpus = args.num_gpus | |
| num_parallel_scans = args.num_parallel | |
| num_scan_workers = args.num_workers | |
| checkpoint_interval = args.checkpoint_interval | |
| # Always use the main cache file from the full run | |
| main_cache = "embeddings_output_full/directory_cache.json" | |
| if os.path.exists(main_cache): | |
| cache_file = main_cache | |
| print(f"✓ Found main directory cache: {main_cache}") | |
| else: | |
| cache_file = os.path.join(output_dir, "directory_cache.json") | |
| # Verify root directory exists | |
| if not os.path.exists(root_dir): | |
| raise ValueError(f"Root directory does not exist: {root_dir}") | |
| # Load PIDs from CSV if provided | |
| filter_pids = None | |
| if args.pid_csv: | |
| print(f"Loading subject PIDs from: {args.pid_csv}") | |
| import pandas as pd | |
| csv_data = pd.read_csv(args.pid_csv) | |
| filter_pids = set(str(pid) for pid in csv_data['pid'].unique()) | |
| print(f" Found {len(filter_pids)} unique PIDs to extract") | |
| print(f" Examples: {list(filter_pids)[:5]}") | |
| # Create output directory | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Print configuration | |
| print(f"\n{'='*80}") | |
| print(f"CONFIGURATION") | |
| print(f"{'='*80}") | |
| print(f"Root directory: {root_dir}") | |
| print(f"Output directory: {output_dir}") | |
| print(f"Number of GPUs: {num_gpus}") | |
| print(f"Parallel scans: {num_parallel_scans} (recommended: 1-4 depending on GPU memory)") | |
| print(f"Directory scan workers: {num_scan_workers} (recommended: 4-12 depending on storage)") | |
| print(f"Checkpoint interval: {checkpoint_interval} scans") | |
| if filter_pids: | |
| print(f"Filtering to: {len(filter_pids)} PIDs from CSV") | |
| if max_subjects: | |
| print(f"Max subjects: {max_subjects}") | |
| print(f"{'='*80}\n") | |
| # Warning about memory requirements | |
| if num_parallel_scans > 1: | |
| estimated_vram = (num_parallel_scans // num_gpus) * 10 | |
| print(f"⚠️ MEMORY WARNING:") | |
| print(f" Parallel processing requires ~{estimated_vram}GB VRAM per GPU") | |
| print(f" If you encounter OOM errors, reduce --num-parallel to 1-2") | |
| print(f" Current: {num_parallel_scans} scans across {num_gpus} GPU(s)\n") | |
| # Find all directories containing DICOM files (FAST with parallel processing!) | |
| # Will use cached directory list if available, otherwise scan and save cache | |
| dicom_dirs = find_dicom_directories(root_dir, max_subjects=max_subjects, | |
| num_workers=num_scan_workers, cache_file=cache_file, | |
| filter_pids=filter_pids) | |
| if len(dicom_dirs) == 0: | |
| raise ValueError(f"No directories with DICOM files found in {root_dir}") | |
| print(f"\n{'='*80}") | |
| print(f"Found {len(dicom_dirs)} directories containing DICOM files") | |
| print(f"{'='*80}\n") | |
| # Detect and load models on multiple GPUs | |
| print(f"🎮 Detected {num_gpus} GPU(s)") | |
| print(f"🚀 Will process {num_parallel_scans} scans in parallel ({num_parallel_scans // num_gpus} per GPU)") | |
| print(f"💾 Checkpoints will be saved every {checkpoint_interval} scans\n") | |
| # Load models on each GPU | |
| models_and_devices = [] | |
| for gpu_id in range(num_gpus): | |
| model, device = load_model(gpu_id) | |
| models_and_devices.append((model, device, gpu_id)) | |
| # Process each scan directory and collect all embeddings | |
| all_embeddings = [] | |
| all_metadata = [] | |
| failed = [] | |
| checkpoint_counter = 0 | |
| if num_parallel_scans > 1: | |
| # Parallel processing of multiple scans across multiple GPUs | |
| print(f"Processing {num_parallel_scans} scans in parallel across {num_gpus} GPU(s)...") | |
| print(f"Note: This requires ~{(num_parallel_scans // num_gpus) * 10}GB VRAM per GPU.\n") | |
| from functools import partial | |
| from concurrent.futures import as_completed | |
| # Process scans in batches for checkpoint saving | |
| batch_size = checkpoint_interval | |
| num_batches = (len(dicom_dirs) + batch_size - 1) // batch_size | |
| for batch_idx in range(num_batches): | |
| start_idx = batch_idx * batch_size | |
| end_idx = min(start_idx + batch_size, len(dicom_dirs)) | |
| batch_dirs = dicom_dirs[start_idx:end_idx] | |
| print(f"\n{'='*80}") | |
| print(f"Processing batch {batch_idx + 1}/{num_batches} (scans {start_idx + 1} to {end_idx})") | |
| print(f"{'='*80}\n") | |
| # Use ThreadPoolExecutor for parallel scan processing | |
| # IMPORTANT: max_workers limits concurrent execution to prevent OOM | |
| with ThreadPoolExecutor(max_workers=num_parallel_scans) as executor: | |
| # Submit scans in controlled batches to avoid memory issues | |
| # We submit only max_workers scans at once, then submit more as they complete | |
| future_to_info = {} | |
| scan_queue = list(enumerate(batch_dirs)) | |
| scans_submitted = 0 | |
| # Submit initial batch (up to max_workers scans) | |
| while scan_queue and scans_submitted < num_parallel_scans: | |
| i, scan_dir = scan_queue.pop(0) | |
| # Select GPU in round-robin fashion | |
| gpu_idx = i % num_gpus | |
| model, device, gpu_id = models_and_devices[gpu_idx] | |
| # Create partial function with model and device | |
| process_func = partial(process_scan, model, device) | |
| future = executor.submit(process_func, scan_dir) | |
| future_to_info[future] = (start_idx + i + 1, scan_dir, gpu_id) | |
| scans_submitted += 1 | |
| # Process results as they complete and submit new scans | |
| while future_to_info: | |
| # Wait for next completion | |
| done_futures = [] | |
| for future in list(future_to_info.keys()): | |
| if future.done(): | |
| done_futures.append(future) | |
| if not done_futures: | |
| import time | |
| time.sleep(0.1) | |
| continue | |
| # Process completed futures | |
| for future in done_futures: | |
| scan_num, scan_dir, gpu_id = future_to_info.pop(future) | |
| try: | |
| print(f"[{scan_num}/{len(dicom_dirs)}] Processing on GPU {gpu_id}...") | |
| embeddings, scan_metadata = future.result() | |
| # Set the index for this scan | |
| scan_metadata["embedding_index"] = len(all_embeddings) | |
| # Collect embeddings and metadata | |
| all_embeddings.append(embeddings) | |
| all_metadata.append(scan_metadata) | |
| except Exception as e: | |
| print(f"ERROR processing {scan_dir}: {e}") | |
| failed.append({"scan_dir": scan_dir, "error": str(e)}) | |
| # Submit next scan from queue | |
| if scan_queue: | |
| i, next_scan_dir = scan_queue.pop(0) | |
| gpu_idx = i % num_gpus | |
| model, device, gpu_id = models_and_devices[gpu_idx] | |
| process_func = partial(process_scan, model, device) | |
| new_future = executor.submit(process_func, next_scan_dir) | |
| future_to_info[new_future] = (start_idx + i + 1, next_scan_dir, gpu_id) | |
| # Save checkpoint after each batch | |
| checkpoint_counter += 1 | |
| save_checkpoint(all_embeddings, all_metadata, failed, output_dir, checkpoint_counter) | |
| print(f"Progress: {len(all_embeddings)}/{len(dicom_dirs)} scans completed " | |
| f"({len(all_embeddings)/len(dicom_dirs)*100:.1f}%)\n") | |
| else: | |
| # Sequential processing (original behavior) | |
| model, device, gpu_id = models_and_devices[0] # Use first GPU | |
| for i, scan_dir in enumerate(dicom_dirs, 1): | |
| try: | |
| print(f"\n[{i}/{len(dicom_dirs)}] Processing scan...") | |
| # Process scan and get results | |
| embeddings, scan_metadata = process_scan(model, device, scan_dir) | |
| # Set the index for this scan | |
| scan_metadata["embedding_index"] = len(all_embeddings) | |
| # Collect embeddings and metadata | |
| all_embeddings.append(embeddings) | |
| all_metadata.append(scan_metadata) | |
| # Save checkpoint every checkpoint_interval scans | |
| if i % checkpoint_interval == 0: | |
| checkpoint_counter += 1 | |
| save_checkpoint(all_embeddings, all_metadata, failed, output_dir, checkpoint_counter) | |
| except Exception as e: | |
| print(f"ERROR processing {scan_dir}: {e}") | |
| failed.append({"scan_dir": scan_dir, "error": str(e)}) | |
| # Convert embeddings list to numpy array | |
| # Shape will be (num_scans, embedding_dim) | |
| embeddings_array = np.array(all_embeddings) | |
| embedding_dim = int(embeddings_array.shape[1]) if len(embeddings_array.shape) > 1 else int(embeddings_array.shape[0]) | |
| # Create DataFrame with embeddings and metadata for Parquet | |
| # Store embeddings as a single array column | |
| df_data = { | |
| 'case_number': [m['case_number'] for m in all_metadata], | |
| 'subject_id': [m['subject_id'] for m in all_metadata], | |
| 'scan_id': [m['scan_id'] for m in all_metadata], | |
| 'timepoint': [m.get('timepoint') for m in all_metadata], # T0, T1, T2, etc. | |
| 'dicom_directory': [m['dicom_directory'] for m in all_metadata], | |
| 'num_dicom_files': [m['num_dicom_files'] for m in all_metadata], | |
| 'embedding_index': [m['embedding_index'] for m in all_metadata], | |
| 'embedding': list(embeddings_array) # Store as list of arrays | |
| } | |
| # Create DataFrame | |
| df = pd.DataFrame(df_data) | |
| # Save final complete file as Parquet | |
| embeddings_filename = "all_embeddings.parquet" | |
| embeddings_path = os.path.join(output_dir, embeddings_filename) | |
| df.to_parquet(embeddings_path, index=False, compression='snappy') | |
| print(f"\n✅ Saved FINAL embeddings to Parquet: {embeddings_path}") | |
| # Create comprehensive metadata JSON | |
| dataset_metadata = { | |
| "dataset_info": { | |
| "root_directory": root_dir, | |
| "total_scans": len(all_embeddings), | |
| "failed_scans": len(failed), | |
| "embedding_shape": list(embeddings_array.shape), | |
| "embedding_dim": embedding_dim, | |
| "extraction_timestamp": datetime.now().isoformat(), | |
| "file_format": "parquet" | |
| }, | |
| "model_info": { | |
| "model": "Lab-Rasool/sybil", | |
| "layer": "after_relu_before_dropout", | |
| "ensemble_averaged": True, | |
| "num_ensemble_models": 5 | |
| }, | |
| "embeddings_file": embeddings_filename, | |
| "parquet_schema": { | |
| "metadata_columns": ["case_number", "subject_id", "scan_id", "timepoint", "dicom_directory", "num_dicom_files", "embedding_index"], | |
| "embedding_column": "embedding", | |
| "embedding_shape": f"({embedding_dim},)", | |
| "total_columns": 8, | |
| "timepoint_info": "T0=1999, T1=2000, T2=2001, etc. Extracted from year in path. Can be None if not detected." | |
| }, | |
| "filtering_info": { | |
| "localizer_detection": "Scans identified as localizers (by folder name or DICOM metadata) are filtered out", | |
| "min_slices": "Scans with ≤2 DICOM files are filtered out (likely localizers)", | |
| "accepted_scans": len(all_embeddings) | |
| }, | |
| "scans": all_metadata, | |
| "failed_scans": failed | |
| } | |
| metadata_filename = "dataset_metadata.json" | |
| metadata_path = os.path.join(output_dir, metadata_filename) | |
| with open(metadata_path, 'w') as f: | |
| json.dump(dataset_metadata, f, indent=2) | |
| print(f"✅ Saved FINAL metadata: {metadata_path}") | |
| # Summary | |
| print(f"\n{'='*80}") | |
| print(f"PROCESSING COMPLETE") | |
| print(f"{'='*80}") | |
| print(f"Successfully processed: {len(all_embeddings)}/{len(dicom_dirs)} scans") | |
| print(f"Failed: {len(failed)}/{len(dicom_dirs)} scans") | |
| print(f"\nEmbeddings array shape: {embeddings_array.shape}") | |
| print(f"Saved embeddings to: {embeddings_path}") | |
| print(f"Saved metadata to: {metadata_path}") | |
| # Timepoint summary | |
| timepoint_counts = {} | |
| for m in all_metadata: | |
| tp = m.get('timepoint', 'Unknown') | |
| timepoint_counts[tp] = timepoint_counts.get(tp, 0) + 1 | |
| if timepoint_counts: | |
| print(f"\n📅 Timepoint Distribution:") | |
| for tp in sorted(timepoint_counts.keys(), key=lambda x: (x is None, x)): | |
| count = timepoint_counts[tp] | |
| if tp is None: | |
| print(f" Unknown/Not detected: {count} scans") | |
| else: | |
| print(f" {tp}: {count} scans") | |
| if failed: | |
| print(f"\nFailed scans: {len(failed)}") | |
| for fail_info in failed[:5]: # Show first 5 failures | |
| print(f" - {fail_info['scan_dir']}") | |
| print(f" Error: {fail_info['error']}") | |
| if len(failed) > 5: | |
| print(f" ... and {len(failed) - 5} more failures") | |
| print(f"\n{'='*80}") | |
| print(f"For downstream training, load embeddings with:") | |
| print(f" import pandas as pd") | |
| print(f" import numpy as np") | |
| print(f" df = pd.read_parquet('{embeddings_path}')") | |
| print(f" # Total rows: {len(df)}, Total columns: {len(df.columns)}") | |
| print(f" # Extract embeddings array: embeddings = np.stack(df['embedding'].values)") | |
| print(f" # Shape: {embeddings_array.shape}") | |
| print(f" # Access individual: df.loc[0, 'embedding'] -> array of shape ({embedding_dim},)") | |
| print(f"{'='*80}") | |