init
This commit is contained in:
commit
fb7f06ed03
3
.gitignore
vendored
Executable file
3
.gitignore
vendored
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
__pycache__/
|
||||||
|
test/*
|
||||||
|
keys.yaml
|
||||||
227
app/api/routes.py
Executable file
227
app/api/routes.py
Executable file
@ -0,0 +1,227 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Request, BackgroundTasks
|
||||||
|
from app.core.providers import get_provider_handler
|
||||||
|
from app.runtime import key_manager, dispatcher, stats_exporter
|
||||||
|
import httpx
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/v1/stats")
|
||||||
|
async def get_stats():
|
||||||
|
"""
|
||||||
|
Returns usage statistics for all API keys.
|
||||||
|
"""
|
||||||
|
stats = await dispatcher.get_all_key_stats()
|
||||||
|
return {"data": stats}
|
||||||
|
|
||||||
|
@router.get("/v1/usage")
|
||||||
|
async def get_usage():
|
||||||
|
"""
|
||||||
|
Returns overall usage summary.
|
||||||
|
"""
|
||||||
|
summary = await dispatcher.get_usage_summary()
|
||||||
|
return {"data": summary}
|
||||||
|
|
||||||
|
@router.post("/v1/stats/reset/rpd")
|
||||||
|
async def reset_today_rpd():
|
||||||
|
"""
|
||||||
|
Reset today's RPD counters in Redis and refresh dashboard output.
|
||||||
|
"""
|
||||||
|
day = dispatcher._get_day_str()
|
||||||
|
deleted_keys = await dispatcher.reset_today_rpd_usage()
|
||||||
|
stats_exporter.clear_daily_usage_day(day)
|
||||||
|
await stats_exporter.export_once()
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"day": day,
|
||||||
|
"deleted_keys": deleted_keys,
|
||||||
|
"message": f"Reset today's RPD counters for {deleted_keys} keys",
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.post("/v1/keys/config/{config_id}/enable")
|
||||||
|
async def enable_config(config_id: str):
|
||||||
|
"""
|
||||||
|
Enable all keys for a specific config_id (auto-saves to config file).
|
||||||
|
"""
|
||||||
|
success = key_manager.set_config_enabled(config_id, True)
|
||||||
|
if success:
|
||||||
|
return {"success": True, "message": f"Config {config_id} enabled and saved"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Config {config_id} not found")
|
||||||
|
|
||||||
|
@router.post("/v1/keys/config/{config_id}/disable")
|
||||||
|
async def disable_config(config_id: str):
|
||||||
|
"""
|
||||||
|
Disable all keys for a specific config_id (auto-saves to config file).
|
||||||
|
"""
|
||||||
|
success = key_manager.set_config_enabled(config_id, False)
|
||||||
|
if success:
|
||||||
|
return {"success": True, "message": f"Config {config_id} disabled and saved"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Config {config_id} not found")
|
||||||
|
|
||||||
|
@router.post("/v1/keys/{key_id}/enable")
|
||||||
|
async def enable_key(key_id: str):
|
||||||
|
"""
|
||||||
|
Enable a specific key.
|
||||||
|
"""
|
||||||
|
success = key_manager.set_key_enabled(key_id, True)
|
||||||
|
if success:
|
||||||
|
return {"success": True, "message": f"Key {key_id} enabled"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Key {key_id} not found")
|
||||||
|
|
||||||
|
@router.post("/v1/keys/{key_id}/disable")
|
||||||
|
async def disable_key(key_id: str):
|
||||||
|
"""
|
||||||
|
Disable a specific key.
|
||||||
|
"""
|
||||||
|
success = key_manager.set_key_enabled(key_id, False)
|
||||||
|
if success:
|
||||||
|
return {"success": True, "message": f"Key {key_id} disabled"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Key {key_id} not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/keys/config/{config_id}/endpoint/{endpoint_idx}/enable")
|
||||||
|
async def enable_endpoint(config_id: str, endpoint_idx: int):
|
||||||
|
"""
|
||||||
|
Enable a specific endpoint for a config (auto-saves to config file).
|
||||||
|
"""
|
||||||
|
success = key_manager.set_endpoint_enabled(config_id, endpoint_idx, True)
|
||||||
|
if success:
|
||||||
|
return {"success": True, "message": f"Endpoint {endpoint_idx} of {config_id} enabled and saved"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_idx} of {config_id} not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/keys/config/{config_id}/endpoint/{endpoint_idx}/disable")
|
||||||
|
async def disable_endpoint(config_id: str, endpoint_idx: int):
|
||||||
|
"""
|
||||||
|
Disable a specific endpoint for a config (auto-saves to config file).
|
||||||
|
"""
|
||||||
|
success = key_manager.set_endpoint_enabled(config_id, endpoint_idx, False)
|
||||||
|
if success:
|
||||||
|
return {"success": True, "message": f"Endpoint {endpoint_idx} of {config_id} disabled and saved"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_idx} of {config_id} not found")
|
||||||
|
|
||||||
|
@router.get("/v1/configs")
|
||||||
|
async def get_configs():
|
||||||
|
"""
|
||||||
|
Returns all config IDs.
|
||||||
|
"""
|
||||||
|
configs = []
|
||||||
|
for config_id in key_manager.get_config_ids():
|
||||||
|
keys = key_manager.get_keys_by_config(config_id)
|
||||||
|
if keys:
|
||||||
|
configs.append({
|
||||||
|
"config_id": config_id,
|
||||||
|
"model_name": keys[0].model_name,
|
||||||
|
"provider": keys[0].provider,
|
||||||
|
"key_count": len(keys),
|
||||||
|
"enabled": keys[0].enabled
|
||||||
|
})
|
||||||
|
return {"data": configs}
|
||||||
|
|
||||||
|
@router.post("/v1/chat/completions")
|
||||||
|
async def chat_completions(request: Request, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
Unified entry point.
|
||||||
|
Expects JSON body with "model" field.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except:
|
||||||
|
logger.error(f"[Invalid JSON] detail={traceback.format_exc()}...")
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid JSON")
|
||||||
|
|
||||||
|
model_name = body.get("model")
|
||||||
|
|
||||||
|
key, error_reason = await dispatcher.select_key(model_name)
|
||||||
|
if not key:
|
||||||
|
logger.error(f"[Key Not Found] model={model_name} | error={error_reason}...")
|
||||||
|
raise HTTPException(status_code=503, detail=error_reason)
|
||||||
|
|
||||||
|
if not model_name:
|
||||||
|
model_name = key.model_name
|
||||||
|
|
||||||
|
acquired = await dispatcher.acquire_lease(key)
|
||||||
|
if not acquired:
|
||||||
|
logger.error(f"[System busy (Concurrency Limit)] model={model_name} | provider={key.provider} | key={key.id}|owner={key.owner or 'N/A'}|key_prefix={key.key[:20] if key.key else 'N/A'}...")
|
||||||
|
raise HTTPException(status_code=503, detail="System busy (Concurrency Limit)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
handler = get_provider_handler(key)
|
||||||
|
url = handler.get_url()
|
||||||
|
headers = handler.get_headers()
|
||||||
|
|
||||||
|
params = {k: v for k, v in body.items() if k not in ["model", "messages", "stream"]}
|
||||||
|
|
||||||
|
payload = handler.get_payload(body.get("messages", []), params)
|
||||||
|
|
||||||
|
timeout = httpx.Timeout(300.0, connect=10.0)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
|
resp = await client.post(url, headers=headers, json=payload)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
error_detail = resp.text[:500] if len(resp.text) > 500 else resp.text
|
||||||
|
logger.error(
|
||||||
|
f"[Request Failed] model={model_name} | key={key.id}|owner={key.owner or 'N/A'}|key_prefix={key.key[:20] if key.key else 'N/A'} | "
|
||||||
|
f"api_base={key.api_base} | "
|
||||||
|
f"status={resp.status_code} | response={error_detail}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 429:
|
||||||
|
error_message = ""
|
||||||
|
try:
|
||||||
|
error_data = resp.json()
|
||||||
|
if isinstance(error_data, dict):
|
||||||
|
error_message = error_data.get("error", {}).get("message", "") or error_data.get("message", "")
|
||||||
|
except:
|
||||||
|
error_message = resp.text
|
||||||
|
await dispatcher.report_failure(key, is_rate_limit=True, error_message=error_message)
|
||||||
|
|
||||||
|
raise HTTPException(status_code=resp.status_code, detail=f"Provider Error: {error_detail}")
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
parsed_result = handler.parse_response(data)
|
||||||
|
|
||||||
|
total_tokens = parsed_result["usage"].get("total_tokens", 0)
|
||||||
|
background_tasks.add_task(dispatcher.record_usage, key, total_tokens)
|
||||||
|
|
||||||
|
response_data = {
|
||||||
|
"id": f"chatcmpl-{int(time.time())}",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": parsed_result["content"],
|
||||||
|
"reasoning_content": parsed_result.get("reasoning_content")
|
||||||
|
},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": parsed_result["usage"]
|
||||||
|
}
|
||||||
|
|
||||||
|
return response_data
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
_key_info = f"{key.id}|owner={key.owner or 'N/A'}|key_prefix={key.key[:20] if key.key else 'N/A'}" if key else 'N/A'
|
||||||
|
logger.exception(
|
||||||
|
f"[Internal Error] model={model_name} | "
|
||||||
|
f"key={_key_info} | api_base={getattr(key, 'api_base', 'N/A') if key else 'N/A'} | error={str(e)}"
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
finally:
|
||||||
|
await dispatcher.release_lease(key)
|
||||||
556
app/core/dispatcher.py
Executable file
556
app/core/dispatcher.py
Executable file
@ -0,0 +1,556 @@
|
|||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import fnmatch
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from app.models import APIKey, KeyStatus
|
||||||
|
from app.core.key_manager import KeyManager
|
||||||
|
from app.utils.redis_client import RedisClient
|
||||||
|
|
||||||
|
class Dispatcher:
|
||||||
|
DAILY_RPD_TTL_SECONDS = 90 * 86400
|
||||||
|
RATE_LIMIT_PENALTY_SEQ_KEY = "keypool:priority:429:seq"
|
||||||
|
PROVIDER_PRIORITIES = {
|
||||||
|
"modelscope": 0,
|
||||||
|
"nvidia": 0,
|
||||||
|
"siliconflow": 0,
|
||||||
|
"llama-local": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, key_manager: KeyManager):
|
||||||
|
self.key_manager = key_manager
|
||||||
|
self.redis = RedisClient.get_instance()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_day_str(offset_days: int = 0) -> str:
|
||||||
|
current_day = datetime.now(timezone.utc) + timedelta(days=offset_days)
|
||||||
|
return current_day.strftime("%Y%m%d")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_rate_limit_penalty_key(key_id: str) -> str:
|
||||||
|
return f"keypool:priority:429:{key_id}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_provider_priority(cls, provider: str) -> int:
|
||||||
|
return cls.PROVIDER_PRIORITIES.get(provider.lower(), len(cls.PROVIDER_PRIORITIES))
|
||||||
|
|
||||||
|
async def _delete_matching_keys(self, pattern: str, batch_size: int = 200) -> int:
|
||||||
|
if hasattr(self.redis, "scan_iter"):
|
||||||
|
batch = []
|
||||||
|
deleted = 0
|
||||||
|
async for key in self.redis.scan_iter(match=pattern, count=batch_size):
|
||||||
|
batch.append(key)
|
||||||
|
if len(batch) >= batch_size:
|
||||||
|
deleted += await self.redis.delete(*batch)
|
||||||
|
batch.clear()
|
||||||
|
if batch:
|
||||||
|
deleted += await self.redis.delete(*batch)
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
if not hasattr(self.redis, "values"):
|
||||||
|
raise RuntimeError("Redis client does not support scan_iter")
|
||||||
|
|
||||||
|
keys = [key for key in self.redis.values if fnmatch.fnmatch(key, pattern)]
|
||||||
|
return await self.redis.delete(*keys) if keys else 0
|
||||||
|
|
||||||
|
async def reset_today_rpd_usage(self) -> int:
|
||||||
|
day = self._get_day_str()
|
||||||
|
return await self._delete_matching_keys(f"keypool:usage:rpd:*:{day}")
|
||||||
|
|
||||||
|
async def get_key_stats(self, key: APIKey):
|
||||||
|
"""Fetches runtime stats from Redis for a single key."""
|
||||||
|
concurrency_key = f"keypool:concurrency:{key.id}"
|
||||||
|
penalty_key = self._get_rate_limit_penalty_key(key.id)
|
||||||
|
|
||||||
|
current_day = self._get_day_str()
|
||||||
|
rpd_key = f"keypool:usage:rpd:{key.id}:{current_day}"
|
||||||
|
|
||||||
|
# Pipeline for efficiency
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.get(concurrency_key)
|
||||||
|
pipe.get(penalty_key)
|
||||||
|
pipe.get(rpd_key)
|
||||||
|
results = await pipe.execute()
|
||||||
|
|
||||||
|
current_concurrency = int(results[0]) if results[0] else 0
|
||||||
|
rate_limit_penalty = int(results[1]) if results[1] else 0
|
||||||
|
current_rpd = int(results[2]) if results[2] else 0
|
||||||
|
|
||||||
|
key.current_concurrency = current_concurrency
|
||||||
|
key.rate_limit_penalty = rate_limit_penalty
|
||||||
|
key.current_rpd = current_rpd
|
||||||
|
return key
|
||||||
|
|
||||||
|
async def select_key(self, model_name: Optional[str] = None, timeout: float = 30.0) -> tuple[Optional[APIKey], str]:
|
||||||
|
"""
|
||||||
|
Selects the best available key for the given model.
|
||||||
|
If model_name is None, selects from ALL available keys.
|
||||||
|
|
||||||
|
If no keys are available immediately, it will wait (queue) up to `timeout` seconds.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
1. Try to find an available key immediately.
|
||||||
|
2. If none found, wait for a short interval and retry (polling/queue simulation).
|
||||||
|
In a production system, we might use a real Redis list/pubsub for queuing,
|
||||||
|
but for now, we'll use a local async loop with exponential backoff or simple polling.
|
||||||
|
|
||||||
|
Returns: (key, error_reason)
|
||||||
|
|
||||||
|
Note: If model_name is explicitly specified (force_select=True), the status check
|
||||||
|
for DISABLED keys is bypassed, allowing access to disabled models.
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
# If model_name is explicitly specified, force select (bypass enabled check)
|
||||||
|
force_select = bool(model_name)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# 1. Try to select a key
|
||||||
|
key, error = await self._try_select_key_once(model_name, force_select=force_select)
|
||||||
|
if key:
|
||||||
|
return key, ""
|
||||||
|
|
||||||
|
# 2. Check timeout
|
||||||
|
if time.time() - start_time > timeout:
|
||||||
|
return None, f"Timeout waiting for available key: {error}"
|
||||||
|
|
||||||
|
# 3. Wait before retrying (simple polling).
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
async def _try_select_key_once(self, model_name: Optional[str] = None, force_select: bool = False) -> tuple[Optional[APIKey], str]:
|
||||||
|
"""
|
||||||
|
Internal method to try selecting a key once without waiting.
|
||||||
|
Strategy:
|
||||||
|
1. Filter by status and hard limits.
|
||||||
|
2. Prefer keys that never hit 429.
|
||||||
|
3. Prefer providers by static rank: modelscope > nvidia > siliconflow.
|
||||||
|
4. For keys that did hit 429, use the demotion order as a queue.
|
||||||
|
5. Use load as the final tie breaker inside the same priority tier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Optional model name to filter by
|
||||||
|
force_select: If True, bypass the status check (allows accessing disabled models)
|
||||||
|
|
||||||
|
Returns: (key, error_reason)
|
||||||
|
"""
|
||||||
|
candidates = self.key_manager.get_candidate_keys(model_name)
|
||||||
|
if not candidates:
|
||||||
|
return None, "No keys configured for this model" if model_name else "No keys configured in the pool"
|
||||||
|
|
||||||
|
# Prepare keys for mget
|
||||||
|
concurrency_keys = [f"keypool:concurrency:{k.id}" for k in candidates]
|
||||||
|
penalty_keys = [self._get_rate_limit_penalty_key(k.id) for k in candidates]
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
current_minute = int(now / 60)
|
||||||
|
prev_minute = current_minute - 1
|
||||||
|
|
||||||
|
rpm_curr_keys = [f"keypool:usage:rpm:{k.id}:{current_minute}" for k in candidates]
|
||||||
|
rpm_prev_keys = [f"keypool:usage:rpm:{k.id}:{prev_minute}" for k in candidates]
|
||||||
|
|
||||||
|
tpm_curr_keys = [f"keypool:usage:tpm:{k.id}:{current_minute}" for k in candidates]
|
||||||
|
tpm_prev_keys = [f"keypool:usage:tpm:{k.id}:{prev_minute}" for k in candidates]
|
||||||
|
|
||||||
|
current_day = self._get_day_str()
|
||||||
|
rpd_keys = [f"keypool:usage:rpd:{k.id}:{current_day}" for k in candidates]
|
||||||
|
|
||||||
|
all_keys = concurrency_keys + penalty_keys + rpm_curr_keys + rpm_prev_keys + tpm_curr_keys + tpm_prev_keys + rpd_keys
|
||||||
|
if not all_keys:
|
||||||
|
return None, "No keys available (empty stats)"
|
||||||
|
|
||||||
|
values = await self.redis.mget(all_keys)
|
||||||
|
|
||||||
|
n = len(candidates)
|
||||||
|
concurrency_values = values[:n]
|
||||||
|
penalty_values = values[n:2*n]
|
||||||
|
rpm_curr_values = values[2*n:3*n]
|
||||||
|
rpm_prev_values = values[3*n:4*n]
|
||||||
|
tpm_curr_values = values[4*n:5*n]
|
||||||
|
tpm_prev_values = values[5*n:6*n]
|
||||||
|
rpd_values = values[6*n:7*n]
|
||||||
|
|
||||||
|
# Calculate window weight
|
||||||
|
# If we are at second 0 of the minute, weight of prev is 1.0 (actually 0.99...)
|
||||||
|
# If we are at second 59, weight of prev is ~0
|
||||||
|
# Formula: weight = (60 - seconds_elapsed) / 60
|
||||||
|
seconds_elapsed = now % 60
|
||||||
|
prev_weight = (60 - seconds_elapsed) / 60
|
||||||
|
|
||||||
|
available_keys = []
|
||||||
|
rejection_reasons = []
|
||||||
|
|
||||||
|
for i, key in enumerate(candidates):
|
||||||
|
current_conc = int(concurrency_values[i]) if concurrency_values[i] else 0
|
||||||
|
rate_limit_penalty = int(penalty_values[i]) if penalty_values[i] else 0
|
||||||
|
|
||||||
|
r_curr = int(rpm_curr_values[i]) if rpm_curr_values[i] else 0
|
||||||
|
r_prev = int(rpm_prev_values[i]) if rpm_prev_values[i] else 0
|
||||||
|
|
||||||
|
t_curr = int(tpm_curr_values[i]) if tpm_curr_values[i] else 0
|
||||||
|
t_prev = int(tpm_prev_values[i]) if tpm_prev_values[i] else 0
|
||||||
|
|
||||||
|
current_rpd = int(rpd_values[i]) if rpd_values[i] else 0
|
||||||
|
|
||||||
|
current_rpm = int(r_curr + r_prev * prev_weight)
|
||||||
|
current_tpm = int(t_curr + t_prev * prev_weight)
|
||||||
|
|
||||||
|
key.current_concurrency = current_conc
|
||||||
|
key.rate_limit_penalty = rate_limit_penalty
|
||||||
|
key.current_rpm = current_rpm
|
||||||
|
key.current_tpm = current_tpm
|
||||||
|
key.current_rpd = current_rpd
|
||||||
|
|
||||||
|
if key.status != KeyStatus.ACTIVE:
|
||||||
|
# If force_select is True (explicit model specified), bypass status check
|
||||||
|
if not force_select:
|
||||||
|
rejection_reasons.append(f"{key.id}({key.owner}): Inactive (status={key.status})")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key.current_concurrency >= key.max_concurrency:
|
||||||
|
rejection_reasons.append(f"{key.id}({key.owner}): Concurrency limit ({key.current_concurrency}/{key.max_concurrency})")
|
||||||
|
continue
|
||||||
|
if current_rpm >= key.rpm_limit:
|
||||||
|
rejection_reasons.append(f"{key.id}({key.owner}): RPM limit ({current_rpm}/{key.rpm_limit})")
|
||||||
|
continue
|
||||||
|
if current_tpm >= key.tpm_limit:
|
||||||
|
rejection_reasons.append(f"{key.id}({key.owner}): TPM limit ({current_tpm}/{key.tpm_limit})")
|
||||||
|
continue
|
||||||
|
if key.rpd_limit > 0 and current_rpd >= key.rpd_limit:
|
||||||
|
rejection_reasons.append(f"{key.id}({key.owner}): RPD limit ({current_rpd}/{key.rpd_limit})")
|
||||||
|
continue
|
||||||
|
|
||||||
|
available_keys.append(key)
|
||||||
|
|
||||||
|
if not available_keys:
|
||||||
|
reason = "; ".join(rejection_reasons) if rejection_reasons else "No available keys found"
|
||||||
|
return None, f"All keys unavailable: {reason}"
|
||||||
|
|
||||||
|
def calculate_load_score(k: APIKey) -> float:
|
||||||
|
conc_ratio = k.current_concurrency / k.max_concurrency if k.max_concurrency > 0 else 0
|
||||||
|
rpm_ratio = k.current_rpm / k.rpm_limit if k.rpm_limit > 0 else 0
|
||||||
|
tpm_ratio = k.current_tpm / k.tpm_limit if k.tpm_limit > 0 else 0
|
||||||
|
rpd_ratio = k.current_rpd / k.rpd_limit if k.rpd_limit > 0 else 0
|
||||||
|
return max(conc_ratio, rpm_ratio, tpm_ratio, rpd_ratio)
|
||||||
|
|
||||||
|
def selection_key(k: APIKey) -> tuple[int, int, int, float]:
|
||||||
|
provider_priority = self._get_provider_priority(k.provider)
|
||||||
|
if k.rate_limit_penalty > 0:
|
||||||
|
return (1, k.rate_limit_penalty, provider_priority, calculate_load_score(k))
|
||||||
|
return (0, provider_priority, 0, calculate_load_score(k))
|
||||||
|
|
||||||
|
best_key = min(available_keys, key=selection_key)
|
||||||
|
|
||||||
|
return best_key, ""
|
||||||
|
|
||||||
|
async def acquire_lease(self, key: APIKey) -> bool:
|
||||||
|
"""
|
||||||
|
Increments concurrency, RPM and RPD for the key.
|
||||||
|
Double check limit in Redis to be safe (race conditions).
|
||||||
|
"""
|
||||||
|
key_concurrency_key = f"keypool:concurrency:{key.id}"
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
current_minute = int(now / 60)
|
||||||
|
prev_minute = current_minute - 1
|
||||||
|
current_day = self._get_day_str()
|
||||||
|
|
||||||
|
key_rpm_key = f"keypool:usage:rpm:{key.id}:{current_minute}"
|
||||||
|
key_rpm_prev_key = f"keypool:usage:rpm:{key.id}:{prev_minute}"
|
||||||
|
key_rpd_key = f"keypool:usage:rpd:{key.id}:{current_day}"
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.incr(key_concurrency_key)
|
||||||
|
pipe.incr(key_rpm_key)
|
||||||
|
pipe.expire(key_rpm_key, 65)
|
||||||
|
pipe.get(key_rpm_prev_key)
|
||||||
|
pipe.incr(key_rpd_key)
|
||||||
|
# Keep daily counters for two days so the dashboard can read yesterday's RPD.
|
||||||
|
pipe.expire(key_rpd_key, self.DAILY_RPD_TTL_SECONDS)
|
||||||
|
|
||||||
|
results = await pipe.execute()
|
||||||
|
curr_conc = results[0]
|
||||||
|
curr_rpm_val = results[1]
|
||||||
|
prev_rpm_val = int(results[3]) if results[3] else 0
|
||||||
|
curr_rpd_val = results[4]
|
||||||
|
|
||||||
|
seconds_elapsed = now % 60
|
||||||
|
prev_weight = (60 - seconds_elapsed) / 60
|
||||||
|
estimated_rpm = curr_rpm_val + (prev_rpm_val * prev_weight)
|
||||||
|
|
||||||
|
if curr_conc > key.max_concurrency:
|
||||||
|
await self.redis.decr(key_concurrency_key)
|
||||||
|
await self.redis.decr(key_rpm_key)
|
||||||
|
await self.redis.decr(key_rpd_key)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if estimated_rpm > key.rpm_limit:
|
||||||
|
await self.redis.decr(key_concurrency_key)
|
||||||
|
await self.redis.decr(key_rpm_key)
|
||||||
|
await self.redis.decr(key_rpd_key)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if key.rpd_limit > 0 and curr_rpd_val > key.rpd_limit:
|
||||||
|
await self.redis.decr(key_concurrency_key)
|
||||||
|
await self.redis.decr(key_rpm_key)
|
||||||
|
await self.redis.decr(key_rpd_key)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def release_lease(self, key: APIKey):
|
||||||
|
"""Decrements concurrency."""
|
||||||
|
key_concurrency_key = f"keypool:concurrency:{key.id}"
|
||||||
|
await self.redis.decr(key_concurrency_key)
|
||||||
|
|
||||||
|
async def record_usage(self, key: APIKey, tokens: int):
|
||||||
|
"""Records Token usage (TPM)."""
|
||||||
|
if tokens <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_minute = int(time.time() / 60)
|
||||||
|
key_tpm_key = f"keypool:usage:tpm:{key.id}:{current_minute}"
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.incrby(key_tpm_key, tokens)
|
||||||
|
pipe.expire(key_tpm_key, 65)
|
||||||
|
await pipe.execute()
|
||||||
|
|
||||||
|
|
||||||
|
async def report_failure(self, key: APIKey, is_rate_limit: bool = False, error_message: str = ""):
|
||||||
|
"""
|
||||||
|
Reports a failure. A 429 moves the key to the back of the priority queue.
|
||||||
|
"""
|
||||||
|
del error_message
|
||||||
|
if is_rate_limit:
|
||||||
|
seq = await self.redis.incr(self.RATE_LIMIT_PENALTY_SEQ_KEY)
|
||||||
|
await self.redis.set(self._get_rate_limit_penalty_key(key.id), seq)
|
||||||
|
|
||||||
|
async def get_all_key_stats(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Retrieves statistics for all keys, including status, concurrency, and usage limits.
|
||||||
|
"""
|
||||||
|
keys = self.key_manager.get_all_keys()
|
||||||
|
if not keys:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Prepare keys for batch fetching from Redis
|
||||||
|
now = time.time()
|
||||||
|
current_minute = int(now / 60)
|
||||||
|
prev_minute = current_minute - 1
|
||||||
|
current_day = self._get_day_str()
|
||||||
|
|
||||||
|
redis_keys = []
|
||||||
|
for key in keys:
|
||||||
|
redis_keys.extend([
|
||||||
|
f"keypool:concurrency:{key.id}",
|
||||||
|
self._get_rate_limit_penalty_key(key.id),
|
||||||
|
f"keypool:usage:rpm:{key.id}:{current_minute}",
|
||||||
|
f"keypool:usage:rpm:{key.id}:{prev_minute}",
|
||||||
|
f"keypool:usage:tpm:{key.id}:{current_minute}",
|
||||||
|
f"keypool:usage:tpm:{key.id}:{prev_minute}",
|
||||||
|
f"keypool:usage:rpd:{key.id}:{current_day}"
|
||||||
|
])
|
||||||
|
|
||||||
|
if not redis_keys:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Fetch all values in a single round-trip
|
||||||
|
values = await self.redis.mget(redis_keys)
|
||||||
|
|
||||||
|
stats = []
|
||||||
|
|
||||||
|
# Process results in chunks of 7
|
||||||
|
seconds_elapsed = now % 60
|
||||||
|
prev_weight = (60 - seconds_elapsed) / 60
|
||||||
|
|
||||||
|
for i, key in enumerate(keys):
|
||||||
|
base_idx = i * 7
|
||||||
|
|
||||||
|
curr_conc = int(values[base_idx]) if values[base_idx] else 0
|
||||||
|
rate_limit_penalty = int(values[base_idx + 1]) if values[base_idx + 1] else 0
|
||||||
|
|
||||||
|
rpm_curr = int(values[base_idx + 2]) if values[base_idx + 2] else 0
|
||||||
|
rpm_prev = int(values[base_idx + 3]) if values[base_idx + 3] else 0
|
||||||
|
|
||||||
|
tpm_curr = int(values[base_idx + 4]) if values[base_idx + 4] else 0
|
||||||
|
tpm_prev = int(values[base_idx + 5]) if values[base_idx + 5] else 0
|
||||||
|
|
||||||
|
rpd_curr = int(values[base_idx + 6]) if values[base_idx + 6] else 0
|
||||||
|
|
||||||
|
estimated_rpm = int(rpm_curr + rpm_prev * prev_weight)
|
||||||
|
estimated_tpm = int(tpm_curr + tpm_prev * prev_weight)
|
||||||
|
|
||||||
|
# Determine effective status
|
||||||
|
status = key.status
|
||||||
|
if key.status == KeyStatus.ACTIVE and rate_limit_penalty > 0:
|
||||||
|
status = KeyStatus.COOLDOWN
|
||||||
|
|
||||||
|
key_stat = {
|
||||||
|
"id": key.id,
|
||||||
|
"model_name": key.model_name,
|
||||||
|
"provider": key.provider,
|
||||||
|
"status": status,
|
||||||
|
"enabled": key.enabled,
|
||||||
|
"config_id": getattr(key, 'config_id', None),
|
||||||
|
"owner": getattr(key, 'owner', None),
|
||||||
|
"endpoint_idx": getattr(key, 'endpoint_idx', None),
|
||||||
|
"api_base": getattr(key, 'api_base', None),
|
||||||
|
"limits": {
|
||||||
|
"rpm": key.rpm_limit,
|
||||||
|
"tpm": key.tpm_limit,
|
||||||
|
"rpd": key.rpd_limit,
|
||||||
|
"max_concurrency": key.max_concurrency
|
||||||
|
},
|
||||||
|
"usage": {
|
||||||
|
"current_concurrency": curr_conc,
|
||||||
|
"current_rpm": estimated_rpm,
|
||||||
|
"current_tpm": estimated_tpm,
|
||||||
|
"current_rpd": rpd_curr
|
||||||
|
},
|
||||||
|
"cooldown_remaining": 0
|
||||||
|
}
|
||||||
|
stats.append(key_stat)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
async def get_usage_summary(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get overall usage summary across all keys.
|
||||||
|
Returns total RPM, TPM, RPD, concurrency, and active keys count.
|
||||||
|
"""
|
||||||
|
keys = self.key_manager.get_all_keys()
|
||||||
|
if not keys:
|
||||||
|
return {
|
||||||
|
"total_rpm": 0,
|
||||||
|
"total_tpm": 0,
|
||||||
|
"total_rpd": 0,
|
||||||
|
"total_yesterday_rpd": 0,
|
||||||
|
"total_concurrency": 0,
|
||||||
|
"active_keys": 0,
|
||||||
|
"cooldown_keys": 0,
|
||||||
|
"disabled_keys": 0,
|
||||||
|
"total_keys": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
current_minute = int(now / 60)
|
||||||
|
prev_minute = current_minute - 1
|
||||||
|
current_day = self._get_day_str()
|
||||||
|
yesterday_day = self._get_day_str(-1)
|
||||||
|
|
||||||
|
redis_keys = []
|
||||||
|
for key in keys:
|
||||||
|
redis_keys.extend([
|
||||||
|
f"keypool:concurrency:{key.id}",
|
||||||
|
self._get_rate_limit_penalty_key(key.id),
|
||||||
|
f"keypool:usage:rpm:{key.id}:{current_minute}",
|
||||||
|
f"keypool:usage:rpm:{key.id}:{prev_minute}",
|
||||||
|
f"keypool:usage:tpm:{key.id}:{current_minute}",
|
||||||
|
f"keypool:usage:tpm:{key.id}:{prev_minute}",
|
||||||
|
f"keypool:usage:rpd:{key.id}:{current_day}",
|
||||||
|
f"keypool:usage:rpd:{key.id}:{yesterday_day}"
|
||||||
|
])
|
||||||
|
|
||||||
|
if not redis_keys:
|
||||||
|
return {
|
||||||
|
"total_rpm": 0,
|
||||||
|
"total_tpm": 0,
|
||||||
|
"total_rpd": 0,
|
||||||
|
"total_yesterday_rpd": 0,
|
||||||
|
"total_concurrency": 0,
|
||||||
|
"active_keys": 0,
|
||||||
|
"cooldown_keys": 0,
|
||||||
|
"disabled_keys": 0,
|
||||||
|
"total_keys": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
values = await self.redis.mget(redis_keys)
|
||||||
|
|
||||||
|
seconds_elapsed = now % 60
|
||||||
|
prev_weight = (60 - seconds_elapsed) / 60
|
||||||
|
|
||||||
|
total_rpm = 0
|
||||||
|
total_tpm = 0
|
||||||
|
total_rpd = 0
|
||||||
|
total_yesterday_rpd = 0
|
||||||
|
total_concurrency = 0
|
||||||
|
active_keys = 0
|
||||||
|
cooldown_keys = 0
|
||||||
|
disabled_keys = 0
|
||||||
|
|
||||||
|
for i, key in enumerate(keys):
|
||||||
|
base_idx = i * 8
|
||||||
|
|
||||||
|
curr_conc = int(values[base_idx]) if values[base_idx] else 0
|
||||||
|
rate_limit_penalty = int(values[base_idx + 1]) if values[base_idx + 1] else 0
|
||||||
|
|
||||||
|
rpm_curr = int(values[base_idx + 2]) if values[base_idx + 2] else 0
|
||||||
|
rpm_prev = int(values[base_idx + 3]) if values[base_idx + 3] else 0
|
||||||
|
tpm_curr = int(values[base_idx + 4]) if values[base_idx + 4] else 0
|
||||||
|
tpm_prev = int(values[base_idx + 5]) if values[base_idx + 5] else 0
|
||||||
|
rpd_curr = int(values[base_idx + 6]) if values[base_idx + 6] else 0
|
||||||
|
rpd_yesterday = int(values[base_idx + 7]) if values[base_idx + 7] else 0
|
||||||
|
|
||||||
|
estimated_rpm = int(rpm_curr + rpm_prev * prev_weight)
|
||||||
|
estimated_tpm = int(tpm_curr + tpm_prev * prev_weight)
|
||||||
|
|
||||||
|
total_rpm += estimated_rpm
|
||||||
|
total_tpm += estimated_tpm
|
||||||
|
total_rpd += rpd_curr
|
||||||
|
total_yesterday_rpd += rpd_yesterday
|
||||||
|
total_concurrency += curr_conc
|
||||||
|
|
||||||
|
if key.status == KeyStatus.ACTIVE and rate_limit_penalty == 0:
|
||||||
|
active_keys += 1
|
||||||
|
elif key.status == KeyStatus.DISABLED:
|
||||||
|
disabled_keys += 1
|
||||||
|
elif rate_limit_penalty > 0:
|
||||||
|
cooldown_keys += 1
|
||||||
|
else:
|
||||||
|
active_keys += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_rpm": total_rpm,
|
||||||
|
"total_tpm": total_tpm,
|
||||||
|
"total_rpd": total_rpd,
|
||||||
|
"total_yesterday_rpd": total_yesterday_rpd,
|
||||||
|
"total_concurrency": total_concurrency,
|
||||||
|
"active_keys": active_keys,
|
||||||
|
"cooldown_keys": cooldown_keys,
|
||||||
|
"disabled_keys": disabled_keys,
|
||||||
|
"total_keys": len(keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_daily_usage_history(self, days: int = 7) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Aggregate total RPD usage for the most recent N days.
|
||||||
|
Uses the same UTC day buckets as runtime RPD limiting.
|
||||||
|
"""
|
||||||
|
keys = self.key_manager.get_all_keys()
|
||||||
|
if not keys or days <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
day_labels = [self._get_day_str(-offset) for offset in range(days - 1, -1, -1)]
|
||||||
|
|
||||||
|
redis_keys = []
|
||||||
|
for day_str in day_labels:
|
||||||
|
for key in keys:
|
||||||
|
redis_keys.append(f"keypool:usage:rpd:{key.id}:{day_str}")
|
||||||
|
|
||||||
|
values = await self.redis.mget(redis_keys)
|
||||||
|
|
||||||
|
history = []
|
||||||
|
key_count = len(keys)
|
||||||
|
for day_index, day_str in enumerate(day_labels):
|
||||||
|
base_idx = day_index * key_count
|
||||||
|
day_total = 0
|
||||||
|
for key_offset in range(key_count):
|
||||||
|
value = values[base_idx + key_offset]
|
||||||
|
day_total += int(value) if value else 0
|
||||||
|
|
||||||
|
history.append({
|
||||||
|
"day": day_str,
|
||||||
|
"label": f"{day_str[4:6]}-{day_str[6:8]}",
|
||||||
|
"total_rpd": day_total,
|
||||||
|
})
|
||||||
|
|
||||||
|
return history
|
||||||
294
app/core/key_manager.py
Executable file
294
app/core/key_manager.py
Executable file
@ -0,0 +1,294 @@
|
|||||||
|
import yaml
|
||||||
|
import time
|
||||||
|
from typing import List, Dict, Optional, Set, Any, Tuple
|
||||||
|
from app.models import APIKey, KeyStatus
|
||||||
|
|
||||||
|
class KeyManager:
|
||||||
|
def __init__(self, config_path: str):
|
||||||
|
self.config_path = config_path
|
||||||
|
self._keys: Dict[str, APIKey] = {}
|
||||||
|
self._keys_by_model: Dict[str, List[APIKey]] = {}
|
||||||
|
self._config_ids: Set[str] = set()
|
||||||
|
self._last_load_time: float = 0
|
||||||
|
self._raw_config: Dict[str, Any] = {}
|
||||||
|
self._config_enabled_map: Dict[str, bool] = {}
|
||||||
|
# Track endpoint enable status: {(config_id, idx): enabled}
|
||||||
|
self._endpoint_enabled_map: Dict[Tuple[str, int], bool] = {}
|
||||||
|
# Track endpoint URLs: {(config_id, idx): url}
|
||||||
|
self._endpoint_urls: Dict[Tuple[str, int], str] = {}
|
||||||
|
self.load_keys()
|
||||||
|
|
||||||
|
def load_keys(self):
|
||||||
|
"""Loads keys from the YAML configuration file."""
|
||||||
|
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
self._raw_config = data
|
||||||
|
self._keys = {}
|
||||||
|
self._keys_by_model = {}
|
||||||
|
self._config_ids = set()
|
||||||
|
self._config_enabled_map = {}
|
||||||
|
|
||||||
|
configs = {}
|
||||||
|
for cfg in data.get('model_configs', []):
|
||||||
|
if 'id' in cfg:
|
||||||
|
configs[cfg['id']] = cfg
|
||||||
|
|
||||||
|
for key_group in data.get('keys', []):
|
||||||
|
config_id = key_group.get('config_id')
|
||||||
|
group_enabled = key_group.get('enabled', True)
|
||||||
|
keys_list = key_group.get('keys', [])
|
||||||
|
|
||||||
|
if config_id:
|
||||||
|
self._config_ids.add(config_id)
|
||||||
|
self._config_enabled_map[config_id] = group_enabled
|
||||||
|
|
||||||
|
if config_id and config_id in configs:
|
||||||
|
cfg = configs[config_id]
|
||||||
|
cfg_copy = cfg.copy()
|
||||||
|
cfg_copy.pop('id', None)
|
||||||
|
|
||||||
|
# Check if this config has multiple endpoints (for local llama.cpp load balancing)
|
||||||
|
api_endpoints = cfg_copy.get('api_endpoints', [])
|
||||||
|
|
||||||
|
if api_endpoints and len(api_endpoints) > 0:
|
||||||
|
# Generate virtual keys for each endpoint
|
||||||
|
for idx, endpoint_item in enumerate(api_endpoints):
|
||||||
|
# Handle both old format (string URL) and new format (dict with url/enabled)
|
||||||
|
if isinstance(endpoint_item, dict):
|
||||||
|
endpoint_url = endpoint_item.get('url', '')
|
||||||
|
endpoint_enabled = endpoint_item.get('enabled', True)
|
||||||
|
else:
|
||||||
|
endpoint_url = endpoint_item
|
||||||
|
endpoint_enabled = True
|
||||||
|
|
||||||
|
# Store endpoint info for dashboard
|
||||||
|
self._endpoint_enabled_map[(config_id, idx)] = endpoint_enabled
|
||||||
|
self._endpoint_urls[(config_id, idx)] = endpoint_url
|
||||||
|
|
||||||
|
final_data = {}
|
||||||
|
final_data.update(cfg_copy)
|
||||||
|
final_data.pop('api_endpoints', None) # Remove the list from individual key
|
||||||
|
final_data['config_id'] = config_id
|
||||||
|
# Endpoint is enabled only if both group and endpoint itself are enabled
|
||||||
|
final_data['enabled'] = group_enabled and endpoint_enabled
|
||||||
|
final_data['api_base'] = endpoint_url # Use individual endpoint
|
||||||
|
final_data['key'] = f"local-{config_id}-{idx}" # Virtual key for local service
|
||||||
|
final_data['endpoint_idx'] = idx # Track endpoint index
|
||||||
|
|
||||||
|
# Virtual key ID that includes endpoint info for tracking
|
||||||
|
endpoint_hash = hash(endpoint_url) & 0xFFFF
|
||||||
|
key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}"
|
||||||
|
final_data['id'] = key_id
|
||||||
|
|
||||||
|
if not group_enabled or not endpoint_enabled:
|
||||||
|
final_data['status'] = KeyStatus.DISABLED
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = APIKey(**final_data)
|
||||||
|
self._keys[key.id] = key
|
||||||
|
|
||||||
|
if key.model_name not in self._keys_by_model:
|
||||||
|
self._keys_by_model[key.model_name] = []
|
||||||
|
self._keys_by_model[key.model_name].append(key)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading virtual key {key_id} for endpoint {endpoint_url}: {e}")
|
||||||
|
else:
|
||||||
|
# Standard key loading (original logic)
|
||||||
|
for key_data in keys_list:
|
||||||
|
final_data = {}
|
||||||
|
final_data.update(cfg_copy)
|
||||||
|
final_data['config_id'] = config_id
|
||||||
|
final_data['enabled'] = group_enabled
|
||||||
|
|
||||||
|
if key_data.get('owner'):
|
||||||
|
final_data['owner'] = key_data['owner']
|
||||||
|
|
||||||
|
final_data['key'] = key_data['key']
|
||||||
|
|
||||||
|
if not group_enabled:
|
||||||
|
final_data['status'] = KeyStatus.DISABLED
|
||||||
|
|
||||||
|
key_id = f"{config_id}:{hash(key_data['key']) & 0xFFFFFFFF}"
|
||||||
|
final_data['id'] = key_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = APIKey(**final_data)
|
||||||
|
self._keys[key.id] = key
|
||||||
|
|
||||||
|
if key.model_name not in self._keys_by_model:
|
||||||
|
self._keys_by_model[key.model_name] = []
|
||||||
|
self._keys_by_model[key.model_name].append(key)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading key {key_id}: {e}")
|
||||||
|
else:
|
||||||
|
for key_data in keys_list:
|
||||||
|
final_data = {}
|
||||||
|
final_data.update(key_group)
|
||||||
|
final_data.pop('keys', None)
|
||||||
|
final_data.update(key_data)
|
||||||
|
|
||||||
|
if 'id' not in final_data:
|
||||||
|
final_data['id'] = key_data.get('key', 'unknown')[:16]
|
||||||
|
|
||||||
|
if not final_data.get('enabled', True):
|
||||||
|
final_data['status'] = KeyStatus.DISABLED
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = APIKey(**final_data)
|
||||||
|
self._keys[key.id] = key
|
||||||
|
|
||||||
|
if key.model_name not in self._keys_by_model:
|
||||||
|
self._keys_by_model[key.model_name] = []
|
||||||
|
self._keys_by_model[key.model_name].append(key)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading key {final_data.get('id', 'unknown')}: {e}")
|
||||||
|
|
||||||
|
self._last_load_time = time.time()
|
||||||
|
|
||||||
|
def _save_config(self) -> bool:
|
||||||
|
"""Save current configuration to YAML file."""
|
||||||
|
try:
|
||||||
|
for key_group in self._raw_config.get('keys', []):
|
||||||
|
config_id = key_group.get('config_id')
|
||||||
|
if config_id and config_id in self._config_enabled_map:
|
||||||
|
key_group['enabled'] = self._config_enabled_map[config_id]
|
||||||
|
|
||||||
|
# Save endpoint enable status
|
||||||
|
for (cfg_id, idx), enabled in self._endpoint_enabled_map.items():
|
||||||
|
for cfg in self._raw_config.get('model_configs', []):
|
||||||
|
if cfg.get('id') == cfg_id:
|
||||||
|
endpoints = cfg.get('api_endpoints', [])
|
||||||
|
if idx < len(endpoints):
|
||||||
|
if isinstance(endpoints[idx], dict):
|
||||||
|
endpoints[idx]['enabled'] = enabled
|
||||||
|
else:
|
||||||
|
# Convert string to dict with url and enabled
|
||||||
|
endpoints[idx] = {'url': endpoints[idx], 'enabled': enabled}
|
||||||
|
|
||||||
|
with open(self.config_path, 'w', encoding='utf-8') as f:
|
||||||
|
yaml.dump(self._raw_config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving config: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def reload_keys(self) -> bool:
|
||||||
|
"""Reload keys from config file. Returns True if successful."""
|
||||||
|
try:
|
||||||
|
self.load_keys()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reloading keys: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def set_config_enabled(self, config_id: str, enabled: bool) -> bool:
|
||||||
|
"""Enable or disable all keys for a config_id and save to file."""
|
||||||
|
if config_id not in self._config_ids:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._config_enabled_map[config_id] = enabled
|
||||||
|
|
||||||
|
updated = False
|
||||||
|
for key in self._keys.values():
|
||||||
|
if hasattr(key, 'config_id') and key.config_id == config_id:
|
||||||
|
key.enabled = enabled
|
||||||
|
if enabled:
|
||||||
|
key.status = KeyStatus.ACTIVE
|
||||||
|
else:
|
||||||
|
key.status = KeyStatus.DISABLED
|
||||||
|
updated = True
|
||||||
|
|
||||||
|
if updated:
|
||||||
|
self._save_config()
|
||||||
|
|
||||||
|
return updated
|
||||||
|
|
||||||
|
def set_key_enabled(self, key_id: str, enabled: bool) -> bool:
|
||||||
|
"""Enable or disable a specific key."""
|
||||||
|
if key_id not in self._keys:
|
||||||
|
return False
|
||||||
|
|
||||||
|
key = self._keys[key_id]
|
||||||
|
key.enabled = enabled
|
||||||
|
if enabled:
|
||||||
|
key.status = KeyStatus.ACTIVE
|
||||||
|
else:
|
||||||
|
key.status = KeyStatus.DISABLED
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_config_ids(self) -> Set[str]:
|
||||||
|
"""Get all config IDs."""
|
||||||
|
return self._config_ids.copy()
|
||||||
|
|
||||||
|
def get_keys_by_config(self, config_id: str) -> List[APIKey]:
|
||||||
|
"""Get all keys for a specific config_id."""
|
||||||
|
return [k for k in self._keys.values() if hasattr(k, 'config_id') and k.config_id == config_id]
|
||||||
|
|
||||||
|
def get_key(self, key_id: str) -> Optional[APIKey]:
|
||||||
|
return self._keys.get(key_id)
|
||||||
|
|
||||||
|
def get_candidate_keys(self, model_name: Optional[str] = None) -> List[APIKey]:
|
||||||
|
"""Returns all keys for a given model, regardless of status. If model_name is None, returns all keys."""
|
||||||
|
if not model_name:
|
||||||
|
return self.get_all_keys()
|
||||||
|
return self._keys_by_model.get(model_name, [])
|
||||||
|
|
||||||
|
def get_all_keys(self) -> List[APIKey]:
|
||||||
|
return list(self._keys.values())
|
||||||
|
|
||||||
|
def get_last_load_time(self) -> float:
|
||||||
|
return self._last_load_time
|
||||||
|
|
||||||
|
def get_config_enabled(self, config_id: str) -> bool:
|
||||||
|
"""Get enabled status for a config_id."""
|
||||||
|
return self._config_enabled_map.get(config_id, True)
|
||||||
|
|
||||||
|
def get_endpoint_info(self, config_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all endpoint info for a config_id."""
|
||||||
|
endpoints = []
|
||||||
|
for (cfg_id, idx), url in self._endpoint_urls.items():
|
||||||
|
if cfg_id == config_id:
|
||||||
|
enabled = self._endpoint_enabled_map.get((cfg_id, idx), True)
|
||||||
|
endpoint_hash = hash(url) & 0xFFFF
|
||||||
|
key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}"
|
||||||
|
endpoints.append({
|
||||||
|
'idx': idx,
|
||||||
|
'url': url,
|
||||||
|
'enabled': enabled,
|
||||||
|
'key_id': key_id
|
||||||
|
})
|
||||||
|
return sorted(endpoints, key=lambda x: x['idx'])
|
||||||
|
|
||||||
|
def set_endpoint_enabled(self, config_id: str, endpoint_idx: int, enabled: bool) -> bool:
|
||||||
|
"""Enable or disable a specific endpoint and save to file."""
|
||||||
|
if (config_id, endpoint_idx) not in self._endpoint_enabled_map:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._endpoint_enabled_map[(config_id, endpoint_idx)] = enabled
|
||||||
|
|
||||||
|
# Update the corresponding key's status
|
||||||
|
url = self._endpoint_urls.get((config_id, endpoint_idx), '')
|
||||||
|
endpoint_hash = hash(url) & 0xFFFF
|
||||||
|
key_id = f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}"
|
||||||
|
|
||||||
|
if key_id in self._keys:
|
||||||
|
key = self._keys[key_id]
|
||||||
|
key.enabled = enabled
|
||||||
|
if enabled:
|
||||||
|
key.status = KeyStatus.ACTIVE
|
||||||
|
else:
|
||||||
|
key.status = KeyStatus.DISABLED
|
||||||
|
|
||||||
|
self._save_config()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_endpoint_key_id(self, config_id: str, endpoint_idx: int) -> Optional[str]:
|
||||||
|
"""Get the key_id for a specific endpoint."""
|
||||||
|
url = self._endpoint_urls.get((config_id, endpoint_idx))
|
||||||
|
if url:
|
||||||
|
endpoint_hash = hash(url) & 0xFFFF
|
||||||
|
return f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}"
|
||||||
|
return None
|
||||||
311
app/core/providers.py
Executable file
311
app/core/providers.py
Executable file
@ -0,0 +1,311 @@
|
|||||||
|
import re
|
||||||
|
from typing import Dict, Any, List, Optional, Tuple, Union
|
||||||
|
from app.models import APIKey
|
||||||
|
|
||||||
|
class BaseProvider:
|
||||||
|
def __init__(self, api_key: APIKey):
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
def get_headers(self) -> Dict[str, str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_url(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Parses provider response into a standardized format.
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"content": str,
|
||||||
|
"reasoning_content": Optional[str],
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": int,
|
||||||
|
"output_tokens": int,
|
||||||
|
"total_tokens": int
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class OpenAIProvider(BaseProvider):
|
||||||
|
def get_headers(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key.key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_url(self) -> str:
|
||||||
|
base = self.api_key.api_base or "https://api.openai.com/v1"
|
||||||
|
return f"{base.rstrip('/')}/chat/completions"
|
||||||
|
|
||||||
|
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
data = {"messages": messages, "model": self.api_key.model_name, "stream": False}
|
||||||
|
|
||||||
|
# 1. Start with Key's default params
|
||||||
|
final_params = self.api_key.params.copy()
|
||||||
|
|
||||||
|
# 2. Update with request-level params (filter None)
|
||||||
|
valid_request_params = {k: v for k, v in params.items() if v is not None}
|
||||||
|
final_params.update(valid_request_params)
|
||||||
|
|
||||||
|
# 3. Add to payload
|
||||||
|
data.update(final_params)
|
||||||
|
|
||||||
|
# 4. Handle extra_config from Key (e.g., Qwen extra_body)
|
||||||
|
if self.api_key.extra_config:
|
||||||
|
if "extra_body" in self.api_key.extra_config:
|
||||||
|
data.update(self.api_key.extra_config["extra_body"])
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
choice = data["choices"][0]
|
||||||
|
text = choice["message"]["content"]
|
||||||
|
reasoning = choice["message"].get("reasoning_content") or \
|
||||||
|
choice.get("reasoning_content")
|
||||||
|
|
||||||
|
usage = data.get("usage", {})
|
||||||
|
return {
|
||||||
|
"content": text,
|
||||||
|
"reasoning_content": reasoning,
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": usage.get("prompt_tokens", 0),
|
||||||
|
"output_tokens": usage.get("completion_tokens", 0),
|
||||||
|
"total_tokens": usage.get("total_tokens", 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except (KeyError, IndexError):
|
||||||
|
return {"content": "", "usage": {}}
|
||||||
|
|
||||||
|
class AzureProvider(OpenAIProvider):
|
||||||
|
def get_headers(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": self.api_key.key
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_url(self) -> str:
|
||||||
|
# Expected api_base: https://{resource}.openai.azure.com/openai/deployments/{deployment}
|
||||||
|
# We need to append api-version
|
||||||
|
base = self.api_key.api_base or ""
|
||||||
|
sep = "&" if "?" in base else "?"
|
||||||
|
|
||||||
|
# Get api-version from extra_config or default
|
||||||
|
version = "2025-01-01-preview"
|
||||||
|
if self.api_key.extra_config and "api_version" in self.api_key.extra_config:
|
||||||
|
version = self.api_key.extra_config["api_version"]
|
||||||
|
|
||||||
|
return f"{base}{sep}api-version={version}"
|
||||||
|
|
||||||
|
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
payload = super().get_payload(messages, params)
|
||||||
|
# Azure usually doesn't need 'model' in body if it's in URL (deployment)
|
||||||
|
# But keeping it usually doesn't hurt, unless configured to remove.
|
||||||
|
# For simplicity, we keep it or remove it if we want strict adherence to typical azure usage.
|
||||||
|
# Let's remove it to be safe as deployment is in URL.
|
||||||
|
payload.pop("model", None)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
class GeminiProvider(BaseProvider):
|
||||||
|
def get_headers(self) -> Dict[str, str]:
|
||||||
|
return {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
def get_url(self) -> str:
|
||||||
|
# api_base example: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent
|
||||||
|
base = self.api_key.api_base or f"https://generativelanguage.googleapis.com/v1beta/models/{self.api_key.model_name}:generateContent"
|
||||||
|
sep = "&" if "?" in base else "?"
|
||||||
|
return f"{base}{sep}key={self.api_key.key}"
|
||||||
|
|
||||||
|
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
gemini_contents = []
|
||||||
|
system_instruction = None
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = "model" if msg["role"] == "assistant" else msg["role"]
|
||||||
|
content = msg["content"]
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
text_val = content if isinstance(content, str) else content[0]["text"]
|
||||||
|
system_instruction = {"parts": [{"text": text_val}]}
|
||||||
|
continue
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
content_list = content if isinstance(content, list) else [{"type": "text", "text": content}]
|
||||||
|
|
||||||
|
for item in content_list:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
parts.append({"text": item["text"]})
|
||||||
|
elif item.get("type") == "image_url":
|
||||||
|
url = item.get("image_url", {}).get("url", "")
|
||||||
|
match = re.match(r'data:(.*?);base64,(.*)', url)
|
||||||
|
if match:
|
||||||
|
parts.append({
|
||||||
|
"inline_data": {
|
||||||
|
"mime_type": match.group(1),
|
||||||
|
"data": match.group(2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if parts:
|
||||||
|
gemini_contents.append({"role": role, "parts": parts})
|
||||||
|
|
||||||
|
generation_config = {}
|
||||||
|
|
||||||
|
# 1. Start with Key's default params (mapped to Gemini format if possible, or assume config.py had correct format)
|
||||||
|
# Note: config.py used "maxOutputTokens" (Gemini style) directly in 'params'.
|
||||||
|
# We should merge self.api_key.params into generation_config
|
||||||
|
generation_config.update(self.api_key.params)
|
||||||
|
|
||||||
|
# 2. Update with request params (mapped from OpenAI style)
|
||||||
|
if "max_tokens" in params:
|
||||||
|
generation_config["maxOutputTokens"] = params["max_tokens"]
|
||||||
|
if "temperature" in params:
|
||||||
|
generation_config["temperature"] = params["temperature"]
|
||||||
|
if "top_p" in params:
|
||||||
|
generation_config["topP"] = params["top_p"]
|
||||||
|
|
||||||
|
# 3. Handle thinking config from Key or Request
|
||||||
|
# Check Key's extra_config first
|
||||||
|
if self.api_key.extra_config and "thinking_config" in self.api_key.extra_config:
|
||||||
|
generation_config["thinkingConfig"] = self.api_key.extra_config["thinking_config"]
|
||||||
|
|
||||||
|
# Override with request param if present
|
||||||
|
if "thinking_config" in params:
|
||||||
|
generation_config["thinkingConfig"] = params["thinking_config"]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"contents": gemini_contents,
|
||||||
|
"generationConfig": generation_config
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_instruction:
|
||||||
|
payload["system_instruction"] = system_instruction
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
candidate = data["candidates"][0]
|
||||||
|
parts_text = []
|
||||||
|
if "content" in candidate and "parts" in candidate["content"]:
|
||||||
|
for part in candidate["content"]["parts"]:
|
||||||
|
if "text" in part:
|
||||||
|
parts_text.append(part["text"])
|
||||||
|
|
||||||
|
text = "\n".join(parts_text) if parts_text else ""
|
||||||
|
|
||||||
|
reasoning = None
|
||||||
|
if "thinking_process" in candidate:
|
||||||
|
reasoning = candidate["thinking_process"]
|
||||||
|
|
||||||
|
usage_metadata = data.get("usageMetadata", {})
|
||||||
|
input_tokens = usage_metadata.get("promptTokenCount", 0)
|
||||||
|
output_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
||||||
|
total_tokens = input_tokens + output_tokens
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": text,
|
||||||
|
"reasoning_content": reasoning,
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"output_tokens": output_tokens,
|
||||||
|
"total_tokens": total_tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except (KeyError, IndexError):
|
||||||
|
return {"content": "", "usage": {}}
|
||||||
|
|
||||||
|
class ClaudeProvider(BaseProvider):
|
||||||
|
def get_headers(self) -> Dict[str, str]:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"x-api-key": self.api_key.key
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check extra_config for version
|
||||||
|
if self.api_key.extra_config and "anthropic_version" in self.api_key.extra_config:
|
||||||
|
headers["anthropic-version"] = self.api_key.extra_config["anthropic_version"]
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def get_url(self) -> str:
|
||||||
|
base = self.api_key.api_base or "https://api.anthropic.com/v1/messages"
|
||||||
|
return base
|
||||||
|
|
||||||
|
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
claude_msgs = []
|
||||||
|
system_prompt = ""
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = "assistant" if msg["role"] == "assistant" else msg["role"]
|
||||||
|
content = msg["content"]
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
system_prompt = content if isinstance(content, str) else content[0]["text"]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Simplified content handling for text only for now, can expand if needed
|
||||||
|
claude_msgs.append({"role": role, "content": content})
|
||||||
|
|
||||||
|
# Start with Key's default params
|
||||||
|
payload = {
|
||||||
|
"model": self.api_key.model_name,
|
||||||
|
"messages": claude_msgs,
|
||||||
|
"max_tokens": 1024 # Default fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
# Merge Key params
|
||||||
|
payload.update(self.api_key.params)
|
||||||
|
|
||||||
|
# Merge request params
|
||||||
|
if "max_tokens" in params:
|
||||||
|
payload["max_tokens"] = params["max_tokens"]
|
||||||
|
if "temperature" in params:
|
||||||
|
payload["temperature"] = params["temperature"]
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
payload["system"] = system_prompt
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
content_list = data.get("content", [])
|
||||||
|
text = ""
|
||||||
|
for item in content_list:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
text += item.get("text", "")
|
||||||
|
|
||||||
|
usage = data.get("usage", {})
|
||||||
|
input_tokens = usage.get("input_tokens", 0)
|
||||||
|
output_tokens = usage.get("output_tokens", 0)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": text,
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"output_tokens": output_tokens,
|
||||||
|
"total_tokens": input_tokens + output_tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return {"content": "", "usage": {}}
|
||||||
|
|
||||||
|
def get_provider_handler(api_key: APIKey) -> BaseProvider:
|
||||||
|
providers = {
|
||||||
|
"google": GeminiProvider,
|
||||||
|
"gemini": GeminiProvider,
|
||||||
|
"openai": OpenAIProvider,
|
||||||
|
"azure": AzureProvider,
|
||||||
|
"claude": ClaudeProvider,
|
||||||
|
"anthropic": ClaudeProvider,
|
||||||
|
"zhipu": OpenAIProvider # Zhipu is mostly OpenAI compatible
|
||||||
|
}
|
||||||
|
provider_class = providers.get(api_key.provider, OpenAIProvider)
|
||||||
|
return provider_class(api_key)
|
||||||
1016
app/core/stats_exporter.py
Executable file
1016
app/core/stats_exporter.py
Executable file
File diff suppressed because it is too large
Load Diff
207
app/core/stats_store.py
Executable file
207
app/core/stats_store.py
Executable file
@ -0,0 +1,207 @@
|
|||||||
|
import sqlite3
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class StatsStore:
|
||||||
|
DAILY_RETENTION_DAYS = 7
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | Path):
|
||||||
|
self.db_path = Path(db_path)
|
||||||
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._init_db()
|
||||||
|
|
||||||
|
def _connect(self) -> sqlite3.Connection:
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.executescript(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS latest_summary (
|
||||||
|
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||||
|
updated_at INTEGER NOT NULL,
|
||||||
|
total_keys INTEGER NOT NULL,
|
||||||
|
active_keys INTEGER NOT NULL,
|
||||||
|
cooldown_keys INTEGER NOT NULL,
|
||||||
|
disabled_keys INTEGER NOT NULL,
|
||||||
|
total_concurrency INTEGER NOT NULL,
|
||||||
|
total_rpm INTEGER NOT NULL,
|
||||||
|
total_tpm INTEGER NOT NULL,
|
||||||
|
total_rpd INTEGER NOT NULL,
|
||||||
|
total_yesterday_rpd INTEGER NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS latest_key_stats (
|
||||||
|
key_id TEXT PRIMARY KEY,
|
||||||
|
updated_at INTEGER NOT NULL,
|
||||||
|
model_name TEXT NOT NULL,
|
||||||
|
provider TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
enabled INTEGER NOT NULL,
|
||||||
|
config_id TEXT,
|
||||||
|
owner TEXT,
|
||||||
|
current_concurrency INTEGER NOT NULL,
|
||||||
|
current_rpm INTEGER NOT NULL,
|
||||||
|
current_tpm INTEGER NOT NULL,
|
||||||
|
current_rpd INTEGER NOT NULL,
|
||||||
|
cooldown_remaining REAL NOT NULL,
|
||||||
|
rpm_limit INTEGER NOT NULL,
|
||||||
|
tpm_limit INTEGER NOT NULL,
|
||||||
|
rpd_limit INTEGER NOT NULL,
|
||||||
|
max_concurrency INTEGER NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS daily_usage (
|
||||||
|
day TEXT PRIMARY KEY,
|
||||||
|
label TEXT NOT NULL,
|
||||||
|
total_rpd INTEGER NOT NULL,
|
||||||
|
updated_at INTEGER NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_export(self, stats: list[dict[str, Any]], usage_summary: dict[str, Any], daily_usage: list[dict[str, Any]]):
|
||||||
|
now = int(datetime.now(timezone.utc).timestamp())
|
||||||
|
cutoff = (datetime.now(timezone.utc) - timedelta(days=self.DAILY_RETENTION_DAYS)).strftime("%Y%m%d")
|
||||||
|
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO latest_summary (
|
||||||
|
id, updated_at, total_keys, active_keys, cooldown_keys, disabled_keys,
|
||||||
|
total_concurrency, total_rpm, total_tpm, total_rpd, total_yesterday_rpd
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(id) DO UPDATE SET
|
||||||
|
updated_at = excluded.updated_at,
|
||||||
|
total_keys = excluded.total_keys,
|
||||||
|
active_keys = excluded.active_keys,
|
||||||
|
cooldown_keys = excluded.cooldown_keys,
|
||||||
|
disabled_keys = excluded.disabled_keys,
|
||||||
|
total_concurrency = excluded.total_concurrency,
|
||||||
|
total_rpm = excluded.total_rpm,
|
||||||
|
total_tpm = excluded.total_tpm,
|
||||||
|
total_rpd = excluded.total_rpd,
|
||||||
|
total_yesterday_rpd = excluded.total_yesterday_rpd
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
now,
|
||||||
|
usage_summary["total_keys"],
|
||||||
|
usage_summary["active_keys"],
|
||||||
|
usage_summary["cooldown_keys"],
|
||||||
|
usage_summary["disabled_keys"],
|
||||||
|
usage_summary["total_concurrency"],
|
||||||
|
usage_summary["total_rpm"],
|
||||||
|
usage_summary["total_tpm"],
|
||||||
|
usage_summary["total_rpd"],
|
||||||
|
usage_summary.get("total_yesterday_rpd", 0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.execute("DELETE FROM latest_key_stats")
|
||||||
|
conn.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO latest_key_stats (
|
||||||
|
key_id, updated_at, model_name, provider, status, enabled, config_id, owner,
|
||||||
|
current_concurrency, current_rpm, current_tpm, current_rpd, cooldown_remaining,
|
||||||
|
rpm_limit, tpm_limit, rpd_limit, max_concurrency
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
item["id"],
|
||||||
|
now,
|
||||||
|
item["model_name"],
|
||||||
|
item["provider"],
|
||||||
|
item["status"],
|
||||||
|
1 if item["enabled"] else 0,
|
||||||
|
item.get("config_id"),
|
||||||
|
item.get("owner"),
|
||||||
|
item["usage"]["current_concurrency"],
|
||||||
|
item["usage"]["current_rpm"],
|
||||||
|
item["usage"]["current_tpm"],
|
||||||
|
item["usage"]["current_rpd"],
|
||||||
|
item["cooldown_remaining"],
|
||||||
|
item["limits"]["rpm"],
|
||||||
|
item["limits"]["tpm"],
|
||||||
|
item["limits"]["rpd"],
|
||||||
|
item["limits"]["max_concurrency"],
|
||||||
|
)
|
||||||
|
for item in stats
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO daily_usage (day, label, total_rpd, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
ON CONFLICT(day) DO UPDATE SET
|
||||||
|
label = excluded.label,
|
||||||
|
total_rpd = MAX(daily_usage.total_rpd, excluded.total_rpd),
|
||||||
|
updated_at = excluded.updated_at
|
||||||
|
""",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
item["day"],
|
||||||
|
item["label"],
|
||||||
|
item["total_rpd"],
|
||||||
|
now,
|
||||||
|
)
|
||||||
|
for item in daily_usage
|
||||||
|
],
|
||||||
|
)
|
||||||
|
conn.execute("DELETE FROM daily_usage WHERE day < ?", (cutoff,))
|
||||||
|
|
||||||
|
def get_daily_usage(self, days: int = 7) -> list[dict[str, Any]]:
|
||||||
|
if days <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
with self._connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT day, label, total_rpd
|
||||||
|
FROM daily_usage
|
||||||
|
ORDER BY day DESC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(days,),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
rows = list(reversed(rows))
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"day": row["day"],
|
||||||
|
"label": row["label"],
|
||||||
|
"total_rpd": row["total_rpd"],
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_latest_summary(self) -> dict[str, Any] | None:
|
||||||
|
with self._connect() as conn:
|
||||||
|
row = conn.execute("SELECT * FROM latest_summary WHERE id = 1").fetchone()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return dict(row)
|
||||||
|
|
||||||
|
def get_latest_key_stats(self) -> list[dict[str, Any]]:
|
||||||
|
with self._connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM latest_key_stats
|
||||||
|
ORDER BY provider, model_name, key_id
|
||||||
|
"""
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
return [dict(row) for row in rows]
|
||||||
|
|
||||||
|
def delete_day(self, day: str):
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute("DELETE FROM daily_usage WHERE day = ?", (day,))
|
||||||
36
app/models.py
Executable file
36
app/models.py
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
class KeyStatus(str, Enum):
|
||||||
|
ACTIVE = "active"
|
||||||
|
COOLDOWN = "cooldown"
|
||||||
|
DISABLED = "disabled"
|
||||||
|
|
||||||
|
class APIKey(BaseModel):
|
||||||
|
id: str = Field(..., description="Unique identifier for the key")
|
||||||
|
key: str = Field(..., description="The actual API Key string")
|
||||||
|
provider: str = Field(..., description="Provider identifier (e.g., google, zhipu)")
|
||||||
|
api_base: Optional[str] = Field(None, description="Base URL for the API")
|
||||||
|
model_name: str = Field(..., description="The physical model name")
|
||||||
|
params: Dict[str, Any] = Field(default_factory=dict, description="Default model parameters (e.g., temperature)")
|
||||||
|
extra_config: Dict[str, Any] = Field(default_factory=dict, description="Provider specific configuration")
|
||||||
|
enabled: bool = Field(default=True, description="Whether the key is enabled in configuration")
|
||||||
|
status: KeyStatus = Field(default=KeyStatus.ACTIVE, description="Current status of the key")
|
||||||
|
rpm_limit: int = Field(..., description="Requests Per Minute limit")
|
||||||
|
tpm_limit: int = Field(..., description="Tokens Per Minute limit")
|
||||||
|
rpd_limit: int = Field(default=0, description="Requests Per Day limit (0 = unlimited)")
|
||||||
|
max_concurrency: int = Field(..., description="Maximum concurrent requests allowed")
|
||||||
|
config_id: Optional[str] = Field(None, description="Config ID this key belongs to")
|
||||||
|
owner: Optional[str] = Field(None, description="Owner of this key")
|
||||||
|
|
||||||
|
# Runtime stats (Not loaded from config usually, but part of the object in memory)
|
||||||
|
current_concurrency: int = Field(default=0, description="Current concurrent requests")
|
||||||
|
current_rpm: int = Field(default=0, description="Current RPM usage")
|
||||||
|
current_tpm: int = Field(default=0, description="Current TPM usage")
|
||||||
|
current_rpd: int = Field(default=0, description="Current RPD usage")
|
||||||
|
rate_limit_penalty: int = Field(default=0, description="429 demotion order. 0 means never rate limited.")
|
||||||
|
endpoint_idx: Optional[int] = Field(default=None, description="Endpoint index for local llama load balancing")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
use_enum_values = True
|
||||||
10
app/runtime.py
Executable file
10
app/runtime.py
Executable file
@ -0,0 +1,10 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from app.core.dispatcher import Dispatcher
|
||||||
|
from app.core.key_manager import KeyManager
|
||||||
|
from app.core.stats_exporter import StatsExporter
|
||||||
|
|
||||||
|
|
||||||
|
key_manager = KeyManager(os.getenv("KEY_CONFIG_PATH", "config/keys.yaml"))
|
||||||
|
dispatcher = Dispatcher(key_manager)
|
||||||
|
stats_exporter = StatsExporter(dispatcher, output_dir="stats", interval=10)
|
||||||
18
app/utils/redis_client.py
Executable file
18
app/utils/redis_client.py
Executable file
@ -0,0 +1,18 @@
|
|||||||
|
import redis.asyncio as redis
|
||||||
|
import os
|
||||||
|
|
||||||
|
class RedisClient:
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
|
||||||
|
cls._instance = redis.from_url(redis_url, decode_responses=True)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def close(cls):
|
||||||
|
if cls._instance:
|
||||||
|
await cls._instance.close()
|
||||||
|
cls._instance = None
|
||||||
41
config/keys_example.yaml
Executable file
41
config/keys_example.yaml
Executable file
@ -0,0 +1,41 @@
|
|||||||
|
model_configs:
|
||||||
|
- id: nvidia-glm-cfg
|
||||||
|
provider: nvidia
|
||||||
|
model_name: z-ai/glm-5.1
|
||||||
|
api_base: https://integrate.api.nvidia.com/v1
|
||||||
|
rpm_limit: 40
|
||||||
|
tpm_limit: 100000
|
||||||
|
max_concurrency: 8
|
||||||
|
params:
|
||||||
|
temperature: 1
|
||||||
|
top_p: 1
|
||||||
|
max_tokens: 16384
|
||||||
|
stream: false
|
||||||
|
chat_template_kwargs:
|
||||||
|
enable_thinking: false
|
||||||
|
clear_thinking: false
|
||||||
|
- id: llama-local-cfg
|
||||||
|
provider: openai
|
||||||
|
model_name: llama-local
|
||||||
|
api_base: http://192.168.2.101:8848/v1
|
||||||
|
api_endpoints:
|
||||||
|
- url: http://192.168.2.101:8848/v1
|
||||||
|
enabled: true
|
||||||
|
- url: http://192.168.1.51:8848/v1
|
||||||
|
enabled: true
|
||||||
|
- url: http://192.168.1.61:8848/v1
|
||||||
|
enabled: true
|
||||||
|
rpm_limit: 999999
|
||||||
|
tpm_limit: 999999999
|
||||||
|
max_concurrency: 999
|
||||||
|
params:
|
||||||
|
temperature: 1.0
|
||||||
|
max_tokens: 6144
|
||||||
|
chat_template_kwargs:
|
||||||
|
enable_thinking: false
|
||||||
|
keys:
|
||||||
|
- config_id: nvidia-glm-cfg
|
||||||
|
enabled: true
|
||||||
|
keys:
|
||||||
|
- key: nvapi-key-1
|
||||||
|
owner: name
|
||||||
41
main.py
Executable file
41
main.py
Executable file
@ -0,0 +1,41 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
|
from app.api.routes import router
|
||||||
|
from app.runtime import stats_exporter
|
||||||
|
from app.utils.redis_client import RedisClient
|
||||||
|
|
||||||
|
app = FastAPI(title="Smart KeyPool Gateway")
|
||||||
|
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
@app.get("/", response_class=HTMLResponse)
|
||||||
|
async def dashboard_home():
|
||||||
|
return await stats_exporter.render_dashboard()
|
||||||
|
|
||||||
|
@app.get("/dashboard", response_class=HTMLResponse)
|
||||||
|
async def dashboard_page():
|
||||||
|
return await stats_exporter.render_dashboard()
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
redis = RedisClient.get_instance()
|
||||||
|
try:
|
||||||
|
await redis.ping()
|
||||||
|
print("Connected to Redis")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to connect to Redis: {e}")
|
||||||
|
|
||||||
|
await stats_exporter.start()
|
||||||
|
print("Stats exporter started - dashboard available at / or /dashboard")
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
await stats_exporter.stop()
|
||||||
|
await RedisClient.close()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "8888")))
|
||||||
6
requirements.txt
Executable file
6
requirements.txt
Executable file
@ -0,0 +1,6 @@
|
|||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
redis
|
||||||
|
pydantic
|
||||||
|
pyyaml
|
||||||
|
httpx
|
||||||
284
test_keys.py
Executable file
284
test_keys.py
Executable file
@ -0,0 +1,284 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试所有启用的API密钥是否正常工作(并发版本)
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import httpx
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
# 添加app目录到路径
|
||||||
|
sys.path.insert(0, 'd:\\DPI\\keypool')
|
||||||
|
|
||||||
|
from app.core.key_manager import KeyManager
|
||||||
|
from app.core.providers import get_provider_handler
|
||||||
|
from app.models import APIKey
|
||||||
|
|
||||||
|
|
||||||
|
class TestStatus(Enum):
|
||||||
|
SUCCESS = "success"
|
||||||
|
FAILED = "failed"
|
||||||
|
SKIPPED = "skipped"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestResult:
|
||||||
|
key_id: str
|
||||||
|
provider: str
|
||||||
|
owner: Optional[str]
|
||||||
|
status: TestStatus
|
||||||
|
message: str
|
||||||
|
response_time: float = 0.0
|
||||||
|
details: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_single_key(api_key: APIKey, timeout: int = 240) -> TestResult:
|
||||||
|
"""测试单个API密钥"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
owner = getattr(api_key, 'owner', None)
|
||||||
|
|
||||||
|
# 如果密钥被禁用,跳过测试
|
||||||
|
if not api_key.enabled:
|
||||||
|
return TestResult(
|
||||||
|
key_id=api_key.id,
|
||||||
|
provider=api_key.provider,
|
||||||
|
owner=owner,
|
||||||
|
status=TestStatus.SKIPPED,
|
||||||
|
message="密钥已禁用 (enabled: false)",
|
||||||
|
response_time=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = get_provider_handler(api_key)
|
||||||
|
|
||||||
|
# 准备测试消息
|
||||||
|
test_messages = [
|
||||||
|
{"role": "user", "content": "Hello, this is a test message. Please reply with 'OK'."}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 获取请求参数
|
||||||
|
headers = provider.get_headers()
|
||||||
|
url = provider.get_url()
|
||||||
|
payload = provider.get_payload(test_messages, {})
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response_time = time.time() - start_time
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
parsed = provider.parse_response(data)
|
||||||
|
content = parsed.get("content", "")
|
||||||
|
|
||||||
|
return TestResult(
|
||||||
|
key_id=api_key.id,
|
||||||
|
provider=api_key.provider,
|
||||||
|
owner=owner,
|
||||||
|
status=TestStatus.SUCCESS,
|
||||||
|
message=f"测试成功,响应时间: {response_time:.2f}s",
|
||||||
|
response_time=response_time,
|
||||||
|
details={
|
||||||
|
"content_preview": content[:100] + "..." if len(content) > 100 else content,
|
||||||
|
"usage": parsed.get("usage", {})
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_text = response.text[:500] if response.text else "No error details"
|
||||||
|
return TestResult(
|
||||||
|
key_id=api_key.id,
|
||||||
|
provider=api_key.provider,
|
||||||
|
owner=owner,
|
||||||
|
status=TestStatus.FAILED,
|
||||||
|
message=f"HTTP错误 {response.status_code}: {error_text}",
|
||||||
|
response_time=time.time() - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
return TestResult(
|
||||||
|
key_id=api_key.id,
|
||||||
|
provider=api_key.provider,
|
||||||
|
owner=owner,
|
||||||
|
status=TestStatus.FAILED,
|
||||||
|
message=f"请求超时 (>{timeout}s)",
|
||||||
|
response_time=time.time() - start_time
|
||||||
|
)
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
return TestResult(
|
||||||
|
key_id=api_key.id,
|
||||||
|
provider=api_key.provider,
|
||||||
|
owner=owner,
|
||||||
|
status=TestStatus.FAILED,
|
||||||
|
message=f"连接错误: {str(e)}",
|
||||||
|
response_time=time.time() - start_time
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return TestResult(
|
||||||
|
key_id=api_key.id,
|
||||||
|
provider=api_key.provider,
|
||||||
|
owner=owner,
|
||||||
|
status=TestStatus.FAILED,
|
||||||
|
message=f"异常: {type(e).__name__}: {str(e)}",
|
||||||
|
response_time=time.time() - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_all_keys_concurrent(config_path: str = "config/keys.yaml", max_concurrency: int = 10):
|
||||||
|
"""并发测试所有启用的密钥"""
|
||||||
|
print("=" * 80)
|
||||||
|
print("API密钥测试工具 (并发模式)")
|
||||||
|
print("=" * 80)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 加载密钥管理器
|
||||||
|
try:
|
||||||
|
km = KeyManager(config_path)
|
||||||
|
all_keys = km.get_all_keys()
|
||||||
|
print(f"共加载 {len(all_keys)} 个密钥配置")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 筛选启用的密钥
|
||||||
|
enabled_keys = [k for k in all_keys if k.enabled]
|
||||||
|
print(f"其中 {len(enabled_keys)} 个密钥已启用")
|
||||||
|
print(f"并发数: {max_concurrency}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if not enabled_keys:
|
||||||
|
print("没有启用的密钥需要测试")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 使用信号量控制并发
|
||||||
|
semaphore = asyncio.Semaphore(max_concurrency)
|
||||||
|
|
||||||
|
async def test_with_semaphore(key: APIKey) -> TestResult:
|
||||||
|
async with semaphore:
|
||||||
|
return await test_single_key(key)
|
||||||
|
|
||||||
|
# 创建所有测试任务
|
||||||
|
tasks = [test_with_semaphore(key) for key in enabled_keys]
|
||||||
|
|
||||||
|
# 并发执行所有测试
|
||||||
|
print("开始并发测试...")
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# 处理可能的异常结果
|
||||||
|
processed_results: List[TestResult] = []
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
key = enabled_keys[i]
|
||||||
|
owner = getattr(key, 'owner', None)
|
||||||
|
processed_results.append(TestResult(
|
||||||
|
key_id=key.id,
|
||||||
|
provider=key.provider,
|
||||||
|
owner=owner,
|
||||||
|
status=TestStatus.FAILED,
|
||||||
|
message=f"测试任务异常: {type(result).__name__}: {str(result)}",
|
||||||
|
response_time=0.0
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
processed_results.append(result)
|
||||||
|
|
||||||
|
results = processed_results
|
||||||
|
|
||||||
|
# 按提供商分组显示结果
|
||||||
|
results_by_provider: Dict[str, List[TestResult]] = {}
|
||||||
|
for r in results:
|
||||||
|
p = r.provider
|
||||||
|
if p not in results_by_provider:
|
||||||
|
results_by_provider[p] = []
|
||||||
|
results_by_provider[p].append(r)
|
||||||
|
|
||||||
|
# 打印详细结果
|
||||||
|
for provider, provider_results in sorted(results_by_provider.items()):
|
||||||
|
print(f"\n{'='*80}")
|
||||||
|
print(f"提供商: {provider.upper()} ({len(provider_results)} 个密钥)")
|
||||||
|
print(f"{'='*80}")
|
||||||
|
|
||||||
|
for r in provider_results:
|
||||||
|
status_icon = "✓" if r.status == TestStatus.SUCCESS else "✗" if r.status == TestStatus.FAILED else "○"
|
||||||
|
owner_str = f" ({r.owner})" if r.owner else ""
|
||||||
|
print(f"\n [{status_icon}] {r.key_id}{owner_str}")
|
||||||
|
print(f" 状态: {r.message}")
|
||||||
|
if r.details and r.status == TestStatus.SUCCESS:
|
||||||
|
print(f" 回复: {r.details['content_preview']}")
|
||||||
|
|
||||||
|
# 打印汇总报告
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("测试汇总报告")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
success_count = sum(1 for r in results if r.status == TestStatus.SUCCESS)
|
||||||
|
failed_count = sum(1 for r in results if r.status == TestStatus.FAILED)
|
||||||
|
skipped_count = sum(1 for r in results if r.status == TestStatus.SKIPPED)
|
||||||
|
|
||||||
|
print(f"\n总计: {len(results)} 个密钥")
|
||||||
|
print(f" ✓ 成功: {success_count}")
|
||||||
|
print(f" ✗ 失败: {failed_count}")
|
||||||
|
print(f" ○ 跳过: {skipped_count}")
|
||||||
|
|
||||||
|
# 按提供商统计
|
||||||
|
print("\n按提供商统计:")
|
||||||
|
provider_stats: Dict[str, Dict[str, int]] = {}
|
||||||
|
for r in results:
|
||||||
|
p = r.provider
|
||||||
|
if p not in provider_stats:
|
||||||
|
provider_stats[p] = {"success": 0, "failed": 0, "skipped": 0}
|
||||||
|
if r.status == TestStatus.SUCCESS:
|
||||||
|
provider_stats[p]["success"] += 1
|
||||||
|
elif r.status == TestStatus.FAILED:
|
||||||
|
provider_stats[p]["failed"] += 1
|
||||||
|
else:
|
||||||
|
provider_stats[p]["skipped"] += 1
|
||||||
|
|
||||||
|
for provider, stats in sorted(provider_stats.items()):
|
||||||
|
total = stats["success"] + stats["failed"] + stats["skipped"]
|
||||||
|
success_rate = (stats["success"] / total * 100) if total > 0 else 0
|
||||||
|
print(f" {provider:15s}: {stats['success']:2d}/{total:2d} 成功 ({success_rate:5.1f}%), {stats['failed']} 失败, {stats['skipped']} 跳过")
|
||||||
|
|
||||||
|
# 失败的详细信息
|
||||||
|
failed_results = [r for r in results if r.status == TestStatus.FAILED]
|
||||||
|
if failed_results:
|
||||||
|
print("\n" + "-" * 80)
|
||||||
|
print("失败的密钥详情:")
|
||||||
|
print("-" * 80)
|
||||||
|
for r in failed_results:
|
||||||
|
print(f"\n [{r.provider}] {r.key_id}")
|
||||||
|
if r.owner:
|
||||||
|
print(f" 拥有者: {r.owner}")
|
||||||
|
print(f" 错误: {r.message}")
|
||||||
|
|
||||||
|
# 成功的响应时间统计
|
||||||
|
success_results = [r for r in results if r.status == TestStatus.SUCCESS and r.response_time > 0]
|
||||||
|
if success_results:
|
||||||
|
avg_time = sum(r.response_time for r in success_results) / len(success_results)
|
||||||
|
min_time = min(r.response_time for r in success_results)
|
||||||
|
max_time = max(r.response_time for r in success_results)
|
||||||
|
print("\n" + "-" * 80)
|
||||||
|
print("响应时间统计 (成功请求):")
|
||||||
|
print("-" * 80)
|
||||||
|
print(f" 平均: {avg_time:.2f}s")
|
||||||
|
print(f" 最快: {min_time:.2f}s")
|
||||||
|
print(f" 最慢: {max_time:.2f}s")
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("测试完成")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 切换到脚本所在目录
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
os.chdir(script_dir)
|
||||||
|
|
||||||
|
# 运行测试(默认并发数10)
|
||||||
|
asyncio.run(test_all_keys_concurrent(max_concurrency=11))
|
||||||
Loading…
Reference in New Issue
Block a user