This commit is contained in:
yufengzhen 2026-06-17 11:13:11 +08:00
commit fb7f06ed03
14 changed files with 3050 additions and 0 deletions

3
.gitignore vendored Executable file
View File

@ -0,0 +1,3 @@
__pycache__/
test/*
keys.yaml

227
app/api/routes.py Executable file
View File

@ -0,0 +1,227 @@
from fastapi import APIRouter, HTTPException, Request, BackgroundTasks
from app.core.providers import get_provider_handler
from app.runtime import key_manager, dispatcher, stats_exporter
import httpx
import time
import logging
import traceback
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/v1/stats")
async def get_stats():
"""
Returns usage statistics for all API keys.
"""
stats = await dispatcher.get_all_key_stats()
return {"data": stats}
@router.get("/v1/usage")
async def get_usage():
"""
Returns overall usage summary.
"""
summary = await dispatcher.get_usage_summary()
return {"data": summary}
@router.post("/v1/stats/reset/rpd")
async def reset_today_rpd():
"""
Reset today's RPD counters in Redis and refresh dashboard output.
"""
day = dispatcher._get_day_str()
deleted_keys = await dispatcher.reset_today_rpd_usage()
stats_exporter.clear_daily_usage_day(day)
await stats_exporter.export_once()
return {
"success": True,
"day": day,
"deleted_keys": deleted_keys,
"message": f"Reset today's RPD counters for {deleted_keys} keys",
}
@router.post("/v1/keys/config/{config_id}/enable")
async def enable_config(config_id: str):
"""
Enable all keys for a specific config_id (auto-saves to config file).
"""
success = key_manager.set_config_enabled(config_id, True)
if success:
return {"success": True, "message": f"Config {config_id} enabled and saved"}
else:
raise HTTPException(status_code=404, detail=f"Config {config_id} not found")
@router.post("/v1/keys/config/{config_id}/disable")
async def disable_config(config_id: str):
"""
Disable all keys for a specific config_id (auto-saves to config file).
"""
success = key_manager.set_config_enabled(config_id, False)
if success:
return {"success": True, "message": f"Config {config_id} disabled and saved"}
else:
raise HTTPException(status_code=404, detail=f"Config {config_id} not found")
@router.post("/v1/keys/{key_id}/enable")
async def enable_key(key_id: str):
"""
Enable a specific key.
"""
success = key_manager.set_key_enabled(key_id, True)
if success:
return {"success": True, "message": f"Key {key_id} enabled"}
else:
raise HTTPException(status_code=404, detail=f"Key {key_id} not found")
@router.post("/v1/keys/{key_id}/disable")
async def disable_key(key_id: str):
"""
Disable a specific key.
"""
success = key_manager.set_key_enabled(key_id, False)
if success:
return {"success": True, "message": f"Key {key_id} disabled"}
else:
raise HTTPException(status_code=404, detail=f"Key {key_id} not found")
@router.post("/v1/keys/config/{config_id}/endpoint/{endpoint_idx}/enable")
async def enable_endpoint(config_id: str, endpoint_idx: int):
"""
Enable a specific endpoint for a config (auto-saves to config file).
"""
success = key_manager.set_endpoint_enabled(config_id, endpoint_idx, True)
if success:
return {"success": True, "message": f"Endpoint {endpoint_idx} of {config_id} enabled and saved"}
else:
raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_idx} of {config_id} not found")
@router.post("/v1/keys/config/{config_id}/endpoint/{endpoint_idx}/disable")
async def disable_endpoint(config_id: str, endpoint_idx: int):
"""
Disable a specific endpoint for a config (auto-saves to config file).
"""
success = key_manager.set_endpoint_enabled(config_id, endpoint_idx, False)
if success:
return {"success": True, "message": f"Endpoint {endpoint_idx} of {config_id} disabled and saved"}
else:
raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_idx} of {config_id} not found")
@router.get("/v1/configs")
async def get_configs():
"""
Returns all config IDs.
"""
configs = []
for config_id in key_manager.get_config_ids():
keys = key_manager.get_keys_by_config(config_id)
if keys:
configs.append({
"config_id": config_id,
"model_name": keys[0].model_name,
"provider": keys[0].provider,
"key_count": len(keys),
"enabled": keys[0].enabled
})
return {"data": configs}
@router.post("/v1/chat/completions")
async def chat_completions(request: Request, background_tasks: BackgroundTasks):
"""
Unified entry point.
Expects JSON body with "model" field.
"""
try:
body = await request.json()
except:
logger.error(f"[Invalid JSON] detail={traceback.format_exc()}...")
raise HTTPException(status_code=400, detail="Invalid JSON")
model_name = body.get("model")
key, error_reason = await dispatcher.select_key(model_name)
if not key:
logger.error(f"[Key Not Found] model={model_name} | error={error_reason}...")
raise HTTPException(status_code=503, detail=error_reason)
if not model_name:
model_name = key.model_name
acquired = await dispatcher.acquire_lease(key)
if not acquired:
logger.error(f"[System busy (Concurrency Limit)] model={model_name} | provider={key.provider} | key={key.id}|owner={key.owner or 'N/A'}|key_prefix={key.key[:20] if key.key else 'N/A'}...")
raise HTTPException(status_code=503, detail="System busy (Concurrency Limit)")
try:
handler = get_provider_handler(key)
url = handler.get_url()
headers = handler.get_headers()
params = {k: v for k, v in body.items() if k not in ["model", "messages", "stream"]}
payload = handler.get_payload(body.get("messages", []), params)
timeout = httpx.Timeout(300.0, connect=10.0)
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(url, headers=headers, json=payload)
if resp.status_code != 200:
error_detail = resp.text[:500] if len(resp.text) > 500 else resp.text
logger.error(
f"[Request Failed] model={model_name} | key={key.id}|owner={key.owner or 'N/A'}|key_prefix={key.key[:20] if key.key else 'N/A'} | "
f"api_base={key.api_base} | "
f"status={resp.status_code} | response={error_detail}"
)
if resp.status_code == 429:
error_message = ""
try:
error_data = resp.json()
if isinstance(error_data, dict):
error_message = error_data.get("error", {}).get("message", "") or error_data.get("message", "")
except:
error_message = resp.text
await dispatcher.report_failure(key, is_rate_limit=True, error_message=error_message)
raise HTTPException(status_code=resp.status_code, detail=f"Provider Error: {error_detail}")
data = resp.json()
parsed_result = handler.parse_response(data)
total_tokens = parsed_result["usage"].get("total_tokens", 0)
background_tasks.add_task(dispatcher.record_usage, key, total_tokens)
response_data = {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": parsed_result["content"],
"reasoning_content": parsed_result.get("reasoning_content")
},
"finish_reason": "stop"
}],
"usage": parsed_result["usage"]
}
return response_data
except HTTPException:
raise
except Exception as e:
_key_info = f"{key.id}|owner={key.owner or 'N/A'}|key_prefix={key.key[:20] if key.key else 'N/A'}" if key else 'N/A'
logger.exception(
f"[Internal Error] model={model_name} | "
f"key={_key_info} | api_base={getattr(key, 'api_base', 'N/A') if key else 'N/A'} | error={str(e)}"
)
raise HTTPException(status_code=500, detail=str(e))
finally:
await dispatcher.release_lease(key)

556
app/core/dispatcher.py Executable file
View File

@ -0,0 +1,556 @@
import time
import asyncio
import fnmatch
from datetime import datetime, timedelta, timezone
from typing import Optional, List, Dict, Any
from app.models import APIKey, KeyStatus
from app.core.key_manager import KeyManager
from app.utils.redis_client import RedisClient
class Dispatcher:
DAILY_RPD_TTL_SECONDS = 90 * 86400
RATE_LIMIT_PENALTY_SEQ_KEY = "keypool:priority:429:seq"
PROVIDER_PRIORITIES = {
"modelscope": 0,
"nvidia": 0,
"siliconflow": 0,
"llama-local": 0,
}
def __init__(self, key_manager: KeyManager):
self.key_manager = key_manager
self.redis = RedisClient.get_instance()
@staticmethod
def _get_day_str(offset_days: int = 0) -> str:
current_day = datetime.now(timezone.utc) + timedelta(days=offset_days)
return current_day.strftime("%Y%m%d")
@staticmethod
def _get_rate_limit_penalty_key(key_id: str) -> str:
return f"keypool:priority:429:{key_id}"
@classmethod
def _get_provider_priority(cls, provider: str) -> int:
return cls.PROVIDER_PRIORITIES.get(provider.lower(), len(cls.PROVIDER_PRIORITIES))
async def _delete_matching_keys(self, pattern: str, batch_size: int = 200) -> int:
if hasattr(self.redis, "scan_iter"):
batch = []
deleted = 0
async for key in self.redis.scan_iter(match=pattern, count=batch_size):
batch.append(key)
if len(batch) >= batch_size:
deleted += await self.redis.delete(*batch)
batch.clear()
if batch:
deleted += await self.redis.delete(*batch)
return deleted
if not hasattr(self.redis, "values"):
raise RuntimeError("Redis client does not support scan_iter")
keys = [key for key in self.redis.values if fnmatch.fnmatch(key, pattern)]
return await self.redis.delete(*keys) if keys else 0
async def reset_today_rpd_usage(self) -> int:
day = self._get_day_str()
return await self._delete_matching_keys(f"keypool:usage:rpd:*:{day}")
async def get_key_stats(self, key: APIKey):
"""Fetches runtime stats from Redis for a single key."""
concurrency_key = f"keypool:concurrency:{key.id}"
penalty_key = self._get_rate_limit_penalty_key(key.id)
current_day = self._get_day_str()
rpd_key = f"keypool:usage:rpd:{key.id}:{current_day}"
# Pipeline for efficiency
pipe = self.redis.pipeline()
pipe.get(concurrency_key)
pipe.get(penalty_key)
pipe.get(rpd_key)
results = await pipe.execute()
current_concurrency = int(results[0]) if results[0] else 0
rate_limit_penalty = int(results[1]) if results[1] else 0
current_rpd = int(results[2]) if results[2] else 0
key.current_concurrency = current_concurrency
key.rate_limit_penalty = rate_limit_penalty
key.current_rpd = current_rpd
return key
async def select_key(self, model_name: Optional[str] = None, timeout: float = 30.0) -> tuple[Optional[APIKey], str]:
"""
Selects the best available key for the given model.
If model_name is None, selects from ALL available keys.
If no keys are available immediately, it will wait (queue) up to `timeout` seconds.
Strategy:
1. Try to find an available key immediately.
2. If none found, wait for a short interval and retry (polling/queue simulation).
In a production system, we might use a real Redis list/pubsub for queuing,
but for now, we'll use a local async loop with exponential backoff or simple polling.
Returns: (key, error_reason)
Note: If model_name is explicitly specified (force_select=True), the status check
for DISABLED keys is bypassed, allowing access to disabled models.
"""
start_time = time.time()
# If model_name is explicitly specified, force select (bypass enabled check)
force_select = bool(model_name)
while True:
# 1. Try to select a key
key, error = await self._try_select_key_once(model_name, force_select=force_select)
if key:
return key, ""
# 2. Check timeout
if time.time() - start_time > timeout:
return None, f"Timeout waiting for available key: {error}"
# 3. Wait before retrying (simple polling).
await asyncio.sleep(0.5)
async def _try_select_key_once(self, model_name: Optional[str] = None, force_select: bool = False) -> tuple[Optional[APIKey], str]:
"""
Internal method to try selecting a key once without waiting.
Strategy:
1. Filter by status and hard limits.
2. Prefer keys that never hit 429.
3. Prefer providers by static rank: modelscope > nvidia > siliconflow.
4. For keys that did hit 429, use the demotion order as a queue.
5. Use load as the final tie breaker inside the same priority tier.
Args:
model_name: Optional model name to filter by
force_select: If True, bypass the status check (allows accessing disabled models)
Returns: (key, error_reason)
"""
candidates = self.key_manager.get_candidate_keys(model_name)
if not candidates:
return None, "No keys configured for this model" if model_name else "No keys configured in the pool"
# Prepare keys for mget
concurrency_keys = [f"keypool:concurrency:{k.id}" for k in candidates]
penalty_keys = [self._get_rate_limit_penalty_key(k.id) for k in candidates]
now = time.time()
current_minute = int(now / 60)
prev_minute = current_minute - 1
rpm_curr_keys = [f"keypool:usage:rpm:{k.id}:{current_minute}" for k in candidates]
rpm_prev_keys = [f"keypool:usage:rpm:{k.id}:{prev_minute}" for k in candidates]
tpm_curr_keys = [f"keypool:usage:tpm:{k.id}:{current_minute}" for k in candidates]
tpm_prev_keys = [f"keypool:usage:tpm:{k.id}:{prev_minute}" for k in candidates]
current_day = self._get_day_str()
rpd_keys = [f"keypool:usage:rpd:{k.id}:{current_day}" for k in candidates]
all_keys = concurrency_keys + penalty_keys + rpm_curr_keys + rpm_prev_keys + tpm_curr_keys + tpm_prev_keys + rpd_keys
if not all_keys:
return None, "No keys available (empty stats)"
values = await self.redis.mget(all_keys)
n = len(candidates)
concurrency_values = values[:n]
penalty_values = values[n:2*n]
rpm_curr_values = values[2*n:3*n]
rpm_prev_values = values[3*n:4*n]
tpm_curr_values = values[4*n:5*n]
tpm_prev_values = values[5*n:6*n]
rpd_values = values[6*n:7*n]
# Calculate window weight
# If we are at second 0 of the minute, weight of prev is 1.0 (actually 0.99...)
# If we are at second 59, weight of prev is ~0
# Formula: weight = (60 - seconds_elapsed) / 60
seconds_elapsed = now % 60
prev_weight = (60 - seconds_elapsed) / 60
available_keys = []
rejection_reasons = []
for i, key in enumerate(candidates):
current_conc = int(concurrency_values[i]) if concurrency_values[i] else 0
rate_limit_penalty = int(penalty_values[i]) if penalty_values[i] else 0
r_curr = int(rpm_curr_values[i]) if rpm_curr_values[i] else 0
r_prev = int(rpm_prev_values[i]) if rpm_prev_values[i] else 0
t_curr = int(tpm_curr_values[i]) if tpm_curr_values[i] else 0
t_prev = int(tpm_prev_values[i]) if tpm_prev_values[i] else 0
current_rpd = int(rpd_values[i]) if rpd_values[i] else 0
current_rpm = int(r_curr + r_prev * prev_weight)
current_tpm = int(t_curr + t_prev * prev_weight)
key.current_concurrency = current_conc
key.rate_limit_penalty = rate_limit_penalty
key.current_rpm = current_rpm
key.current_tpm = current_tpm
key.current_rpd = current_rpd
if key.status != KeyStatus.ACTIVE:
# If force_select is True (explicit model specified), bypass status check
if not force_select:
rejection_reasons.append(f"{key.id}({key.owner}): Inactive (status={key.status})")
continue
if key.current_concurrency >= key.max_concurrency:
rejection_reasons.append(f"{key.id}({key.owner}): Concurrency limit ({key.current_concurrency}/{key.max_concurrency})")
continue
if current_rpm >= key.rpm_limit:
rejection_reasons.append(f"{key.id}({key.owner}): RPM limit ({current_rpm}/{key.rpm_limit})")
continue
if current_tpm >= key.tpm_limit:
rejection_reasons.append(f"{key.id}({key.owner}): TPM limit ({current_tpm}/{key.tpm_limit})")
continue
if key.rpd_limit > 0 and current_rpd >= key.rpd_limit:
rejection_reasons.append(f"{key.id}({key.owner}): RPD limit ({current_rpd}/{key.rpd_limit})")
continue
available_keys.append(key)
if not available_keys:
reason = "; ".join(rejection_reasons) if rejection_reasons else "No available keys found"
return None, f"All keys unavailable: {reason}"
def calculate_load_score(k: APIKey) -> float:
conc_ratio = k.current_concurrency / k.max_concurrency if k.max_concurrency > 0 else 0
rpm_ratio = k.current_rpm / k.rpm_limit if k.rpm_limit > 0 else 0
tpm_ratio = k.current_tpm / k.tpm_limit if k.tpm_limit > 0 else 0
rpd_ratio = k.current_rpd / k.rpd_limit if k.rpd_limit > 0 else 0
return max(conc_ratio, rpm_ratio, tpm_ratio, rpd_ratio)
def selection_key(k: APIKey) -> tuple[int, int, int, float]:
provider_priority = self._get_provider_priority(k.provider)
if k.rate_limit_penalty > 0:
return (1, k.rate_limit_penalty, provider_priority, calculate_load_score(k))
return (0, provider_priority, 0, calculate_load_score(k))
best_key = min(available_keys, key=selection_key)
return best_key, ""
async def acquire_lease(self, key: APIKey) -> bool:
"""
Increments concurrency, RPM and RPD for the key.
Double check limit in Redis to be safe (race conditions).
"""
key_concurrency_key = f"keypool:concurrency:{key.id}"
now = time.time()
current_minute = int(now / 60)
prev_minute = current_minute - 1
current_day = self._get_day_str()
key_rpm_key = f"keypool:usage:rpm:{key.id}:{current_minute}"
key_rpm_prev_key = f"keypool:usage:rpm:{key.id}:{prev_minute}"
key_rpd_key = f"keypool:usage:rpd:{key.id}:{current_day}"
pipe = self.redis.pipeline()
pipe.incr(key_concurrency_key)
pipe.incr(key_rpm_key)
pipe.expire(key_rpm_key, 65)
pipe.get(key_rpm_prev_key)
pipe.incr(key_rpd_key)
# Keep daily counters for two days so the dashboard can read yesterday's RPD.
pipe.expire(key_rpd_key, self.DAILY_RPD_TTL_SECONDS)
results = await pipe.execute()
curr_conc = results[0]
curr_rpm_val = results[1]
prev_rpm_val = int(results[3]) if results[3] else 0
curr_rpd_val = results[4]
seconds_elapsed = now % 60
prev_weight = (60 - seconds_elapsed) / 60
estimated_rpm = curr_rpm_val + (prev_rpm_val * prev_weight)
if curr_conc > key.max_concurrency:
await self.redis.decr(key_concurrency_key)
await self.redis.decr(key_rpm_key)
await self.redis.decr(key_rpd_key)
return False
if estimated_rpm > key.rpm_limit:
await self.redis.decr(key_concurrency_key)
await self.redis.decr(key_rpm_key)
await self.redis.decr(key_rpd_key)
return False
if key.rpd_limit > 0 and curr_rpd_val > key.rpd_limit:
await self.redis.decr(key_concurrency_key)
await self.redis.decr(key_rpm_key)
await self.redis.decr(key_rpd_key)
return False
return True
async def release_lease(self, key: APIKey):
"""Decrements concurrency."""
key_concurrency_key = f"keypool:concurrency:{key.id}"
await self.redis.decr(key_concurrency_key)
async def record_usage(self, key: APIKey, tokens: int):
"""Records Token usage (TPM)."""
if tokens <= 0:
return
current_minute = int(time.time() / 60)
key_tpm_key = f"keypool:usage:tpm:{key.id}:{current_minute}"
pipe = self.redis.pipeline()
pipe.incrby(key_tpm_key, tokens)
pipe.expire(key_tpm_key, 65)
await pipe.execute()
async def report_failure(self, key: APIKey, is_rate_limit: bool = False, error_message: str = ""):
"""
Reports a failure. A 429 moves the key to the back of the priority queue.
"""
del error_message
if is_rate_limit:
seq = await self.redis.incr(self.RATE_LIMIT_PENALTY_SEQ_KEY)
await self.redis.set(self._get_rate_limit_penalty_key(key.id), seq)
async def get_all_key_stats(self) -> List[Dict[str, Any]]:
"""
Retrieves statistics for all keys, including status, concurrency, and usage limits.
"""
keys = self.key_manager.get_all_keys()
if not keys:
return []
# Prepare keys for batch fetching from Redis
now = time.time()
current_minute = int(now / 60)
prev_minute = current_minute - 1
current_day = self._get_day_str()
redis_keys = []
for key in keys:
redis_keys.extend([
f"keypool:concurrency:{key.id}",
self._get_rate_limit_penalty_key(key.id),
f"keypool:usage:rpm:{key.id}:{current_minute}",
f"keypool:usage:rpm:{key.id}:{prev_minute}",
f"keypool:usage:tpm:{key.id}:{current_minute}",
f"keypool:usage:tpm:{key.id}:{prev_minute}",
f"keypool:usage:rpd:{key.id}:{current_day}"
])
if not redis_keys:
return []
# Fetch all values in a single round-trip
values = await self.redis.mget(redis_keys)
stats = []
# Process results in chunks of 7
seconds_elapsed = now % 60
prev_weight = (60 - seconds_elapsed) / 60
for i, key in enumerate(keys):
base_idx = i * 7
curr_conc = int(values[base_idx]) if values[base_idx] else 0
rate_limit_penalty = int(values[base_idx + 1]) if values[base_idx + 1] else 0
rpm_curr = int(values[base_idx + 2]) if values[base_idx + 2] else 0
rpm_prev = int(values[base_idx + 3]) if values[base_idx + 3] else 0
tpm_curr = int(values[base_idx + 4]) if values[base_idx + 4] else 0
tpm_prev = int(values[base_idx + 5]) if values[base_idx + 5] else 0
rpd_curr = int(values[base_idx + 6]) if values[base_idx + 6] else 0
estimated_rpm = int(rpm_curr + rpm_prev * prev_weight)
estimated_tpm = int(tpm_curr + tpm_prev * prev_weight)
# Determine effective status
status = key.status
if key.status == KeyStatus.ACTIVE and rate_limit_penalty > 0:
status = KeyStatus.COOLDOWN
key_stat = {
"id": key.id,
"model_name": key.model_name,
"provider": key.provider,
"status": status,
"enabled": key.enabled,
"config_id": getattr(key, 'config_id', None),
"owner": getattr(key, 'owner', None),
"endpoint_idx": getattr(key, 'endpoint_idx', None),
"api_base": getattr(key, 'api_base', None),
"limits": {
"rpm": key.rpm_limit,
"tpm": key.tpm_limit,
"rpd": key.rpd_limit,
"max_concurrency": key.max_concurrency
},
"usage": {
"current_concurrency": curr_conc,
"current_rpm": estimated_rpm,
"current_tpm": estimated_tpm,
"current_rpd": rpd_curr
},
"cooldown_remaining": 0
}
stats.append(key_stat)
return stats
async def get_usage_summary(self) -> Dict[str, Any]:
"""
Get overall usage summary across all keys.
Returns total RPM, TPM, RPD, concurrency, and active keys count.
"""
keys = self.key_manager.get_all_keys()
if not keys:
return {
"total_rpm": 0,
"total_tpm": 0,
"total_rpd": 0,
"total_yesterday_rpd": 0,
"total_concurrency": 0,
"active_keys": 0,
"cooldown_keys": 0,
"disabled_keys": 0,
"total_keys": 0
}
now = time.time()
current_minute = int(now / 60)
prev_minute = current_minute - 1
current_day = self._get_day_str()
yesterday_day = self._get_day_str(-1)
redis_keys = []
for key in keys:
redis_keys.extend([
f"keypool:concurrency:{key.id}",
self._get_rate_limit_penalty_key(key.id),
f"keypool:usage:rpm:{key.id}:{current_minute}",
f"keypool:usage:rpm:{key.id}:{prev_minute}",
f"keypool:usage:tpm:{key.id}:{current_minute}",
f"keypool:usage:tpm:{key.id}:{prev_minute}",
f"keypool:usage:rpd:{key.id}:{current_day}",
f"keypool:usage:rpd:{key.id}:{yesterday_day}"
])
if not redis_keys:
return {
"total_rpm": 0,
"total_tpm": 0,
"total_rpd": 0,
"total_yesterday_rpd": 0,
"total_concurrency": 0,
"active_keys": 0,
"cooldown_keys": 0,
"disabled_keys": 0,
"total_keys": 0
}
values = await self.redis.mget(redis_keys)
seconds_elapsed = now % 60
prev_weight = (60 - seconds_elapsed) / 60
total_rpm = 0
total_tpm = 0
total_rpd = 0
total_yesterday_rpd = 0
total_concurrency = 0
active_keys = 0
cooldown_keys = 0
disabled_keys = 0
for i, key in enumerate(keys):
base_idx = i * 8
curr_conc = int(values[base_idx]) if values[base_idx] else 0
rate_limit_penalty = int(values[base_idx + 1]) if values[base_idx + 1] else 0
rpm_curr = int(values[base_idx + 2]) if values[base_idx + 2] else 0
rpm_prev = int(values[base_idx + 3]) if values[base_idx + 3] else 0
tpm_curr = int(values[base_idx + 4]) if values[base_idx + 4] else 0
tpm_prev = int(values[base_idx + 5]) if values[base_idx + 5] else 0
rpd_curr = int(values[base_idx + 6]) if values[base_idx + 6] else 0
rpd_yesterday = int(values[base_idx + 7]) if values[base_idx + 7] else 0
estimated_rpm = int(rpm_curr + rpm_prev * prev_weight)
estimated_tpm = int(tpm_curr + tpm_prev * prev_weight)
total_rpm += estimated_rpm
total_tpm += estimated_tpm
total_rpd += rpd_curr
total_yesterday_rpd += rpd_yesterday
total_concurrency += curr_conc
if key.status == KeyStatus.ACTIVE and rate_limit_penalty == 0:
active_keys += 1
elif key.status == KeyStatus.DISABLED:
disabled_keys += 1
elif rate_limit_penalty > 0:
cooldown_keys += 1
else:
active_keys += 1
return {
"total_rpm": total_rpm,
"total_tpm": total_tpm,
"total_rpd": total_rpd,
"total_yesterday_rpd": total_yesterday_rpd,
"total_concurrency": total_concurrency,
"active_keys": active_keys,
"cooldown_keys": cooldown_keys,
"disabled_keys": disabled_keys,
"total_keys": len(keys)
}
async def get_daily_usage_history(self, days: int = 7) -> List[Dict[str, Any]]:
"""
Aggregate total RPD usage for the most recent N days.
Uses the same UTC day buckets as runtime RPD limiting.
"""
keys = self.key_manager.get_all_keys()
if not keys or days <= 0:
return []
day_labels = [self._get_day_str(-offset) for offset in range(days - 1, -1, -1)]
redis_keys = []
for day_str in day_labels:
for key in keys:
redis_keys.append(f"keypool:usage:rpd:{key.id}:{day_str}")
values = await self.redis.mget(redis_keys)
history = []
key_count = len(keys)
for day_index, day_str in enumerate(day_labels):
base_idx = day_index * key_count
day_total = 0
for key_offset in range(key_count):
value = values[base_idx + key_offset]
day_total += int(value) if value else 0
history.append({
"day": day_str,
"label": f"{day_str[4:6]}-{day_str[6:8]}",
"total_rpd": day_total,
})
return history

294
app/core/key_manager.py Executable file
View File

@ -0,0 +1,294 @@
import yaml
import time
from typing import List, Dict, Optional, Set, Any, Tuple
from app.models import APIKey, KeyStatus
class KeyManager:
def __init__(self, config_path: str):
self.config_path = config_path
self._keys: Dict[str, APIKey] = {}
self._keys_by_model: Dict[str, List[APIKey]] = {}
self._config_ids: Set[str] = set()
self._last_load_time: float = 0
self._raw_config: Dict[str, Any] = {}
self._config_enabled_map: Dict[str, bool] = {}
# Track endpoint enable status: {(config_id, idx): enabled}
self._endpoint_enabled_map: Dict[Tuple[str, int], bool] = {}
# Track endpoint URLs: {(config_id, idx): url}
self._endpoint_urls: Dict[Tuple[str, int], str] = {}
self.load_keys()
def load_keys(self):
"""Loads keys from the YAML configuration file."""
with open(self.config_path, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
self._raw_config = data
self._keys = {}
self._keys_by_model = {}
self._config_ids = set()
self._config_enabled_map = {}
configs = {}
for cfg in data.get('model_configs', []):
if 'id' in cfg:
configs[cfg['id']] = cfg
for key_group in data.get('keys', []):
config_id = key_group.get('config_id')
group_enabled = key_group.get('enabled', True)
keys_list = key_group.get('keys', [])
if config_id:
self._config_ids.add(config_id)
self._config_enabled_map[config_id] = group_enabled
if config_id and config_id in configs:
cfg = configs[config_id]
cfg_copy = cfg.copy()
cfg_copy.pop('id', None)
# Check if this config has multiple endpoints (for local llama.cpp load balancing)
api_endpoints = cfg_copy.get('api_endpoints', [])
if api_endpoints and len(api_endpoints) > 0:
# Generate virtual keys for each endpoint
for idx, endpoint_item in enumerate(api_endpoints):
# Handle both old format (string URL) and new format (dict with url/enabled)
if isinstance(endpoint_item, dict):
endpoint_url = endpoint_item.get('url', '')
endpoint_enabled = endpoint_item.get('enabled', True)
else:
endpoint_url = endpoint_item
endpoint_enabled = True
# Store endpoint info for dashboard
self._endpoint_enabled_map[(config_id, idx)] = endpoint_enabled
self._endpoint_urls[(config_id, idx)] = endpoint_url
final_data = {}
final_data.update(cfg_copy)
final_data.pop('api_endpoints', None) # Remove the list from individual key
final_data['config_id'] = config_id
# Endpoint is enabled only if both group and endpoint itself are enabled
final_data['enabled'] = group_enabled and endpoint_enabled
final_data['api_base'] = endpoint_url # Use individual endpoint
final_data['key'] = f"local-{config_id}-{idx}" # Virtual key for local service
final_data['endpoint_idx'] = idx # Track endpoint index
# Virtual key ID that includes endpoint info for tracking
endpoint_hash = hash(endpoint_url) & 0xFFFF
key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}"
final_data['id'] = key_id
if not group_enabled or not endpoint_enabled:
final_data['status'] = KeyStatus.DISABLED
try:
key = APIKey(**final_data)
self._keys[key.id] = key
if key.model_name not in self._keys_by_model:
self._keys_by_model[key.model_name] = []
self._keys_by_model[key.model_name].append(key)
except Exception as e:
print(f"Error loading virtual key {key_id} for endpoint {endpoint_url}: {e}")
else:
# Standard key loading (original logic)
for key_data in keys_list:
final_data = {}
final_data.update(cfg_copy)
final_data['config_id'] = config_id
final_data['enabled'] = group_enabled
if key_data.get('owner'):
final_data['owner'] = key_data['owner']
final_data['key'] = key_data['key']
if not group_enabled:
final_data['status'] = KeyStatus.DISABLED
key_id = f"{config_id}:{hash(key_data['key']) & 0xFFFFFFFF}"
final_data['id'] = key_id
try:
key = APIKey(**final_data)
self._keys[key.id] = key
if key.model_name not in self._keys_by_model:
self._keys_by_model[key.model_name] = []
self._keys_by_model[key.model_name].append(key)
except Exception as e:
print(f"Error loading key {key_id}: {e}")
else:
for key_data in keys_list:
final_data = {}
final_data.update(key_group)
final_data.pop('keys', None)
final_data.update(key_data)
if 'id' not in final_data:
final_data['id'] = key_data.get('key', 'unknown')[:16]
if not final_data.get('enabled', True):
final_data['status'] = KeyStatus.DISABLED
try:
key = APIKey(**final_data)
self._keys[key.id] = key
if key.model_name not in self._keys_by_model:
self._keys_by_model[key.model_name] = []
self._keys_by_model[key.model_name].append(key)
except Exception as e:
print(f"Error loading key {final_data.get('id', 'unknown')}: {e}")
self._last_load_time = time.time()
def _save_config(self) -> bool:
"""Save current configuration to YAML file."""
try:
for key_group in self._raw_config.get('keys', []):
config_id = key_group.get('config_id')
if config_id and config_id in self._config_enabled_map:
key_group['enabled'] = self._config_enabled_map[config_id]
# Save endpoint enable status
for (cfg_id, idx), enabled in self._endpoint_enabled_map.items():
for cfg in self._raw_config.get('model_configs', []):
if cfg.get('id') == cfg_id:
endpoints = cfg.get('api_endpoints', [])
if idx < len(endpoints):
if isinstance(endpoints[idx], dict):
endpoints[idx]['enabled'] = enabled
else:
# Convert string to dict with url and enabled
endpoints[idx] = {'url': endpoints[idx], 'enabled': enabled}
with open(self.config_path, 'w', encoding='utf-8') as f:
yaml.dump(self._raw_config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
return True
except Exception as e:
print(f"Error saving config: {e}")
return False
def reload_keys(self) -> bool:
"""Reload keys from config file. Returns True if successful."""
try:
self.load_keys()
return True
except Exception as e:
print(f"Error reloading keys: {e}")
return False
def set_config_enabled(self, config_id: str, enabled: bool) -> bool:
"""Enable or disable all keys for a config_id and save to file."""
if config_id not in self._config_ids:
return False
self._config_enabled_map[config_id] = enabled
updated = False
for key in self._keys.values():
if hasattr(key, 'config_id') and key.config_id == config_id:
key.enabled = enabled
if enabled:
key.status = KeyStatus.ACTIVE
else:
key.status = KeyStatus.DISABLED
updated = True
if updated:
self._save_config()
return updated
def set_key_enabled(self, key_id: str, enabled: bool) -> bool:
"""Enable or disable a specific key."""
if key_id not in self._keys:
return False
key = self._keys[key_id]
key.enabled = enabled
if enabled:
key.status = KeyStatus.ACTIVE
else:
key.status = KeyStatus.DISABLED
return True
def get_config_ids(self) -> Set[str]:
"""Get all config IDs."""
return self._config_ids.copy()
def get_keys_by_config(self, config_id: str) -> List[APIKey]:
"""Get all keys for a specific config_id."""
return [k for k in self._keys.values() if hasattr(k, 'config_id') and k.config_id == config_id]
def get_key(self, key_id: str) -> Optional[APIKey]:
return self._keys.get(key_id)
def get_candidate_keys(self, model_name: Optional[str] = None) -> List[APIKey]:
"""Returns all keys for a given model, regardless of status. If model_name is None, returns all keys."""
if not model_name:
return self.get_all_keys()
return self._keys_by_model.get(model_name, [])
def get_all_keys(self) -> List[APIKey]:
return list(self._keys.values())
def get_last_load_time(self) -> float:
return self._last_load_time
def get_config_enabled(self, config_id: str) -> bool:
"""Get enabled status for a config_id."""
return self._config_enabled_map.get(config_id, True)
def get_endpoint_info(self, config_id: str) -> List[Dict[str, Any]]:
"""Get all endpoint info for a config_id."""
endpoints = []
for (cfg_id, idx), url in self._endpoint_urls.items():
if cfg_id == config_id:
enabled = self._endpoint_enabled_map.get((cfg_id, idx), True)
endpoint_hash = hash(url) & 0xFFFF
key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}"
endpoints.append({
'idx': idx,
'url': url,
'enabled': enabled,
'key_id': key_id
})
return sorted(endpoints, key=lambda x: x['idx'])
def set_endpoint_enabled(self, config_id: str, endpoint_idx: int, enabled: bool) -> bool:
"""Enable or disable a specific endpoint and save to file."""
if (config_id, endpoint_idx) not in self._endpoint_enabled_map:
return False
self._endpoint_enabled_map[(config_id, endpoint_idx)] = enabled
# Update the corresponding key's status
url = self._endpoint_urls.get((config_id, endpoint_idx), '')
endpoint_hash = hash(url) & 0xFFFF
key_id = f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}"
if key_id in self._keys:
key = self._keys[key_id]
key.enabled = enabled
if enabled:
key.status = KeyStatus.ACTIVE
else:
key.status = KeyStatus.DISABLED
self._save_config()
return True
def get_endpoint_key_id(self, config_id: str, endpoint_idx: int) -> Optional[str]:
"""Get the key_id for a specific endpoint."""
url = self._endpoint_urls.get((config_id, endpoint_idx))
if url:
endpoint_hash = hash(url) & 0xFFFF
return f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}"
return None

311
app/core/providers.py Executable file
View File

@ -0,0 +1,311 @@
import re
from typing import Dict, Any, List, Optional, Tuple, Union
from app.models import APIKey
class BaseProvider:
def __init__(self, api_key: APIKey):
self.api_key = api_key
def get_headers(self) -> Dict[str, str]:
raise NotImplementedError
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
raise NotImplementedError
def get_url(self) -> str:
raise NotImplementedError
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Parses provider response into a standardized format.
Returns:
{
"content": str,
"reasoning_content": Optional[str],
"usage": {
"input_tokens": int,
"output_tokens": int,
"total_tokens": int
}
}
"""
raise NotImplementedError
class OpenAIProvider(BaseProvider):
def get_headers(self) -> Dict[str, str]:
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key.key}"
}
def get_url(self) -> str:
base = self.api_key.api_base or "https://api.openai.com/v1"
return f"{base.rstrip('/')}/chat/completions"
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
data = {"messages": messages, "model": self.api_key.model_name, "stream": False}
# 1. Start with Key's default params
final_params = self.api_key.params.copy()
# 2. Update with request-level params (filter None)
valid_request_params = {k: v for k, v in params.items() if v is not None}
final_params.update(valid_request_params)
# 3. Add to payload
data.update(final_params)
# 4. Handle extra_config from Key (e.g., Qwen extra_body)
if self.api_key.extra_config:
if "extra_body" in self.api_key.extra_config:
data.update(self.api_key.extra_config["extra_body"])
return data
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
choice = data["choices"][0]
text = choice["message"]["content"]
reasoning = choice["message"].get("reasoning_content") or \
choice.get("reasoning_content")
usage = data.get("usage", {})
return {
"content": text,
"reasoning_content": reasoning,
"usage": {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0)
}
}
except (KeyError, IndexError):
return {"content": "", "usage": {}}
class AzureProvider(OpenAIProvider):
def get_headers(self) -> Dict[str, str]:
return {
"Content-Type": "application/json",
"api-key": self.api_key.key
}
def get_url(self) -> str:
# Expected api_base: https://{resource}.openai.azure.com/openai/deployments/{deployment}
# We need to append api-version
base = self.api_key.api_base or ""
sep = "&" if "?" in base else "?"
# Get api-version from extra_config or default
version = "2025-01-01-preview"
if self.api_key.extra_config and "api_version" in self.api_key.extra_config:
version = self.api_key.extra_config["api_version"]
return f"{base}{sep}api-version={version}"
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
payload = super().get_payload(messages, params)
# Azure usually doesn't need 'model' in body if it's in URL (deployment)
# But keeping it usually doesn't hurt, unless configured to remove.
# For simplicity, we keep it or remove it if we want strict adherence to typical azure usage.
# Let's remove it to be safe as deployment is in URL.
payload.pop("model", None)
return payload
class GeminiProvider(BaseProvider):
def get_headers(self) -> Dict[str, str]:
return {"Content-Type": "application/json"}
def get_url(self) -> str:
# api_base example: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent
base = self.api_key.api_base or f"https://generativelanguage.googleapis.com/v1beta/models/{self.api_key.model_name}:generateContent"
sep = "&" if "?" in base else "?"
return f"{base}{sep}key={self.api_key.key}"
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
gemini_contents = []
system_instruction = None
for msg in messages:
role = "model" if msg["role"] == "assistant" else msg["role"]
content = msg["content"]
if role == "system":
text_val = content if isinstance(content, str) else content[0]["text"]
system_instruction = {"parts": [{"text": text_val}]}
continue
parts = []
content_list = content if isinstance(content, list) else [{"type": "text", "text": content}]
for item in content_list:
if item.get("type") == "text":
parts.append({"text": item["text"]})
elif item.get("type") == "image_url":
url = item.get("image_url", {}).get("url", "")
match = re.match(r'data:(.*?);base64,(.*)', url)
if match:
parts.append({
"inline_data": {
"mime_type": match.group(1),
"data": match.group(2)
}
})
if parts:
gemini_contents.append({"role": role, "parts": parts})
generation_config = {}
# 1. Start with Key's default params (mapped to Gemini format if possible, or assume config.py had correct format)
# Note: config.py used "maxOutputTokens" (Gemini style) directly in 'params'.
# We should merge self.api_key.params into generation_config
generation_config.update(self.api_key.params)
# 2. Update with request params (mapped from OpenAI style)
if "max_tokens" in params:
generation_config["maxOutputTokens"] = params["max_tokens"]
if "temperature" in params:
generation_config["temperature"] = params["temperature"]
if "top_p" in params:
generation_config["topP"] = params["top_p"]
# 3. Handle thinking config from Key or Request
# Check Key's extra_config first
if self.api_key.extra_config and "thinking_config" in self.api_key.extra_config:
generation_config["thinkingConfig"] = self.api_key.extra_config["thinking_config"]
# Override with request param if present
if "thinking_config" in params:
generation_config["thinkingConfig"] = params["thinking_config"]
payload = {
"contents": gemini_contents,
"generationConfig": generation_config
}
if system_instruction:
payload["system_instruction"] = system_instruction
return payload
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
candidate = data["candidates"][0]
parts_text = []
if "content" in candidate and "parts" in candidate["content"]:
for part in candidate["content"]["parts"]:
if "text" in part:
parts_text.append(part["text"])
text = "\n".join(parts_text) if parts_text else ""
reasoning = None
if "thinking_process" in candidate:
reasoning = candidate["thinking_process"]
usage_metadata = data.get("usageMetadata", {})
input_tokens = usage_metadata.get("promptTokenCount", 0)
output_tokens = usage_metadata.get("candidatesTokenCount", 0)
total_tokens = input_tokens + output_tokens
return {
"content": text,
"reasoning_content": reasoning,
"usage": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens
}
}
except (KeyError, IndexError):
return {"content": "", "usage": {}}
class ClaudeProvider(BaseProvider):
def get_headers(self) -> Dict[str, str]:
headers = {
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
"x-api-key": self.api_key.key
}
# Check extra_config for version
if self.api_key.extra_config and "anthropic_version" in self.api_key.extra_config:
headers["anthropic-version"] = self.api_key.extra_config["anthropic_version"]
return headers
def get_url(self) -> str:
base = self.api_key.api_base or "https://api.anthropic.com/v1/messages"
return base
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
claude_msgs = []
system_prompt = ""
for msg in messages:
role = "assistant" if msg["role"] == "assistant" else msg["role"]
content = msg["content"]
if role == "system":
system_prompt = content if isinstance(content, str) else content[0]["text"]
continue
# Simplified content handling for text only for now, can expand if needed
claude_msgs.append({"role": role, "content": content})
# Start with Key's default params
payload = {
"model": self.api_key.model_name,
"messages": claude_msgs,
"max_tokens": 1024 # Default fallback
}
# Merge Key params
payload.update(self.api_key.params)
# Merge request params
if "max_tokens" in params:
payload["max_tokens"] = params["max_tokens"]
if "temperature" in params:
payload["temperature"] = params["temperature"]
if system_prompt:
payload["system"] = system_prompt
return payload
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
content_list = data.get("content", [])
text = ""
for item in content_list:
if item.get("type") == "text":
text += item.get("text", "")
usage = data.get("usage", {})
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
return {
"content": text,
"usage": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens
}
}
except Exception:
return {"content": "", "usage": {}}
def get_provider_handler(api_key: APIKey) -> BaseProvider:
providers = {
"google": GeminiProvider,
"gemini": GeminiProvider,
"openai": OpenAIProvider,
"azure": AzureProvider,
"claude": ClaudeProvider,
"anthropic": ClaudeProvider,
"zhipu": OpenAIProvider # Zhipu is mostly OpenAI compatible
}
provider_class = providers.get(api_key.provider, OpenAIProvider)
return provider_class(api_key)

1016
app/core/stats_exporter.py Executable file

File diff suppressed because it is too large Load Diff

207
app/core/stats_store.py Executable file
View File

@ -0,0 +1,207 @@
import sqlite3
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any
class StatsStore:
DAILY_RETENTION_DAYS = 7
def __init__(self, db_path: str | Path):
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_db()
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def _init_db(self):
with self._connect() as conn:
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS latest_summary (
id INTEGER PRIMARY KEY CHECK (id = 1),
updated_at INTEGER NOT NULL,
total_keys INTEGER NOT NULL,
active_keys INTEGER NOT NULL,
cooldown_keys INTEGER NOT NULL,
disabled_keys INTEGER NOT NULL,
total_concurrency INTEGER NOT NULL,
total_rpm INTEGER NOT NULL,
total_tpm INTEGER NOT NULL,
total_rpd INTEGER NOT NULL,
total_yesterday_rpd INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS latest_key_stats (
key_id TEXT PRIMARY KEY,
updated_at INTEGER NOT NULL,
model_name TEXT NOT NULL,
provider TEXT NOT NULL,
status TEXT NOT NULL,
enabled INTEGER NOT NULL,
config_id TEXT,
owner TEXT,
current_concurrency INTEGER NOT NULL,
current_rpm INTEGER NOT NULL,
current_tpm INTEGER NOT NULL,
current_rpd INTEGER NOT NULL,
cooldown_remaining REAL NOT NULL,
rpm_limit INTEGER NOT NULL,
tpm_limit INTEGER NOT NULL,
rpd_limit INTEGER NOT NULL,
max_concurrency INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS daily_usage (
day TEXT PRIMARY KEY,
label TEXT NOT NULL,
total_rpd INTEGER NOT NULL,
updated_at INTEGER NOT NULL
);
"""
)
def save_export(self, stats: list[dict[str, Any]], usage_summary: dict[str, Any], daily_usage: list[dict[str, Any]]):
now = int(datetime.now(timezone.utc).timestamp())
cutoff = (datetime.now(timezone.utc) - timedelta(days=self.DAILY_RETENTION_DAYS)).strftime("%Y%m%d")
with self._connect() as conn:
conn.execute(
"""
INSERT INTO latest_summary (
id, updated_at, total_keys, active_keys, cooldown_keys, disabled_keys,
total_concurrency, total_rpm, total_tpm, total_rpd, total_yesterday_rpd
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
updated_at = excluded.updated_at,
total_keys = excluded.total_keys,
active_keys = excluded.active_keys,
cooldown_keys = excluded.cooldown_keys,
disabled_keys = excluded.disabled_keys,
total_concurrency = excluded.total_concurrency,
total_rpm = excluded.total_rpm,
total_tpm = excluded.total_tpm,
total_rpd = excluded.total_rpd,
total_yesterday_rpd = excluded.total_yesterday_rpd
""",
(
1,
now,
usage_summary["total_keys"],
usage_summary["active_keys"],
usage_summary["cooldown_keys"],
usage_summary["disabled_keys"],
usage_summary["total_concurrency"],
usage_summary["total_rpm"],
usage_summary["total_tpm"],
usage_summary["total_rpd"],
usage_summary.get("total_yesterday_rpd", 0),
),
)
conn.execute("DELETE FROM latest_key_stats")
conn.executemany(
"""
INSERT INTO latest_key_stats (
key_id, updated_at, model_name, provider, status, enabled, config_id, owner,
current_concurrency, current_rpm, current_tpm, current_rpd, cooldown_remaining,
rpm_limit, tpm_limit, rpd_limit, max_concurrency
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
[
(
item["id"],
now,
item["model_name"],
item["provider"],
item["status"],
1 if item["enabled"] else 0,
item.get("config_id"),
item.get("owner"),
item["usage"]["current_concurrency"],
item["usage"]["current_rpm"],
item["usage"]["current_tpm"],
item["usage"]["current_rpd"],
item["cooldown_remaining"],
item["limits"]["rpm"],
item["limits"]["tpm"],
item["limits"]["rpd"],
item["limits"]["max_concurrency"],
)
for item in stats
],
)
conn.executemany(
"""
INSERT INTO daily_usage (day, label, total_rpd, updated_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(day) DO UPDATE SET
label = excluded.label,
total_rpd = MAX(daily_usage.total_rpd, excluded.total_rpd),
updated_at = excluded.updated_at
""",
[
(
item["day"],
item["label"],
item["total_rpd"],
now,
)
for item in daily_usage
],
)
conn.execute("DELETE FROM daily_usage WHERE day < ?", (cutoff,))
def get_daily_usage(self, days: int = 7) -> list[dict[str, Any]]:
if days <= 0:
return []
with self._connect() as conn:
rows = conn.execute(
"""
SELECT day, label, total_rpd
FROM daily_usage
ORDER BY day DESC
LIMIT ?
""",
(days,),
).fetchall()
rows = list(reversed(rows))
return [
{
"day": row["day"],
"label": row["label"],
"total_rpd": row["total_rpd"],
}
for row in rows
]
def get_latest_summary(self) -> dict[str, Any] | None:
with self._connect() as conn:
row = conn.execute("SELECT * FROM latest_summary WHERE id = 1").fetchone()
if not row:
return None
return dict(row)
def get_latest_key_stats(self) -> list[dict[str, Any]]:
with self._connect() as conn:
rows = conn.execute(
"""
SELECT *
FROM latest_key_stats
ORDER BY provider, model_name, key_id
"""
).fetchall()
return [dict(row) for row in rows]
def delete_day(self, day: str):
with self._connect() as conn:
conn.execute("DELETE FROM daily_usage WHERE day = ?", (day,))

36
app/models.py Executable file
View File

@ -0,0 +1,36 @@
from enum import Enum
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any
class KeyStatus(str, Enum):
ACTIVE = "active"
COOLDOWN = "cooldown"
DISABLED = "disabled"
class APIKey(BaseModel):
id: str = Field(..., description="Unique identifier for the key")
key: str = Field(..., description="The actual API Key string")
provider: str = Field(..., description="Provider identifier (e.g., google, zhipu)")
api_base: Optional[str] = Field(None, description="Base URL for the API")
model_name: str = Field(..., description="The physical model name")
params: Dict[str, Any] = Field(default_factory=dict, description="Default model parameters (e.g., temperature)")
extra_config: Dict[str, Any] = Field(default_factory=dict, description="Provider specific configuration")
enabled: bool = Field(default=True, description="Whether the key is enabled in configuration")
status: KeyStatus = Field(default=KeyStatus.ACTIVE, description="Current status of the key")
rpm_limit: int = Field(..., description="Requests Per Minute limit")
tpm_limit: int = Field(..., description="Tokens Per Minute limit")
rpd_limit: int = Field(default=0, description="Requests Per Day limit (0 = unlimited)")
max_concurrency: int = Field(..., description="Maximum concurrent requests allowed")
config_id: Optional[str] = Field(None, description="Config ID this key belongs to")
owner: Optional[str] = Field(None, description="Owner of this key")
# Runtime stats (Not loaded from config usually, but part of the object in memory)
current_concurrency: int = Field(default=0, description="Current concurrent requests")
current_rpm: int = Field(default=0, description="Current RPM usage")
current_tpm: int = Field(default=0, description="Current TPM usage")
current_rpd: int = Field(default=0, description="Current RPD usage")
rate_limit_penalty: int = Field(default=0, description="429 demotion order. 0 means never rate limited.")
endpoint_idx: Optional[int] = Field(default=None, description="Endpoint index for local llama load balancing")
class Config:
use_enum_values = True

10
app/runtime.py Executable file
View File

@ -0,0 +1,10 @@
import os
from app.core.dispatcher import Dispatcher
from app.core.key_manager import KeyManager
from app.core.stats_exporter import StatsExporter
key_manager = KeyManager(os.getenv("KEY_CONFIG_PATH", "config/keys.yaml"))
dispatcher = Dispatcher(key_manager)
stats_exporter = StatsExporter(dispatcher, output_dir="stats", interval=10)

18
app/utils/redis_client.py Executable file
View File

@ -0,0 +1,18 @@
import redis.asyncio as redis
import os
class RedisClient:
_instance = None
@classmethod
def get_instance(cls):
if cls._instance is None:
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
cls._instance = redis.from_url(redis_url, decode_responses=True)
return cls._instance
@classmethod
async def close(cls):
if cls._instance:
await cls._instance.close()
cls._instance = None

41
config/keys_example.yaml Executable file
View File

@ -0,0 +1,41 @@
model_configs:
- id: nvidia-glm-cfg
provider: nvidia
model_name: z-ai/glm-5.1
api_base: https://integrate.api.nvidia.com/v1
rpm_limit: 40
tpm_limit: 100000
max_concurrency: 8
params:
temperature: 1
top_p: 1
max_tokens: 16384
stream: false
chat_template_kwargs:
enable_thinking: false
clear_thinking: false
- id: llama-local-cfg
provider: openai
model_name: llama-local
api_base: http://192.168.2.101:8848/v1
api_endpoints:
- url: http://192.168.2.101:8848/v1
enabled: true
- url: http://192.168.1.51:8848/v1
enabled: true
- url: http://192.168.1.61:8848/v1
enabled: true
rpm_limit: 999999
tpm_limit: 999999999
max_concurrency: 999
params:
temperature: 1.0
max_tokens: 6144
chat_template_kwargs:
enable_thinking: false
keys:
- config_id: nvidia-glm-cfg
enabled: true
keys:
- key: nvapi-key-1
owner: name

41
main.py Executable file
View File

@ -0,0 +1,41 @@
import os
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from app.api.routes import router
from app.runtime import stats_exporter
from app.utils.redis_client import RedisClient
app = FastAPI(title="Smart KeyPool Gateway")
app.include_router(router)
@app.get("/", response_class=HTMLResponse)
async def dashboard_home():
return await stats_exporter.render_dashboard()
@app.get("/dashboard", response_class=HTMLResponse)
async def dashboard_page():
return await stats_exporter.render_dashboard()
@app.on_event("startup")
async def startup_event():
redis = RedisClient.get_instance()
try:
await redis.ping()
print("Connected to Redis")
except Exception as e:
print(f"Failed to connect to Redis: {e}")
await stats_exporter.start()
print("Stats exporter started - dashboard available at / or /dashboard")
@app.on_event("shutdown")
async def shutdown_event():
await stats_exporter.stop()
await RedisClient.close()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "8888")))

6
requirements.txt Executable file
View File

@ -0,0 +1,6 @@
fastapi
uvicorn
redis
pydantic
pyyaml
httpx

284
test_keys.py Executable file
View File

@ -0,0 +1,284 @@
#!/usr/bin/env python3
"""
测试所有启用的API密钥是否正常工作并发版本
"""
import asyncio
import sys
import httpx
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from enum import Enum
from concurrent.futures import ThreadPoolExecutor
# 添加app目录到路径
sys.path.insert(0, 'd:\\DPI\\keypool')
from app.core.key_manager import KeyManager
from app.core.providers import get_provider_handler
from app.models import APIKey
class TestStatus(Enum):
SUCCESS = "success"
FAILED = "failed"
SKIPPED = "skipped"
@dataclass
class TestResult:
key_id: str
provider: str
owner: Optional[str]
status: TestStatus
message: str
response_time: float = 0.0
details: Optional[Dict] = None
async def test_single_key(api_key: APIKey, timeout: int = 240) -> TestResult:
"""测试单个API密钥"""
import time
start_time = time.time()
owner = getattr(api_key, 'owner', None)
# 如果密钥被禁用,跳过测试
if not api_key.enabled:
return TestResult(
key_id=api_key.id,
provider=api_key.provider,
owner=owner,
status=TestStatus.SKIPPED,
message="密钥已禁用 (enabled: false)",
response_time=0.0
)
try:
provider = get_provider_handler(api_key)
# 准备测试消息
test_messages = [
{"role": "user", "content": "Hello, this is a test message. Please reply with 'OK'."}
]
# 获取请求参数
headers = provider.get_headers()
url = provider.get_url()
payload = provider.get_payload(test_messages, {})
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, headers=headers, json=payload)
response_time = time.time() - start_time
if response.status_code == 200:
data = response.json()
parsed = provider.parse_response(data)
content = parsed.get("content", "")
return TestResult(
key_id=api_key.id,
provider=api_key.provider,
owner=owner,
status=TestStatus.SUCCESS,
message=f"测试成功,响应时间: {response_time:.2f}s",
response_time=response_time,
details={
"content_preview": content[:100] + "..." if len(content) > 100 else content,
"usage": parsed.get("usage", {})
}
)
else:
error_text = response.text[:500] if response.text else "No error details"
return TestResult(
key_id=api_key.id,
provider=api_key.provider,
owner=owner,
status=TestStatus.FAILED,
message=f"HTTP错误 {response.status_code}: {error_text}",
response_time=time.time() - start_time
)
except httpx.TimeoutException:
return TestResult(
key_id=api_key.id,
provider=api_key.provider,
owner=owner,
status=TestStatus.FAILED,
message=f"请求超时 (>{timeout}s)",
response_time=time.time() - start_time
)
except httpx.ConnectError as e:
return TestResult(
key_id=api_key.id,
provider=api_key.provider,
owner=owner,
status=TestStatus.FAILED,
message=f"连接错误: {str(e)}",
response_time=time.time() - start_time
)
except Exception as e:
return TestResult(
key_id=api_key.id,
provider=api_key.provider,
owner=owner,
status=TestStatus.FAILED,
message=f"异常: {type(e).__name__}: {str(e)}",
response_time=time.time() - start_time
)
async def test_all_keys_concurrent(config_path: str = "config/keys.yaml", max_concurrency: int = 10):
"""并发测试所有启用的密钥"""
print("=" * 80)
print("API密钥测试工具 (并发模式)")
print("=" * 80)
print()
# 加载密钥管理器
try:
km = KeyManager(config_path)
all_keys = km.get_all_keys()
print(f"共加载 {len(all_keys)} 个密钥配置")
except Exception as e:
print(f"加载配置失败: {e}")
return
# 筛选启用的密钥
enabled_keys = [k for k in all_keys if k.enabled]
print(f"其中 {len(enabled_keys)} 个密钥已启用")
print(f"并发数: {max_concurrency}")
print()
if not enabled_keys:
print("没有启用的密钥需要测试")
return
# 使用信号量控制并发
semaphore = asyncio.Semaphore(max_concurrency)
async def test_with_semaphore(key: APIKey) -> TestResult:
async with semaphore:
return await test_single_key(key)
# 创建所有测试任务
tasks = [test_with_semaphore(key) for key in enabled_keys]
# 并发执行所有测试
print("开始并发测试...")
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理可能的异常结果
processed_results: List[TestResult] = []
for i, result in enumerate(results):
if isinstance(result, Exception):
key = enabled_keys[i]
owner = getattr(key, 'owner', None)
processed_results.append(TestResult(
key_id=key.id,
provider=key.provider,
owner=owner,
status=TestStatus.FAILED,
message=f"测试任务异常: {type(result).__name__}: {str(result)}",
response_time=0.0
))
else:
processed_results.append(result)
results = processed_results
# 按提供商分组显示结果
results_by_provider: Dict[str, List[TestResult]] = {}
for r in results:
p = r.provider
if p not in results_by_provider:
results_by_provider[p] = []
results_by_provider[p].append(r)
# 打印详细结果
for provider, provider_results in sorted(results_by_provider.items()):
print(f"\n{'='*80}")
print(f"提供商: {provider.upper()} ({len(provider_results)} 个密钥)")
print(f"{'='*80}")
for r in provider_results:
status_icon = "" if r.status == TestStatus.SUCCESS else "" if r.status == TestStatus.FAILED else ""
owner_str = f" ({r.owner})" if r.owner else ""
print(f"\n [{status_icon}] {r.key_id}{owner_str}")
print(f" 状态: {r.message}")
if r.details and r.status == TestStatus.SUCCESS:
print(f" 回复: {r.details['content_preview']}")
# 打印汇总报告
print("\n" + "=" * 80)
print("测试汇总报告")
print("=" * 80)
success_count = sum(1 for r in results if r.status == TestStatus.SUCCESS)
failed_count = sum(1 for r in results if r.status == TestStatus.FAILED)
skipped_count = sum(1 for r in results if r.status == TestStatus.SKIPPED)
print(f"\n总计: {len(results)} 个密钥")
print(f" ✓ 成功: {success_count}")
print(f" ✗ 失败: {failed_count}")
print(f" ○ 跳过: {skipped_count}")
# 按提供商统计
print("\n按提供商统计:")
provider_stats: Dict[str, Dict[str, int]] = {}
for r in results:
p = r.provider
if p not in provider_stats:
provider_stats[p] = {"success": 0, "failed": 0, "skipped": 0}
if r.status == TestStatus.SUCCESS:
provider_stats[p]["success"] += 1
elif r.status == TestStatus.FAILED:
provider_stats[p]["failed"] += 1
else:
provider_stats[p]["skipped"] += 1
for provider, stats in sorted(provider_stats.items()):
total = stats["success"] + stats["failed"] + stats["skipped"]
success_rate = (stats["success"] / total * 100) if total > 0 else 0
print(f" {provider:15s}: {stats['success']:2d}/{total:2d} 成功 ({success_rate:5.1f}%), {stats['failed']} 失败, {stats['skipped']} 跳过")
# 失败的详细信息
failed_results = [r for r in results if r.status == TestStatus.FAILED]
if failed_results:
print("\n" + "-" * 80)
print("失败的密钥详情:")
print("-" * 80)
for r in failed_results:
print(f"\n [{r.provider}] {r.key_id}")
if r.owner:
print(f" 拥有者: {r.owner}")
print(f" 错误: {r.message}")
# 成功的响应时间统计
success_results = [r for r in results if r.status == TestStatus.SUCCESS and r.response_time > 0]
if success_results:
avg_time = sum(r.response_time for r in success_results) / len(success_results)
min_time = min(r.response_time for r in success_results)
max_time = max(r.response_time for r in success_results)
print("\n" + "-" * 80)
print("响应时间统计 (成功请求):")
print("-" * 80)
print(f" 平均: {avg_time:.2f}s")
print(f" 最快: {min_time:.2f}s")
print(f" 最慢: {max_time:.2f}s")
print("\n" + "=" * 80)
print("测试完成")
print("=" * 80)
return results
if __name__ == "__main__":
import os
# 切换到脚本所在目录
script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
# 运行测试默认并发数10
asyncio.run(test_all_keys_concurrent(max_concurrency=11))