autool-test/app/core/dispatcher.py
2026-06-17 11:13:11 +08:00

557 lines
22 KiB
Python
Executable File

import time
import asyncio
import fnmatch
from datetime import datetime, timedelta, timezone
from typing import Optional, List, Dict, Any
from app.models import APIKey, KeyStatus
from app.core.key_manager import KeyManager
from app.utils.redis_client import RedisClient
class Dispatcher:
DAILY_RPD_TTL_SECONDS = 90 * 86400
RATE_LIMIT_PENALTY_SEQ_KEY = "keypool:priority:429:seq"
PROVIDER_PRIORITIES = {
"modelscope": 0,
"nvidia": 0,
"siliconflow": 0,
"llama-local": 0,
}
def __init__(self, key_manager: KeyManager):
self.key_manager = key_manager
self.redis = RedisClient.get_instance()
@staticmethod
def _get_day_str(offset_days: int = 0) -> str:
current_day = datetime.now(timezone.utc) + timedelta(days=offset_days)
return current_day.strftime("%Y%m%d")
@staticmethod
def _get_rate_limit_penalty_key(key_id: str) -> str:
return f"keypool:priority:429:{key_id}"
@classmethod
def _get_provider_priority(cls, provider: str) -> int:
return cls.PROVIDER_PRIORITIES.get(provider.lower(), len(cls.PROVIDER_PRIORITIES))
async def _delete_matching_keys(self, pattern: str, batch_size: int = 200) -> int:
if hasattr(self.redis, "scan_iter"):
batch = []
deleted = 0
async for key in self.redis.scan_iter(match=pattern, count=batch_size):
batch.append(key)
if len(batch) >= batch_size:
deleted += await self.redis.delete(*batch)
batch.clear()
if batch:
deleted += await self.redis.delete(*batch)
return deleted
if not hasattr(self.redis, "values"):
raise RuntimeError("Redis client does not support scan_iter")
keys = [key for key in self.redis.values if fnmatch.fnmatch(key, pattern)]
return await self.redis.delete(*keys) if keys else 0
async def reset_today_rpd_usage(self) -> int:
day = self._get_day_str()
return await self._delete_matching_keys(f"keypool:usage:rpd:*:{day}")
async def get_key_stats(self, key: APIKey):
"""Fetches runtime stats from Redis for a single key."""
concurrency_key = f"keypool:concurrency:{key.id}"
penalty_key = self._get_rate_limit_penalty_key(key.id)
current_day = self._get_day_str()
rpd_key = f"keypool:usage:rpd:{key.id}:{current_day}"
# Pipeline for efficiency
pipe = self.redis.pipeline()
pipe.get(concurrency_key)
pipe.get(penalty_key)
pipe.get(rpd_key)
results = await pipe.execute()
current_concurrency = int(results[0]) if results[0] else 0
rate_limit_penalty = int(results[1]) if results[1] else 0
current_rpd = int(results[2]) if results[2] else 0
key.current_concurrency = current_concurrency
key.rate_limit_penalty = rate_limit_penalty
key.current_rpd = current_rpd
return key
async def select_key(self, model_name: Optional[str] = None, timeout: float = 30.0) -> tuple[Optional[APIKey], str]:
"""
Selects the best available key for the given model.
If model_name is None, selects from ALL available keys.
If no keys are available immediately, it will wait (queue) up to `timeout` seconds.
Strategy:
1. Try to find an available key immediately.
2. If none found, wait for a short interval and retry (polling/queue simulation).
In a production system, we might use a real Redis list/pubsub for queuing,
but for now, we'll use a local async loop with exponential backoff or simple polling.
Returns: (key, error_reason)
Note: If model_name is explicitly specified (force_select=True), the status check
for DISABLED keys is bypassed, allowing access to disabled models.
"""
start_time = time.time()
# If model_name is explicitly specified, force select (bypass enabled check)
force_select = bool(model_name)
while True:
# 1. Try to select a key
key, error = await self._try_select_key_once(model_name, force_select=force_select)
if key:
return key, ""
# 2. Check timeout
if time.time() - start_time > timeout:
return None, f"Timeout waiting for available key: {error}"
# 3. Wait before retrying (simple polling).
await asyncio.sleep(0.5)
async def _try_select_key_once(self, model_name: Optional[str] = None, force_select: bool = False) -> tuple[Optional[APIKey], str]:
"""
Internal method to try selecting a key once without waiting.
Strategy:
1. Filter by status and hard limits.
2. Prefer keys that never hit 429.
3. Prefer providers by static rank: modelscope > nvidia > siliconflow.
4. For keys that did hit 429, use the demotion order as a queue.
5. Use load as the final tie breaker inside the same priority tier.
Args:
model_name: Optional model name to filter by
force_select: If True, bypass the status check (allows accessing disabled models)
Returns: (key, error_reason)
"""
candidates = self.key_manager.get_candidate_keys(model_name)
if not candidates:
return None, "No keys configured for this model" if model_name else "No keys configured in the pool"
# Prepare keys for mget
concurrency_keys = [f"keypool:concurrency:{k.id}" for k in candidates]
penalty_keys = [self._get_rate_limit_penalty_key(k.id) for k in candidates]
now = time.time()
current_minute = int(now / 60)
prev_minute = current_minute - 1
rpm_curr_keys = [f"keypool:usage:rpm:{k.id}:{current_minute}" for k in candidates]
rpm_prev_keys = [f"keypool:usage:rpm:{k.id}:{prev_minute}" for k in candidates]
tpm_curr_keys = [f"keypool:usage:tpm:{k.id}:{current_minute}" for k in candidates]
tpm_prev_keys = [f"keypool:usage:tpm:{k.id}:{prev_minute}" for k in candidates]
current_day = self._get_day_str()
rpd_keys = [f"keypool:usage:rpd:{k.id}:{current_day}" for k in candidates]
all_keys = concurrency_keys + penalty_keys + rpm_curr_keys + rpm_prev_keys + tpm_curr_keys + tpm_prev_keys + rpd_keys
if not all_keys:
return None, "No keys available (empty stats)"
values = await self.redis.mget(all_keys)
n = len(candidates)
concurrency_values = values[:n]
penalty_values = values[n:2*n]
rpm_curr_values = values[2*n:3*n]
rpm_prev_values = values[3*n:4*n]
tpm_curr_values = values[4*n:5*n]
tpm_prev_values = values[5*n:6*n]
rpd_values = values[6*n:7*n]
# Calculate window weight
# If we are at second 0 of the minute, weight of prev is 1.0 (actually 0.99...)
# If we are at second 59, weight of prev is ~0
# Formula: weight = (60 - seconds_elapsed) / 60
seconds_elapsed = now % 60
prev_weight = (60 - seconds_elapsed) / 60
available_keys = []
rejection_reasons = []
for i, key in enumerate(candidates):
current_conc = int(concurrency_values[i]) if concurrency_values[i] else 0
rate_limit_penalty = int(penalty_values[i]) if penalty_values[i] else 0
r_curr = int(rpm_curr_values[i]) if rpm_curr_values[i] else 0
r_prev = int(rpm_prev_values[i]) if rpm_prev_values[i] else 0
t_curr = int(tpm_curr_values[i]) if tpm_curr_values[i] else 0
t_prev = int(tpm_prev_values[i]) if tpm_prev_values[i] else 0
current_rpd = int(rpd_values[i]) if rpd_values[i] else 0
current_rpm = int(r_curr + r_prev * prev_weight)
current_tpm = int(t_curr + t_prev * prev_weight)
key.current_concurrency = current_conc
key.rate_limit_penalty = rate_limit_penalty
key.current_rpm = current_rpm
key.current_tpm = current_tpm
key.current_rpd = current_rpd
if key.status != KeyStatus.ACTIVE:
# If force_select is True (explicit model specified), bypass status check
if not force_select:
rejection_reasons.append(f"{key.id}({key.owner}): Inactive (status={key.status})")
continue
if key.current_concurrency >= key.max_concurrency:
rejection_reasons.append(f"{key.id}({key.owner}): Concurrency limit ({key.current_concurrency}/{key.max_concurrency})")
continue
if current_rpm >= key.rpm_limit:
rejection_reasons.append(f"{key.id}({key.owner}): RPM limit ({current_rpm}/{key.rpm_limit})")
continue
if current_tpm >= key.tpm_limit:
rejection_reasons.append(f"{key.id}({key.owner}): TPM limit ({current_tpm}/{key.tpm_limit})")
continue
if key.rpd_limit > 0 and current_rpd >= key.rpd_limit:
rejection_reasons.append(f"{key.id}({key.owner}): RPD limit ({current_rpd}/{key.rpd_limit})")
continue
available_keys.append(key)
if not available_keys:
reason = "; ".join(rejection_reasons) if rejection_reasons else "No available keys found"
return None, f"All keys unavailable: {reason}"
def calculate_load_score(k: APIKey) -> float:
conc_ratio = k.current_concurrency / k.max_concurrency if k.max_concurrency > 0 else 0
rpm_ratio = k.current_rpm / k.rpm_limit if k.rpm_limit > 0 else 0
tpm_ratio = k.current_tpm / k.tpm_limit if k.tpm_limit > 0 else 0
rpd_ratio = k.current_rpd / k.rpd_limit if k.rpd_limit > 0 else 0
return max(conc_ratio, rpm_ratio, tpm_ratio, rpd_ratio)
def selection_key(k: APIKey) -> tuple[int, int, int, float]:
provider_priority = self._get_provider_priority(k.provider)
if k.rate_limit_penalty > 0:
return (1, k.rate_limit_penalty, provider_priority, calculate_load_score(k))
return (0, provider_priority, 0, calculate_load_score(k))
best_key = min(available_keys, key=selection_key)
return best_key, ""
async def acquire_lease(self, key: APIKey) -> bool:
"""
Increments concurrency, RPM and RPD for the key.
Double check limit in Redis to be safe (race conditions).
"""
key_concurrency_key = f"keypool:concurrency:{key.id}"
now = time.time()
current_minute = int(now / 60)
prev_minute = current_minute - 1
current_day = self._get_day_str()
key_rpm_key = f"keypool:usage:rpm:{key.id}:{current_minute}"
key_rpm_prev_key = f"keypool:usage:rpm:{key.id}:{prev_minute}"
key_rpd_key = f"keypool:usage:rpd:{key.id}:{current_day}"
pipe = self.redis.pipeline()
pipe.incr(key_concurrency_key)
pipe.incr(key_rpm_key)
pipe.expire(key_rpm_key, 65)
pipe.get(key_rpm_prev_key)
pipe.incr(key_rpd_key)
# Keep daily counters for two days so the dashboard can read yesterday's RPD.
pipe.expire(key_rpd_key, self.DAILY_RPD_TTL_SECONDS)
results = await pipe.execute()
curr_conc = results[0]
curr_rpm_val = results[1]
prev_rpm_val = int(results[3]) if results[3] else 0
curr_rpd_val = results[4]
seconds_elapsed = now % 60
prev_weight = (60 - seconds_elapsed) / 60
estimated_rpm = curr_rpm_val + (prev_rpm_val * prev_weight)
if curr_conc > key.max_concurrency:
await self.redis.decr(key_concurrency_key)
await self.redis.decr(key_rpm_key)
await self.redis.decr(key_rpd_key)
return False
if estimated_rpm > key.rpm_limit:
await self.redis.decr(key_concurrency_key)
await self.redis.decr(key_rpm_key)
await self.redis.decr(key_rpd_key)
return False
if key.rpd_limit > 0 and curr_rpd_val > key.rpd_limit:
await self.redis.decr(key_concurrency_key)
await self.redis.decr(key_rpm_key)
await self.redis.decr(key_rpd_key)
return False
return True
async def release_lease(self, key: APIKey):
"""Decrements concurrency."""
key_concurrency_key = f"keypool:concurrency:{key.id}"
await self.redis.decr(key_concurrency_key)
async def record_usage(self, key: APIKey, tokens: int):
"""Records Token usage (TPM)."""
if tokens <= 0:
return
current_minute = int(time.time() / 60)
key_tpm_key = f"keypool:usage:tpm:{key.id}:{current_minute}"
pipe = self.redis.pipeline()
pipe.incrby(key_tpm_key, tokens)
pipe.expire(key_tpm_key, 65)
await pipe.execute()
async def report_failure(self, key: APIKey, is_rate_limit: bool = False, error_message: str = ""):
"""
Reports a failure. A 429 moves the key to the back of the priority queue.
"""
del error_message
if is_rate_limit:
seq = await self.redis.incr(self.RATE_LIMIT_PENALTY_SEQ_KEY)
await self.redis.set(self._get_rate_limit_penalty_key(key.id), seq)
async def get_all_key_stats(self) -> List[Dict[str, Any]]:
"""
Retrieves statistics for all keys, including status, concurrency, and usage limits.
"""
keys = self.key_manager.get_all_keys()
if not keys:
return []
# Prepare keys for batch fetching from Redis
now = time.time()
current_minute = int(now / 60)
prev_minute = current_minute - 1
current_day = self._get_day_str()
redis_keys = []
for key in keys:
redis_keys.extend([
f"keypool:concurrency:{key.id}",
self._get_rate_limit_penalty_key(key.id),
f"keypool:usage:rpm:{key.id}:{current_minute}",
f"keypool:usage:rpm:{key.id}:{prev_minute}",
f"keypool:usage:tpm:{key.id}:{current_minute}",
f"keypool:usage:tpm:{key.id}:{prev_minute}",
f"keypool:usage:rpd:{key.id}:{current_day}"
])
if not redis_keys:
return []
# Fetch all values in a single round-trip
values = await self.redis.mget(redis_keys)
stats = []
# Process results in chunks of 7
seconds_elapsed = now % 60
prev_weight = (60 - seconds_elapsed) / 60
for i, key in enumerate(keys):
base_idx = i * 7
curr_conc = int(values[base_idx]) if values[base_idx] else 0
rate_limit_penalty = int(values[base_idx + 1]) if values[base_idx + 1] else 0
rpm_curr = int(values[base_idx + 2]) if values[base_idx + 2] else 0
rpm_prev = int(values[base_idx + 3]) if values[base_idx + 3] else 0
tpm_curr = int(values[base_idx + 4]) if values[base_idx + 4] else 0
tpm_prev = int(values[base_idx + 5]) if values[base_idx + 5] else 0
rpd_curr = int(values[base_idx + 6]) if values[base_idx + 6] else 0
estimated_rpm = int(rpm_curr + rpm_prev * prev_weight)
estimated_tpm = int(tpm_curr + tpm_prev * prev_weight)
# Determine effective status
status = key.status
if key.status == KeyStatus.ACTIVE and rate_limit_penalty > 0:
status = KeyStatus.COOLDOWN
key_stat = {
"id": key.id,
"model_name": key.model_name,
"provider": key.provider,
"status": status,
"enabled": key.enabled,
"config_id": getattr(key, 'config_id', None),
"owner": getattr(key, 'owner', None),
"endpoint_idx": getattr(key, 'endpoint_idx', None),
"api_base": getattr(key, 'api_base', None),
"limits": {
"rpm": key.rpm_limit,
"tpm": key.tpm_limit,
"rpd": key.rpd_limit,
"max_concurrency": key.max_concurrency
},
"usage": {
"current_concurrency": curr_conc,
"current_rpm": estimated_rpm,
"current_tpm": estimated_tpm,
"current_rpd": rpd_curr
},
"cooldown_remaining": 0
}
stats.append(key_stat)
return stats
async def get_usage_summary(self) -> Dict[str, Any]:
"""
Get overall usage summary across all keys.
Returns total RPM, TPM, RPD, concurrency, and active keys count.
"""
keys = self.key_manager.get_all_keys()
if not keys:
return {
"total_rpm": 0,
"total_tpm": 0,
"total_rpd": 0,
"total_yesterday_rpd": 0,
"total_concurrency": 0,
"active_keys": 0,
"cooldown_keys": 0,
"disabled_keys": 0,
"total_keys": 0
}
now = time.time()
current_minute = int(now / 60)
prev_minute = current_minute - 1
current_day = self._get_day_str()
yesterday_day = self._get_day_str(-1)
redis_keys = []
for key in keys:
redis_keys.extend([
f"keypool:concurrency:{key.id}",
self._get_rate_limit_penalty_key(key.id),
f"keypool:usage:rpm:{key.id}:{current_minute}",
f"keypool:usage:rpm:{key.id}:{prev_minute}",
f"keypool:usage:tpm:{key.id}:{current_minute}",
f"keypool:usage:tpm:{key.id}:{prev_minute}",
f"keypool:usage:rpd:{key.id}:{current_day}",
f"keypool:usage:rpd:{key.id}:{yesterday_day}"
])
if not redis_keys:
return {
"total_rpm": 0,
"total_tpm": 0,
"total_rpd": 0,
"total_yesterday_rpd": 0,
"total_concurrency": 0,
"active_keys": 0,
"cooldown_keys": 0,
"disabled_keys": 0,
"total_keys": 0
}
values = await self.redis.mget(redis_keys)
seconds_elapsed = now % 60
prev_weight = (60 - seconds_elapsed) / 60
total_rpm = 0
total_tpm = 0
total_rpd = 0
total_yesterday_rpd = 0
total_concurrency = 0
active_keys = 0
cooldown_keys = 0
disabled_keys = 0
for i, key in enumerate(keys):
base_idx = i * 8
curr_conc = int(values[base_idx]) if values[base_idx] else 0
rate_limit_penalty = int(values[base_idx + 1]) if values[base_idx + 1] else 0
rpm_curr = int(values[base_idx + 2]) if values[base_idx + 2] else 0
rpm_prev = int(values[base_idx + 3]) if values[base_idx + 3] else 0
tpm_curr = int(values[base_idx + 4]) if values[base_idx + 4] else 0
tpm_prev = int(values[base_idx + 5]) if values[base_idx + 5] else 0
rpd_curr = int(values[base_idx + 6]) if values[base_idx + 6] else 0
rpd_yesterday = int(values[base_idx + 7]) if values[base_idx + 7] else 0
estimated_rpm = int(rpm_curr + rpm_prev * prev_weight)
estimated_tpm = int(tpm_curr + tpm_prev * prev_weight)
total_rpm += estimated_rpm
total_tpm += estimated_tpm
total_rpd += rpd_curr
total_yesterday_rpd += rpd_yesterday
total_concurrency += curr_conc
if key.status == KeyStatus.ACTIVE and rate_limit_penalty == 0:
active_keys += 1
elif key.status == KeyStatus.DISABLED:
disabled_keys += 1
elif rate_limit_penalty > 0:
cooldown_keys += 1
else:
active_keys += 1
return {
"total_rpm": total_rpm,
"total_tpm": total_tpm,
"total_rpd": total_rpd,
"total_yesterday_rpd": total_yesterday_rpd,
"total_concurrency": total_concurrency,
"active_keys": active_keys,
"cooldown_keys": cooldown_keys,
"disabled_keys": disabled_keys,
"total_keys": len(keys)
}
async def get_daily_usage_history(self, days: int = 7) -> List[Dict[str, Any]]:
"""
Aggregate total RPD usage for the most recent N days.
Uses the same UTC day buckets as runtime RPD limiting.
"""
keys = self.key_manager.get_all_keys()
if not keys or days <= 0:
return []
day_labels = [self._get_day_str(-offset) for offset in range(days - 1, -1, -1)]
redis_keys = []
for day_str in day_labels:
for key in keys:
redis_keys.append(f"keypool:usage:rpd:{key.id}:{day_str}")
values = await self.redis.mget(redis_keys)
history = []
key_count = len(keys)
for day_index, day_str in enumerate(day_labels):
base_idx = day_index * key_count
day_total = 0
for key_offset in range(key_count):
value = values[base_idx + key_offset]
day_total += int(value) if value else 0
history.append({
"day": day_str,
"label": f"{day_str[4:6]}-{day_str[6:8]}",
"total_rpd": day_total,
})
return history