import time import asyncio import fnmatch from datetime import datetime, timedelta, timezone from typing import Optional, List, Dict, Any from app.models import APIKey, KeyStatus from app.core.key_manager import KeyManager from app.utils.redis_client import RedisClient class Dispatcher: DAILY_RPD_TTL_SECONDS = 90 * 86400 RATE_LIMIT_PENALTY_SEQ_KEY = "keypool:priority:429:seq" PROVIDER_PRIORITIES = { "modelscope": 0, "nvidia": 0, "siliconflow": 0, "llama-local": 0, } def __init__(self, key_manager: KeyManager): self.key_manager = key_manager self.redis = RedisClient.get_instance() @staticmethod def _get_day_str(offset_days: int = 0) -> str: current_day = datetime.now(timezone.utc) + timedelta(days=offset_days) return current_day.strftime("%Y%m%d") @staticmethod def _get_rate_limit_penalty_key(key_id: str) -> str: return f"keypool:priority:429:{key_id}" @classmethod def _get_provider_priority(cls, provider: str) -> int: return cls.PROVIDER_PRIORITIES.get(provider.lower(), len(cls.PROVIDER_PRIORITIES)) async def _delete_matching_keys(self, pattern: str, batch_size: int = 200) -> int: if hasattr(self.redis, "scan_iter"): batch = [] deleted = 0 async for key in self.redis.scan_iter(match=pattern, count=batch_size): batch.append(key) if len(batch) >= batch_size: deleted += await self.redis.delete(*batch) batch.clear() if batch: deleted += await self.redis.delete(*batch) return deleted if not hasattr(self.redis, "values"): raise RuntimeError("Redis client does not support scan_iter") keys = [key for key in self.redis.values if fnmatch.fnmatch(key, pattern)] return await self.redis.delete(*keys) if keys else 0 async def reset_today_rpd_usage(self) -> int: day = self._get_day_str() return await self._delete_matching_keys(f"keypool:usage:rpd:*:{day}") async def get_key_stats(self, key: APIKey): """Fetches runtime stats from Redis for a single key.""" concurrency_key = f"keypool:concurrency:{key.id}" penalty_key = self._get_rate_limit_penalty_key(key.id) current_day = self._get_day_str() rpd_key = f"keypool:usage:rpd:{key.id}:{current_day}" # Pipeline for efficiency pipe = self.redis.pipeline() pipe.get(concurrency_key) pipe.get(penalty_key) pipe.get(rpd_key) results = await pipe.execute() current_concurrency = int(results[0]) if results[0] else 0 rate_limit_penalty = int(results[1]) if results[1] else 0 current_rpd = int(results[2]) if results[2] else 0 key.current_concurrency = current_concurrency key.rate_limit_penalty = rate_limit_penalty key.current_rpd = current_rpd return key async def select_key(self, model_name: Optional[str] = None, timeout: float = 30.0) -> tuple[Optional[APIKey], str]: """ Selects the best available key for the given model. If model_name is None, selects from ALL available keys. If no keys are available immediately, it will wait (queue) up to `timeout` seconds. Strategy: 1. Try to find an available key immediately. 2. If none found, wait for a short interval and retry (polling/queue simulation). In a production system, we might use a real Redis list/pubsub for queuing, but for now, we'll use a local async loop with exponential backoff or simple polling. Returns: (key, error_reason) Note: If model_name is explicitly specified (force_select=True), the status check for DISABLED keys is bypassed, allowing access to disabled models. """ start_time = time.time() # If model_name is explicitly specified, force select (bypass enabled check) force_select = bool(model_name) while True: # 1. Try to select a key key, error = await self._try_select_key_once(model_name, force_select=force_select) if key: return key, "" # 2. Check timeout if time.time() - start_time > timeout: return None, f"Timeout waiting for available key: {error}" # 3. Wait before retrying (simple polling). await asyncio.sleep(0.5) async def _try_select_key_once(self, model_name: Optional[str] = None, force_select: bool = False) -> tuple[Optional[APIKey], str]: """ Internal method to try selecting a key once without waiting. Strategy: 1. Filter by status and hard limits. 2. Prefer keys that never hit 429. 3. Prefer providers by static rank: modelscope > nvidia > siliconflow. 4. For keys that did hit 429, use the demotion order as a queue. 5. Use load as the final tie breaker inside the same priority tier. Args: model_name: Optional model name to filter by force_select: If True, bypass the status check (allows accessing disabled models) Returns: (key, error_reason) """ candidates = self.key_manager.get_candidate_keys(model_name) if not candidates: return None, "No keys configured for this model" if model_name else "No keys configured in the pool" # Prepare keys for mget concurrency_keys = [f"keypool:concurrency:{k.id}" for k in candidates] penalty_keys = [self._get_rate_limit_penalty_key(k.id) for k in candidates] now = time.time() current_minute = int(now / 60) prev_minute = current_minute - 1 rpm_curr_keys = [f"keypool:usage:rpm:{k.id}:{current_minute}" for k in candidates] rpm_prev_keys = [f"keypool:usage:rpm:{k.id}:{prev_minute}" for k in candidates] tpm_curr_keys = [f"keypool:usage:tpm:{k.id}:{current_minute}" for k in candidates] tpm_prev_keys = [f"keypool:usage:tpm:{k.id}:{prev_minute}" for k in candidates] current_day = self._get_day_str() rpd_keys = [f"keypool:usage:rpd:{k.id}:{current_day}" for k in candidates] all_keys = concurrency_keys + penalty_keys + rpm_curr_keys + rpm_prev_keys + tpm_curr_keys + tpm_prev_keys + rpd_keys if not all_keys: return None, "No keys available (empty stats)" values = await self.redis.mget(all_keys) n = len(candidates) concurrency_values = values[:n] penalty_values = values[n:2*n] rpm_curr_values = values[2*n:3*n] rpm_prev_values = values[3*n:4*n] tpm_curr_values = values[4*n:5*n] tpm_prev_values = values[5*n:6*n] rpd_values = values[6*n:7*n] # Calculate window weight # If we are at second 0 of the minute, weight of prev is 1.0 (actually 0.99...) # If we are at second 59, weight of prev is ~0 # Formula: weight = (60 - seconds_elapsed) / 60 seconds_elapsed = now % 60 prev_weight = (60 - seconds_elapsed) / 60 available_keys = [] rejection_reasons = [] for i, key in enumerate(candidates): current_conc = int(concurrency_values[i]) if concurrency_values[i] else 0 rate_limit_penalty = int(penalty_values[i]) if penalty_values[i] else 0 r_curr = int(rpm_curr_values[i]) if rpm_curr_values[i] else 0 r_prev = int(rpm_prev_values[i]) if rpm_prev_values[i] else 0 t_curr = int(tpm_curr_values[i]) if tpm_curr_values[i] else 0 t_prev = int(tpm_prev_values[i]) if tpm_prev_values[i] else 0 current_rpd = int(rpd_values[i]) if rpd_values[i] else 0 current_rpm = int(r_curr + r_prev * prev_weight) current_tpm = int(t_curr + t_prev * prev_weight) key.current_concurrency = current_conc key.rate_limit_penalty = rate_limit_penalty key.current_rpm = current_rpm key.current_tpm = current_tpm key.current_rpd = current_rpd if key.status != KeyStatus.ACTIVE: # If force_select is True (explicit model specified), bypass status check if not force_select: rejection_reasons.append(f"{key.id}({key.owner}): Inactive (status={key.status})") continue if key.current_concurrency >= key.max_concurrency: rejection_reasons.append(f"{key.id}({key.owner}): Concurrency limit ({key.current_concurrency}/{key.max_concurrency})") continue if current_rpm >= key.rpm_limit: rejection_reasons.append(f"{key.id}({key.owner}): RPM limit ({current_rpm}/{key.rpm_limit})") continue if current_tpm >= key.tpm_limit: rejection_reasons.append(f"{key.id}({key.owner}): TPM limit ({current_tpm}/{key.tpm_limit})") continue if key.rpd_limit > 0 and current_rpd >= key.rpd_limit: rejection_reasons.append(f"{key.id}({key.owner}): RPD limit ({current_rpd}/{key.rpd_limit})") continue available_keys.append(key) if not available_keys: reason = "; ".join(rejection_reasons) if rejection_reasons else "No available keys found" return None, f"All keys unavailable: {reason}" def calculate_load_score(k: APIKey) -> float: conc_ratio = k.current_concurrency / k.max_concurrency if k.max_concurrency > 0 else 0 rpm_ratio = k.current_rpm / k.rpm_limit if k.rpm_limit > 0 else 0 tpm_ratio = k.current_tpm / k.tpm_limit if k.tpm_limit > 0 else 0 rpd_ratio = k.current_rpd / k.rpd_limit if k.rpd_limit > 0 else 0 return max(conc_ratio, rpm_ratio, tpm_ratio, rpd_ratio) def selection_key(k: APIKey) -> tuple[int, int, int, float]: provider_priority = self._get_provider_priority(k.provider) if k.rate_limit_penalty > 0: return (1, k.rate_limit_penalty, provider_priority, calculate_load_score(k)) return (0, provider_priority, 0, calculate_load_score(k)) best_key = min(available_keys, key=selection_key) return best_key, "" async def acquire_lease(self, key: APIKey) -> bool: """ Increments concurrency, RPM and RPD for the key. Double check limit in Redis to be safe (race conditions). """ key_concurrency_key = f"keypool:concurrency:{key.id}" now = time.time() current_minute = int(now / 60) prev_minute = current_minute - 1 current_day = self._get_day_str() key_rpm_key = f"keypool:usage:rpm:{key.id}:{current_minute}" key_rpm_prev_key = f"keypool:usage:rpm:{key.id}:{prev_minute}" key_rpd_key = f"keypool:usage:rpd:{key.id}:{current_day}" pipe = self.redis.pipeline() pipe.incr(key_concurrency_key) pipe.incr(key_rpm_key) pipe.expire(key_rpm_key, 65) pipe.get(key_rpm_prev_key) pipe.incr(key_rpd_key) # Keep daily counters for two days so the dashboard can read yesterday's RPD. pipe.expire(key_rpd_key, self.DAILY_RPD_TTL_SECONDS) results = await pipe.execute() curr_conc = results[0] curr_rpm_val = results[1] prev_rpm_val = int(results[3]) if results[3] else 0 curr_rpd_val = results[4] seconds_elapsed = now % 60 prev_weight = (60 - seconds_elapsed) / 60 estimated_rpm = curr_rpm_val + (prev_rpm_val * prev_weight) if curr_conc > key.max_concurrency: await self.redis.decr(key_concurrency_key) await self.redis.decr(key_rpm_key) await self.redis.decr(key_rpd_key) return False if estimated_rpm > key.rpm_limit: await self.redis.decr(key_concurrency_key) await self.redis.decr(key_rpm_key) await self.redis.decr(key_rpd_key) return False if key.rpd_limit > 0 and curr_rpd_val > key.rpd_limit: await self.redis.decr(key_concurrency_key) await self.redis.decr(key_rpm_key) await self.redis.decr(key_rpd_key) return False return True async def release_lease(self, key: APIKey): """Decrements concurrency.""" key_concurrency_key = f"keypool:concurrency:{key.id}" await self.redis.decr(key_concurrency_key) async def record_usage(self, key: APIKey, tokens: int): """Records Token usage (TPM).""" if tokens <= 0: return current_minute = int(time.time() / 60) key_tpm_key = f"keypool:usage:tpm:{key.id}:{current_minute}" pipe = self.redis.pipeline() pipe.incrby(key_tpm_key, tokens) pipe.expire(key_tpm_key, 65) await pipe.execute() async def report_failure(self, key: APIKey, is_rate_limit: bool = False, error_message: str = ""): """ Reports a failure. A 429 moves the key to the back of the priority queue. """ del error_message if is_rate_limit: seq = await self.redis.incr(self.RATE_LIMIT_PENALTY_SEQ_KEY) await self.redis.set(self._get_rate_limit_penalty_key(key.id), seq) async def get_all_key_stats(self) -> List[Dict[str, Any]]: """ Retrieves statistics for all keys, including status, concurrency, and usage limits. """ keys = self.key_manager.get_all_keys() if not keys: return [] # Prepare keys for batch fetching from Redis now = time.time() current_minute = int(now / 60) prev_minute = current_minute - 1 current_day = self._get_day_str() redis_keys = [] for key in keys: redis_keys.extend([ f"keypool:concurrency:{key.id}", self._get_rate_limit_penalty_key(key.id), f"keypool:usage:rpm:{key.id}:{current_minute}", f"keypool:usage:rpm:{key.id}:{prev_minute}", f"keypool:usage:tpm:{key.id}:{current_minute}", f"keypool:usage:tpm:{key.id}:{prev_minute}", f"keypool:usage:rpd:{key.id}:{current_day}" ]) if not redis_keys: return [] # Fetch all values in a single round-trip values = await self.redis.mget(redis_keys) stats = [] # Process results in chunks of 7 seconds_elapsed = now % 60 prev_weight = (60 - seconds_elapsed) / 60 for i, key in enumerate(keys): base_idx = i * 7 curr_conc = int(values[base_idx]) if values[base_idx] else 0 rate_limit_penalty = int(values[base_idx + 1]) if values[base_idx + 1] else 0 rpm_curr = int(values[base_idx + 2]) if values[base_idx + 2] else 0 rpm_prev = int(values[base_idx + 3]) if values[base_idx + 3] else 0 tpm_curr = int(values[base_idx + 4]) if values[base_idx + 4] else 0 tpm_prev = int(values[base_idx + 5]) if values[base_idx + 5] else 0 rpd_curr = int(values[base_idx + 6]) if values[base_idx + 6] else 0 estimated_rpm = int(rpm_curr + rpm_prev * prev_weight) estimated_tpm = int(tpm_curr + tpm_prev * prev_weight) # Determine effective status status = key.status if key.status == KeyStatus.ACTIVE and rate_limit_penalty > 0: status = KeyStatus.COOLDOWN key_stat = { "id": key.id, "model_name": key.model_name, "provider": key.provider, "status": status, "enabled": key.enabled, "config_id": getattr(key, 'config_id', None), "owner": getattr(key, 'owner', None), "endpoint_idx": getattr(key, 'endpoint_idx', None), "api_base": getattr(key, 'api_base', None), "limits": { "rpm": key.rpm_limit, "tpm": key.tpm_limit, "rpd": key.rpd_limit, "max_concurrency": key.max_concurrency }, "usage": { "current_concurrency": curr_conc, "current_rpm": estimated_rpm, "current_tpm": estimated_tpm, "current_rpd": rpd_curr }, "cooldown_remaining": 0 } stats.append(key_stat) return stats async def get_usage_summary(self) -> Dict[str, Any]: """ Get overall usage summary across all keys. Returns total RPM, TPM, RPD, concurrency, and active keys count. """ keys = self.key_manager.get_all_keys() if not keys: return { "total_rpm": 0, "total_tpm": 0, "total_rpd": 0, "total_yesterday_rpd": 0, "total_concurrency": 0, "active_keys": 0, "cooldown_keys": 0, "disabled_keys": 0, "total_keys": 0 } now = time.time() current_minute = int(now / 60) prev_minute = current_minute - 1 current_day = self._get_day_str() yesterday_day = self._get_day_str(-1) redis_keys = [] for key in keys: redis_keys.extend([ f"keypool:concurrency:{key.id}", self._get_rate_limit_penalty_key(key.id), f"keypool:usage:rpm:{key.id}:{current_minute}", f"keypool:usage:rpm:{key.id}:{prev_minute}", f"keypool:usage:tpm:{key.id}:{current_minute}", f"keypool:usage:tpm:{key.id}:{prev_minute}", f"keypool:usage:rpd:{key.id}:{current_day}", f"keypool:usage:rpd:{key.id}:{yesterday_day}" ]) if not redis_keys: return { "total_rpm": 0, "total_tpm": 0, "total_rpd": 0, "total_yesterday_rpd": 0, "total_concurrency": 0, "active_keys": 0, "cooldown_keys": 0, "disabled_keys": 0, "total_keys": 0 } values = await self.redis.mget(redis_keys) seconds_elapsed = now % 60 prev_weight = (60 - seconds_elapsed) / 60 total_rpm = 0 total_tpm = 0 total_rpd = 0 total_yesterday_rpd = 0 total_concurrency = 0 active_keys = 0 cooldown_keys = 0 disabled_keys = 0 for i, key in enumerate(keys): base_idx = i * 8 curr_conc = int(values[base_idx]) if values[base_idx] else 0 rate_limit_penalty = int(values[base_idx + 1]) if values[base_idx + 1] else 0 rpm_curr = int(values[base_idx + 2]) if values[base_idx + 2] else 0 rpm_prev = int(values[base_idx + 3]) if values[base_idx + 3] else 0 tpm_curr = int(values[base_idx + 4]) if values[base_idx + 4] else 0 tpm_prev = int(values[base_idx + 5]) if values[base_idx + 5] else 0 rpd_curr = int(values[base_idx + 6]) if values[base_idx + 6] else 0 rpd_yesterday = int(values[base_idx + 7]) if values[base_idx + 7] else 0 estimated_rpm = int(rpm_curr + rpm_prev * prev_weight) estimated_tpm = int(tpm_curr + tpm_prev * prev_weight) total_rpm += estimated_rpm total_tpm += estimated_tpm total_rpd += rpd_curr total_yesterday_rpd += rpd_yesterday total_concurrency += curr_conc if key.status == KeyStatus.ACTIVE and rate_limit_penalty == 0: active_keys += 1 elif key.status == KeyStatus.DISABLED: disabled_keys += 1 elif rate_limit_penalty > 0: cooldown_keys += 1 else: active_keys += 1 return { "total_rpm": total_rpm, "total_tpm": total_tpm, "total_rpd": total_rpd, "total_yesterday_rpd": total_yesterday_rpd, "total_concurrency": total_concurrency, "active_keys": active_keys, "cooldown_keys": cooldown_keys, "disabled_keys": disabled_keys, "total_keys": len(keys) } async def get_daily_usage_history(self, days: int = 7) -> List[Dict[str, Any]]: """ Aggregate total RPD usage for the most recent N days. Uses the same UTC day buckets as runtime RPD limiting. """ keys = self.key_manager.get_all_keys() if not keys or days <= 0: return [] day_labels = [self._get_day_str(-offset) for offset in range(days - 1, -1, -1)] redis_keys = [] for day_str in day_labels: for key in keys: redis_keys.append(f"keypool:usage:rpd:{key.id}:{day_str}") values = await self.redis.mget(redis_keys) history = [] key_count = len(keys) for day_index, day_str in enumerate(day_labels): base_idx = day_index * key_count day_total = 0 for key_offset in range(key_count): value = values[base_idx + key_offset] day_total += int(value) if value else 0 history.append({ "day": day_str, "label": f"{day_str[4:6]}-{day_str[6:8]}", "total_rpd": day_total, }) return history