Text Classification
Scikit-learn
Joblib
English
llm-routing
model-selection
budget-optimization
nearest-neighbor
Instructions to use JiaqiXue/R2-Router-RouterArena with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Scikit-learn
How to use JiaqiXue/R2-Router-RouterArena with Scikit-learn:
from huggingface_hub import hf_hub_download import joblib model = joblib.load( hf_hub_download("JiaqiXue/R2-Router-RouterArena", "sklearn_model.joblib") ) # only load pickle files from sources you trust # read more about it here https://skops.readthedocs.io/en/stable/persistence.html - Notebooks
- Google Colab
- Kaggle
| """ | |
| R2-Router: LLM Router with Joint Model-Budget Optimization | |
| Self-contained inference module. Routes queries to the optimal (model, token_budget) | |
| pair by predicting per-query quality and cost using KNN. | |
| Usage: | |
| from router import R2Router | |
| # Option A: Local vLLM (loads Qwen3-0.6B on first call) | |
| router = R2Router.from_pretrained(path) | |
| result = router.route_text("What is the capital of France?") | |
| # Option B: Remote vLLM server (no local GPU needed for embedding) | |
| # Start server: vllm serve Qwen/Qwen3-0.6B --runner pooling | |
| router = R2Router.from_pretrained(path, embed_url="http://localhost:8000") | |
| result = router.route_text("What is the capital of France?") | |
| # Option C: Pre-computed embedding | |
| result = router.route(embedding) # np.ndarray (1024,) | |
| """ | |
| import os | |
| import json | |
| import numpy as np | |
| import joblib | |
| from typing import Dict, List, Optional, Union | |
| from sklearn.neighbors import KNeighborsRegressor | |
| class R2Router: | |
| """ | |
| R2-Router: Routes queries to optimal (LLM, token_budget) pair. | |
| Uses KNN to predict quality for each (model, budget) combination, | |
| then selects the pair that maximizes: | |
| risk = (1 - lambda) * quality - lambda * tokens * price / 1e6 | |
| """ | |
| def __init__( | |
| self, | |
| quality_knns: Dict[str, Dict[str, KNeighborsRegressor]], | |
| token_knns: Dict[str, KNeighborsRegressor], | |
| model_prices: Dict[str, float], | |
| model_names: Dict[str, str], | |
| budgets: Dict[str, int], | |
| lambda_val: float = 0.999, | |
| embed_url: Optional[str] = None, | |
| ): | |
| self.quality_knns = quality_knns # {model: {budget: KNN}} | |
| self.token_knns = token_knns # {model: KNN} | |
| self.model_prices = model_prices # {model: price_per_million_output_tokens} | |
| self.model_names = model_names # {short_name: full_name} | |
| self.budgets = budgets # {budget_name: token_limit} | |
| self.lambda_val = lambda_val | |
| self.embed_url = embed_url # vLLM server URL, e.g. "http://localhost:8000" | |
| self._embedder = None | |
| def from_pretrained( | |
| cls, | |
| path: str, | |
| lambda_val: float = 0.999, | |
| embed_url: Optional[str] = None, | |
| ) -> "R2Router": | |
| """ | |
| Load pre-trained KNN checkpoints. | |
| Args: | |
| path: Local directory or HuggingFace repo ID (e.g., "JiaqiXue/r2-router") | |
| lambda_val: Cost-accuracy tradeoff (higher = more cost-sensitive) | |
| embed_url: vLLM server URL for embedding (e.g., "http://localhost:8000"). | |
| If None, loads Qwen3-0.6B locally on first route_text() call. | |
| """ | |
| if not os.path.isdir(path): | |
| path = cls._download_from_hf(path) | |
| with open(os.path.join(path, "config.json")) as f: | |
| config = json.load(f) | |
| ckpt_dir = os.path.join(path, "checkpoints") | |
| quality_knns = {} | |
| token_knns = {} | |
| for model_name in config["models"]: | |
| quality_knns[model_name] = {} | |
| for budget_name in config["budgets"]: | |
| ckpt_path = os.path.join(ckpt_dir, f"quality_knn_{model_name}_{budget_name}.joblib") | |
| if os.path.exists(ckpt_path): | |
| quality_knns[model_name][budget_name] = joblib.load(ckpt_path) | |
| tok_path = os.path.join(ckpt_dir, f"token_knn_{model_name}.joblib") | |
| if os.path.exists(tok_path): | |
| token_knns[model_name] = joblib.load(tok_path) | |
| model_prices = { | |
| mn: cfg["output_price_per_million"] | |
| for mn, cfg in config["models"].items() | |
| } | |
| model_names = { | |
| mn: cfg["full_name"] | |
| for mn, cfg in config["models"].items() | |
| } | |
| return cls( | |
| quality_knns=quality_knns, | |
| token_knns=token_knns, | |
| model_prices=model_prices, | |
| model_names=model_names, | |
| budgets=config["budgets"], | |
| lambda_val=lambda_val, | |
| embed_url=embed_url, | |
| ) | |
| def from_training_data( | |
| cls, | |
| path: str, | |
| k: int = 80, | |
| lambda_val: float = 0.999, | |
| ) -> "R2Router": | |
| """ | |
| Train KNN from scratch using the provided training data. | |
| Args: | |
| path: Local directory or HuggingFace repo ID | |
| k: Number of KNN neighbors | |
| lambda_val: Cost-accuracy tradeoff | |
| """ | |
| if not os.path.isdir(path): | |
| path = cls._download_from_hf(path) | |
| with open(os.path.join(path, "config.json")) as f: | |
| config = json.load(f) | |
| X_train = np.load(os.path.join(path, "training_data", "embeddings.npy")) | |
| with open(os.path.join(path, "training_data", "labels.json")) as f: | |
| labels = json.load(f) | |
| print(f"Training router on {len(X_train)} samples (k={k})...") | |
| quality_knns = {} | |
| token_knns = {} | |
| n_quality = 0 | |
| n_token = 0 | |
| for model_name, model_labels in labels.items(): | |
| quality_knns[model_name] = {} | |
| for budget_name, bdata in model_labels.items(): | |
| acc = np.array([x if x is not None else np.nan for x in bdata["accuracy"]]) | |
| valid = ~np.isnan(acc) | |
| if valid.sum() < 3: | |
| continue | |
| knn = KNeighborsRegressor( | |
| n_neighbors=min(k, int(valid.sum()) - 1), | |
| metric="cosine", | |
| weights="distance", | |
| ) | |
| knn.fit(X_train[valid], acc[valid]) | |
| quality_knns[model_name][budget_name] = knn | |
| n_quality += 1 | |
| if "concise" in model_labels and "output_tokens" in model_labels["concise"]: | |
| tok = np.array([x if x is not None else np.nan for x in model_labels["concise"]["output_tokens"]]) | |
| valid = ~np.isnan(tok) | |
| if valid.sum() >= 3: | |
| tknn = KNeighborsRegressor( | |
| n_neighbors=min(k, int(valid.sum()) - 1), | |
| metric="cosine", | |
| weights="distance", | |
| ) | |
| tknn.fit(X_train[valid], tok[valid]) | |
| token_knns[model_name] = tknn | |
| n_token += 1 | |
| print(f"Trained {n_quality} quality predictors + {n_token} token predictors for {len(quality_knns)} models.") | |
| model_prices = { | |
| mn: cfg["output_price_per_million"] | |
| for mn, cfg in config["models"].items() | |
| } | |
| model_names = { | |
| mn: cfg["full_name"] | |
| for mn, cfg in config["models"].items() | |
| } | |
| return cls( | |
| quality_knns=quality_knns, | |
| token_knns=token_knns, | |
| model_prices=model_prices, | |
| model_names=model_names, | |
| budgets=config["budgets"], | |
| lambda_val=lambda_val, | |
| ) | |
| def _download_from_hf(repo_id: str) -> str: | |
| """Download model from Hugging Face Hub.""" | |
| try: | |
| from huggingface_hub import snapshot_download | |
| return snapshot_download(repo_id) | |
| except ImportError: | |
| raise ImportError( | |
| "huggingface_hub is required to download from HF. " | |
| "Install with: pip install huggingface_hub" | |
| ) | |
| def embed(self, queries: Union[str, List[str]]) -> np.ndarray: | |
| """ | |
| Embed queries using Qwen3-0.6B. | |
| If embed_url is set, uses a remote vLLM server (OpenAI-compatible API). | |
| Otherwise, loads Qwen3-0.6B locally via vLLM (on first call). | |
| Args: | |
| queries: Single query string or list of queries | |
| Returns: | |
| numpy array of shape (N, 1024) | |
| """ | |
| if isinstance(queries, str): | |
| queries = [queries] | |
| if self.embed_url: | |
| return self._embed_remote(queries) | |
| return self._embed_local(queries) | |
| def _embed_remote(self, queries: List[str]) -> np.ndarray: | |
| """Embed via a running vLLM server (OpenAI-compatible embeddings API).""" | |
| import urllib.request | |
| url = self.embed_url.rstrip("/") + "/v1/embeddings" | |
| payload = json.dumps({ | |
| "model": "Qwen/Qwen3-0.6B", | |
| "input": queries, | |
| }).encode() | |
| req = urllib.request.Request( | |
| url, data=payload, | |
| headers={"Content-Type": "application/json"}, | |
| ) | |
| with urllib.request.urlopen(req) as resp: | |
| result = json.loads(resp.read()) | |
| embeddings = [item["embedding"] for item in sorted(result["data"], key=lambda x: x["index"])] | |
| return np.array(embeddings) | |
| def _embed_local(self, queries: List[str]) -> np.ndarray: | |
| """Embed by loading Qwen3-0.6B locally via vLLM.""" | |
| if self._embedder is None: | |
| try: | |
| from vllm import LLM | |
| except ImportError: | |
| raise ImportError( | |
| "vLLM is required for local embedding. " | |
| "Install with: uv pip install vllm\n" | |
| "Or start a vLLM server and pass embed_url to from_pretrained()." | |
| ) | |
| self._embedder = LLM( | |
| model="Qwen/Qwen3-0.6B", | |
| runner="pooling", | |
| trust_remote_code=True, | |
| dtype="half", | |
| ) | |
| outputs = self._embedder.embed(queries) | |
| return np.array([o.outputs.embedding for o in outputs]) | |
| def route_text( | |
| self, | |
| query: Union[str, List[str]], | |
| lambda_val: Optional[float] = None, | |
| ) -> Union[Dict, List[Dict]]: | |
| """ | |
| Route text query(ies) end-to-end: embed with Qwen3-0.6B, then route. | |
| Args: | |
| query: Single query string or list of queries | |
| lambda_val: Override default lambda | |
| Returns: | |
| Routing decision dict (single) or list of dicts (batch) | |
| """ | |
| embeddings = self.embed(query) | |
| if isinstance(query, str): | |
| return self.route(embeddings[0], lambda_val) | |
| return self.route_batch(embeddings, lambda_val) | |
| def route( | |
| self, | |
| embedding: np.ndarray, | |
| lambda_val: Optional[float] = None, | |
| ) -> Dict: | |
| """ | |
| Route a query to the optimal (model, token_budget) pair. | |
| Args: | |
| embedding: Query embedding vector, shape (1024,) or (1, 1024) | |
| lambda_val: Override default lambda (higher = more cost-sensitive) | |
| Returns: | |
| Dict with keys: model, model_full_name, budget, token_limit, | |
| predicted_quality, predicted_cost, risk, all_options | |
| """ | |
| if embedding.ndim == 1: | |
| embedding = embedding.reshape(1, -1) | |
| lam = lambda_val if lambda_val is not None else self.lambda_val | |
| all_options = [] | |
| for mn in self.quality_knns: | |
| price = self.model_prices.get(mn, 0) | |
| if mn in self.token_knns: | |
| tok = max(1.0, float(self.token_knns[mn].predict(embedding)[0])) | |
| else: | |
| tok = 50.0 | |
| for budget_name, knn in self.quality_knns[mn].items(): | |
| q = float(knn.predict(embedding)[0]) | |
| risk = (1 - lam) * q - lam * tok * price / 1e6 | |
| all_options.append({ | |
| "model": mn, | |
| "model_full_name": self.model_names.get(mn, mn), | |
| "budget": budget_name, | |
| "token_limit": self.budgets.get(budget_name, budget_name), | |
| "predicted_quality": q, | |
| "predicted_tokens": tok, | |
| "predicted_cost": tok * price / 1e6, | |
| "risk": risk, | |
| }) | |
| if not all_options: | |
| raise RuntimeError("No valid routing options") | |
| best = max(all_options, key=lambda x: x["risk"]) | |
| best["all_options"] = all_options | |
| return best | |
| def route_batch( | |
| self, | |
| embeddings: np.ndarray, | |
| lambda_val: Optional[float] = None, | |
| ) -> List[Dict]: | |
| """ | |
| Route a batch of queries. | |
| Args: | |
| embeddings: Query embeddings, shape (N, 1024) | |
| lambda_val: Override default lambda | |
| Returns: | |
| List of routing decisions | |
| """ | |
| return [self.route(embeddings[i], lambda_val) for i in range(len(embeddings))] | |