init
This commit is contained in:
commit
fb7f06ed03
3
.gitignore
vendored
Executable file
3
.gitignore
vendored
Executable file
@ -0,0 +1,3 @@
|
||||
__pycache__/
|
||||
test/*
|
||||
keys.yaml
|
||||
227
app/api/routes.py
Executable file
227
app/api/routes.py
Executable file
@ -0,0 +1,227 @@
|
||||
from fastapi import APIRouter, HTTPException, Request, BackgroundTasks
|
||||
from app.core.providers import get_provider_handler
|
||||
from app.runtime import key_manager, dispatcher, stats_exporter
|
||||
import httpx
|
||||
import time
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/v1/stats")
|
||||
async def get_stats():
|
||||
"""
|
||||
Returns usage statistics for all API keys.
|
||||
"""
|
||||
stats = await dispatcher.get_all_key_stats()
|
||||
return {"data": stats}
|
||||
|
||||
@router.get("/v1/usage")
|
||||
async def get_usage():
|
||||
"""
|
||||
Returns overall usage summary.
|
||||
"""
|
||||
summary = await dispatcher.get_usage_summary()
|
||||
return {"data": summary}
|
||||
|
||||
@router.post("/v1/stats/reset/rpd")
|
||||
async def reset_today_rpd():
|
||||
"""
|
||||
Reset today's RPD counters in Redis and refresh dashboard output.
|
||||
"""
|
||||
day = dispatcher._get_day_str()
|
||||
deleted_keys = await dispatcher.reset_today_rpd_usage()
|
||||
stats_exporter.clear_daily_usage_day(day)
|
||||
await stats_exporter.export_once()
|
||||
return {
|
||||
"success": True,
|
||||
"day": day,
|
||||
"deleted_keys": deleted_keys,
|
||||
"message": f"Reset today's RPD counters for {deleted_keys} keys",
|
||||
}
|
||||
|
||||
@router.post("/v1/keys/config/{config_id}/enable")
|
||||
async def enable_config(config_id: str):
|
||||
"""
|
||||
Enable all keys for a specific config_id (auto-saves to config file).
|
||||
"""
|
||||
success = key_manager.set_config_enabled(config_id, True)
|
||||
if success:
|
||||
return {"success": True, "message": f"Config {config_id} enabled and saved"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"Config {config_id} not found")
|
||||
|
||||
@router.post("/v1/keys/config/{config_id}/disable")
|
||||
async def disable_config(config_id: str):
|
||||
"""
|
||||
Disable all keys for a specific config_id (auto-saves to config file).
|
||||
"""
|
||||
success = key_manager.set_config_enabled(config_id, False)
|
||||
if success:
|
||||
return {"success": True, "message": f"Config {config_id} disabled and saved"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"Config {config_id} not found")
|
||||
|
||||
@router.post("/v1/keys/{key_id}/enable")
|
||||
async def enable_key(key_id: str):
|
||||
"""
|
||||
Enable a specific key.
|
||||
"""
|
||||
success = key_manager.set_key_enabled(key_id, True)
|
||||
if success:
|
||||
return {"success": True, "message": f"Key {key_id} enabled"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"Key {key_id} not found")
|
||||
|
||||
@router.post("/v1/keys/{key_id}/disable")
|
||||
async def disable_key(key_id: str):
|
||||
"""
|
||||
Disable a specific key.
|
||||
"""
|
||||
success = key_manager.set_key_enabled(key_id, False)
|
||||
if success:
|
||||
return {"success": True, "message": f"Key {key_id} disabled"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"Key {key_id} not found")
|
||||
|
||||
|
||||
@router.post("/v1/keys/config/{config_id}/endpoint/{endpoint_idx}/enable")
|
||||
async def enable_endpoint(config_id: str, endpoint_idx: int):
|
||||
"""
|
||||
Enable a specific endpoint for a config (auto-saves to config file).
|
||||
"""
|
||||
success = key_manager.set_endpoint_enabled(config_id, endpoint_idx, True)
|
||||
if success:
|
||||
return {"success": True, "message": f"Endpoint {endpoint_idx} of {config_id} enabled and saved"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_idx} of {config_id} not found")
|
||||
|
||||
|
||||
@router.post("/v1/keys/config/{config_id}/endpoint/{endpoint_idx}/disable")
|
||||
async def disable_endpoint(config_id: str, endpoint_idx: int):
|
||||
"""
|
||||
Disable a specific endpoint for a config (auto-saves to config file).
|
||||
"""
|
||||
success = key_manager.set_endpoint_enabled(config_id, endpoint_idx, False)
|
||||
if success:
|
||||
return {"success": True, "message": f"Endpoint {endpoint_idx} of {config_id} disabled and saved"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_idx} of {config_id} not found")
|
||||
|
||||
@router.get("/v1/configs")
|
||||
async def get_configs():
|
||||
"""
|
||||
Returns all config IDs.
|
||||
"""
|
||||
configs = []
|
||||
for config_id in key_manager.get_config_ids():
|
||||
keys = key_manager.get_keys_by_config(config_id)
|
||||
if keys:
|
||||
configs.append({
|
||||
"config_id": config_id,
|
||||
"model_name": keys[0].model_name,
|
||||
"provider": keys[0].provider,
|
||||
"key_count": len(keys),
|
||||
"enabled": keys[0].enabled
|
||||
})
|
||||
return {"data": configs}
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Unified entry point.
|
||||
Expects JSON body with "model" field.
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
except:
|
||||
logger.error(f"[Invalid JSON] detail={traceback.format_exc()}...")
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON")
|
||||
|
||||
model_name = body.get("model")
|
||||
|
||||
key, error_reason = await dispatcher.select_key(model_name)
|
||||
if not key:
|
||||
logger.error(f"[Key Not Found] model={model_name} | error={error_reason}...")
|
||||
raise HTTPException(status_code=503, detail=error_reason)
|
||||
|
||||
if not model_name:
|
||||
model_name = key.model_name
|
||||
|
||||
acquired = await dispatcher.acquire_lease(key)
|
||||
if not acquired:
|
||||
logger.error(f"[System busy (Concurrency Limit)] model={model_name} | provider={key.provider} | key={key.id}|owner={key.owner or 'N/A'}|key_prefix={key.key[:20] if key.key else 'N/A'}...")
|
||||
raise HTTPException(status_code=503, detail="System busy (Concurrency Limit)")
|
||||
|
||||
try:
|
||||
handler = get_provider_handler(key)
|
||||
url = handler.get_url()
|
||||
headers = handler.get_headers()
|
||||
|
||||
params = {k: v for k, v in body.items() if k not in ["model", "messages", "stream"]}
|
||||
|
||||
payload = handler.get_payload(body.get("messages", []), params)
|
||||
|
||||
timeout = httpx.Timeout(300.0, connect=10.0)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
resp = await client.post(url, headers=headers, json=payload)
|
||||
|
||||
if resp.status_code != 200:
|
||||
error_detail = resp.text[:500] if len(resp.text) > 500 else resp.text
|
||||
logger.error(
|
||||
f"[Request Failed] model={model_name} | key={key.id}|owner={key.owner or 'N/A'}|key_prefix={key.key[:20] if key.key else 'N/A'} | "
|
||||
f"api_base={key.api_base} | "
|
||||
f"status={resp.status_code} | response={error_detail}"
|
||||
)
|
||||
|
||||
if resp.status_code == 429:
|
||||
error_message = ""
|
||||
try:
|
||||
error_data = resp.json()
|
||||
if isinstance(error_data, dict):
|
||||
error_message = error_data.get("error", {}).get("message", "") or error_data.get("message", "")
|
||||
except:
|
||||
error_message = resp.text
|
||||
await dispatcher.report_failure(key, is_rate_limit=True, error_message=error_message)
|
||||
|
||||
raise HTTPException(status_code=resp.status_code, detail=f"Provider Error: {error_detail}")
|
||||
|
||||
data = resp.json()
|
||||
parsed_result = handler.parse_response(data)
|
||||
|
||||
total_tokens = parsed_result["usage"].get("total_tokens", 0)
|
||||
background_tasks.add_task(dispatcher.record_usage, key, total_tokens)
|
||||
|
||||
response_data = {
|
||||
"id": f"chatcmpl-{int(time.time())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": parsed_result["content"],
|
||||
"reasoning_content": parsed_result.get("reasoning_content")
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": parsed_result["usage"]
|
||||
}
|
||||
|
||||
return response_data
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
_key_info = f"{key.id}|owner={key.owner or 'N/A'}|key_prefix={key.key[:20] if key.key else 'N/A'}" if key else 'N/A'
|
||||
logger.exception(
|
||||
f"[Internal Error] model={model_name} | "
|
||||
f"key={_key_info} | api_base={getattr(key, 'api_base', 'N/A') if key else 'N/A'} | error={str(e)}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
await dispatcher.release_lease(key)
|
||||
556
app/core/dispatcher.py
Executable file
556
app/core/dispatcher.py
Executable file
@ -0,0 +1,556 @@
|
||||
import time
|
||||
import asyncio
|
||||
import fnmatch
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, List, Dict, Any
|
||||
from app.models import APIKey, KeyStatus
|
||||
from app.core.key_manager import KeyManager
|
||||
from app.utils.redis_client import RedisClient
|
||||
|
||||
class Dispatcher:
|
||||
DAILY_RPD_TTL_SECONDS = 90 * 86400
|
||||
RATE_LIMIT_PENALTY_SEQ_KEY = "keypool:priority:429:seq"
|
||||
PROVIDER_PRIORITIES = {
|
||||
"modelscope": 0,
|
||||
"nvidia": 0,
|
||||
"siliconflow": 0,
|
||||
"llama-local": 0,
|
||||
}
|
||||
|
||||
def __init__(self, key_manager: KeyManager):
|
||||
self.key_manager = key_manager
|
||||
self.redis = RedisClient.get_instance()
|
||||
|
||||
@staticmethod
|
||||
def _get_day_str(offset_days: int = 0) -> str:
|
||||
current_day = datetime.now(timezone.utc) + timedelta(days=offset_days)
|
||||
return current_day.strftime("%Y%m%d")
|
||||
|
||||
@staticmethod
|
||||
def _get_rate_limit_penalty_key(key_id: str) -> str:
|
||||
return f"keypool:priority:429:{key_id}"
|
||||
|
||||
@classmethod
|
||||
def _get_provider_priority(cls, provider: str) -> int:
|
||||
return cls.PROVIDER_PRIORITIES.get(provider.lower(), len(cls.PROVIDER_PRIORITIES))
|
||||
|
||||
async def _delete_matching_keys(self, pattern: str, batch_size: int = 200) -> int:
|
||||
if hasattr(self.redis, "scan_iter"):
|
||||
batch = []
|
||||
deleted = 0
|
||||
async for key in self.redis.scan_iter(match=pattern, count=batch_size):
|
||||
batch.append(key)
|
||||
if len(batch) >= batch_size:
|
||||
deleted += await self.redis.delete(*batch)
|
||||
batch.clear()
|
||||
if batch:
|
||||
deleted += await self.redis.delete(*batch)
|
||||
return deleted
|
||||
|
||||
if not hasattr(self.redis, "values"):
|
||||
raise RuntimeError("Redis client does not support scan_iter")
|
||||
|
||||
keys = [key for key in self.redis.values if fnmatch.fnmatch(key, pattern)]
|
||||
return await self.redis.delete(*keys) if keys else 0
|
||||
|
||||
async def reset_today_rpd_usage(self) -> int:
|
||||
day = self._get_day_str()
|
||||
return await self._delete_matching_keys(f"keypool:usage:rpd:*:{day}")
|
||||
|
||||
async def get_key_stats(self, key: APIKey):
|
||||
"""Fetches runtime stats from Redis for a single key."""
|
||||
concurrency_key = f"keypool:concurrency:{key.id}"
|
||||
penalty_key = self._get_rate_limit_penalty_key(key.id)
|
||||
|
||||
current_day = self._get_day_str()
|
||||
rpd_key = f"keypool:usage:rpd:{key.id}:{current_day}"
|
||||
|
||||
# Pipeline for efficiency
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.get(concurrency_key)
|
||||
pipe.get(penalty_key)
|
||||
pipe.get(rpd_key)
|
||||
results = await pipe.execute()
|
||||
|
||||
current_concurrency = int(results[0]) if results[0] else 0
|
||||
rate_limit_penalty = int(results[1]) if results[1] else 0
|
||||
current_rpd = int(results[2]) if results[2] else 0
|
||||
|
||||
key.current_concurrency = current_concurrency
|
||||
key.rate_limit_penalty = rate_limit_penalty
|
||||
key.current_rpd = current_rpd
|
||||
return key
|
||||
|
||||
async def select_key(self, model_name: Optional[str] = None, timeout: float = 30.0) -> tuple[Optional[APIKey], str]:
|
||||
"""
|
||||
Selects the best available key for the given model.
|
||||
If model_name is None, selects from ALL available keys.
|
||||
|
||||
If no keys are available immediately, it will wait (queue) up to `timeout` seconds.
|
||||
|
||||
Strategy:
|
||||
1. Try to find an available key immediately.
|
||||
2. If none found, wait for a short interval and retry (polling/queue simulation).
|
||||
In a production system, we might use a real Redis list/pubsub for queuing,
|
||||
but for now, we'll use a local async loop with exponential backoff or simple polling.
|
||||
|
||||
Returns: (key, error_reason)
|
||||
|
||||
Note: If model_name is explicitly specified (force_select=True), the status check
|
||||
for DISABLED keys is bypassed, allowing access to disabled models.
|
||||
"""
|
||||
start_time = time.time()
|
||||
# If model_name is explicitly specified, force select (bypass enabled check)
|
||||
force_select = bool(model_name)
|
||||
|
||||
while True:
|
||||
# 1. Try to select a key
|
||||
key, error = await self._try_select_key_once(model_name, force_select=force_select)
|
||||
if key:
|
||||
return key, ""
|
||||
|
||||
# 2. Check timeout
|
||||
if time.time() - start_time > timeout:
|
||||
return None, f"Timeout waiting for available key: {error}"
|
||||
|
||||
# 3. Wait before retrying (simple polling).
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def _try_select_key_once(self, model_name: Optional[str] = None, force_select: bool = False) -> tuple[Optional[APIKey], str]:
|
||||
"""
|
||||
Internal method to try selecting a key once without waiting.
|
||||
Strategy:
|
||||
1. Filter by status and hard limits.
|
||||
2. Prefer keys that never hit 429.
|
||||
3. Prefer providers by static rank: modelscope > nvidia > siliconflow.
|
||||
4. For keys that did hit 429, use the demotion order as a queue.
|
||||
5. Use load as the final tie breaker inside the same priority tier.
|
||||
|
||||
Args:
|
||||
model_name: Optional model name to filter by
|
||||
force_select: If True, bypass the status check (allows accessing disabled models)
|
||||
|
||||
Returns: (key, error_reason)
|
||||
"""
|
||||
candidates = self.key_manager.get_candidate_keys(model_name)
|
||||
if not candidates:
|
||||
return None, "No keys configured for this model" if model_name else "No keys configured in the pool"
|
||||
|
||||
# Prepare keys for mget
|
||||
concurrency_keys = [f"keypool:concurrency:{k.id}" for k in candidates]
|
||||
penalty_keys = [self._get_rate_limit_penalty_key(k.id) for k in candidates]
|
||||
|
||||
now = time.time()
|
||||
current_minute = int(now / 60)
|
||||
prev_minute = current_minute - 1
|
||||
|
||||
rpm_curr_keys = [f"keypool:usage:rpm:{k.id}:{current_minute}" for k in candidates]
|
||||
rpm_prev_keys = [f"keypool:usage:rpm:{k.id}:{prev_minute}" for k in candidates]
|
||||
|
||||
tpm_curr_keys = [f"keypool:usage:tpm:{k.id}:{current_minute}" for k in candidates]
|
||||
tpm_prev_keys = [f"keypool:usage:tpm:{k.id}:{prev_minute}" for k in candidates]
|
||||
|
||||
current_day = self._get_day_str()
|
||||
rpd_keys = [f"keypool:usage:rpd:{k.id}:{current_day}" for k in candidates]
|
||||
|
||||
all_keys = concurrency_keys + penalty_keys + rpm_curr_keys + rpm_prev_keys + tpm_curr_keys + tpm_prev_keys + rpd_keys
|
||||
if not all_keys:
|
||||
return None, "No keys available (empty stats)"
|
||||
|
||||
values = await self.redis.mget(all_keys)
|
||||
|
||||
n = len(candidates)
|
||||
concurrency_values = values[:n]
|
||||
penalty_values = values[n:2*n]
|
||||
rpm_curr_values = values[2*n:3*n]
|
||||
rpm_prev_values = values[3*n:4*n]
|
||||
tpm_curr_values = values[4*n:5*n]
|
||||
tpm_prev_values = values[5*n:6*n]
|
||||
rpd_values = values[6*n:7*n]
|
||||
|
||||
# Calculate window weight
|
||||
# If we are at second 0 of the minute, weight of prev is 1.0 (actually 0.99...)
|
||||
# If we are at second 59, weight of prev is ~0
|
||||
# Formula: weight = (60 - seconds_elapsed) / 60
|
||||
seconds_elapsed = now % 60
|
||||
prev_weight = (60 - seconds_elapsed) / 60
|
||||
|
||||
available_keys = []
|
||||
rejection_reasons = []
|
||||
|
||||
for i, key in enumerate(candidates):
|
||||
current_conc = int(concurrency_values[i]) if concurrency_values[i] else 0
|
||||
rate_limit_penalty = int(penalty_values[i]) if penalty_values[i] else 0
|
||||
|
||||
r_curr = int(rpm_curr_values[i]) if rpm_curr_values[i] else 0
|
||||
r_prev = int(rpm_prev_values[i]) if rpm_prev_values[i] else 0
|
||||
|
||||
t_curr = int(tpm_curr_values[i]) if tpm_curr_values[i] else 0
|
||||
t_prev = int(tpm_prev_values[i]) if tpm_prev_values[i] else 0
|
||||
|
||||
current_rpd = int(rpd_values[i]) if rpd_values[i] else 0
|
||||
|
||||
current_rpm = int(r_curr + r_prev * prev_weight)
|
||||
current_tpm = int(t_curr + t_prev * prev_weight)
|
||||
|
||||
key.current_concurrency = current_conc
|
||||
key.rate_limit_penalty = rate_limit_penalty
|
||||
key.current_rpm = current_rpm
|
||||
key.current_tpm = current_tpm
|
||||
key.current_rpd = current_rpd
|
||||
|
||||
if key.status != KeyStatus.ACTIVE:
|
||||
# If force_select is True (explicit model specified), bypass status check
|
||||
if not force_select:
|
||||
rejection_reasons.append(f"{key.id}({key.owner}): Inactive (status={key.status})")
|
||||
continue
|
||||
|
||||
if key.current_concurrency >= key.max_concurrency:
|
||||
rejection_reasons.append(f"{key.id}({key.owner}): Concurrency limit ({key.current_concurrency}/{key.max_concurrency})")
|
||||
continue
|
||||
if current_rpm >= key.rpm_limit:
|
||||
rejection_reasons.append(f"{key.id}({key.owner}): RPM limit ({current_rpm}/{key.rpm_limit})")
|
||||
continue
|
||||
if current_tpm >= key.tpm_limit:
|
||||
rejection_reasons.append(f"{key.id}({key.owner}): TPM limit ({current_tpm}/{key.tpm_limit})")
|
||||
continue
|
||||
if key.rpd_limit > 0 and current_rpd >= key.rpd_limit:
|
||||
rejection_reasons.append(f"{key.id}({key.owner}): RPD limit ({current_rpd}/{key.rpd_limit})")
|
||||
continue
|
||||
|
||||
available_keys.append(key)
|
||||
|
||||
if not available_keys:
|
||||
reason = "; ".join(rejection_reasons) if rejection_reasons else "No available keys found"
|
||||
return None, f"All keys unavailable: {reason}"
|
||||
|
||||
def calculate_load_score(k: APIKey) -> float:
|
||||
conc_ratio = k.current_concurrency / k.max_concurrency if k.max_concurrency > 0 else 0
|
||||
rpm_ratio = k.current_rpm / k.rpm_limit if k.rpm_limit > 0 else 0
|
||||
tpm_ratio = k.current_tpm / k.tpm_limit if k.tpm_limit > 0 else 0
|
||||
rpd_ratio = k.current_rpd / k.rpd_limit if k.rpd_limit > 0 else 0
|
||||
return max(conc_ratio, rpm_ratio, tpm_ratio, rpd_ratio)
|
||||
|
||||
def selection_key(k: APIKey) -> tuple[int, int, int, float]:
|
||||
provider_priority = self._get_provider_priority(k.provider)
|
||||
if k.rate_limit_penalty > 0:
|
||||
return (1, k.rate_limit_penalty, provider_priority, calculate_load_score(k))
|
||||
return (0, provider_priority, 0, calculate_load_score(k))
|
||||
|
||||
best_key = min(available_keys, key=selection_key)
|
||||
|
||||
return best_key, ""
|
||||
|
||||
async def acquire_lease(self, key: APIKey) -> bool:
|
||||
"""
|
||||
Increments concurrency, RPM and RPD for the key.
|
||||
Double check limit in Redis to be safe (race conditions).
|
||||
"""
|
||||
key_concurrency_key = f"keypool:concurrency:{key.id}"
|
||||
|
||||
now = time.time()
|
||||
current_minute = int(now / 60)
|
||||
prev_minute = current_minute - 1
|
||||
current_day = self._get_day_str()
|
||||
|
||||
key_rpm_key = f"keypool:usage:rpm:{key.id}:{current_minute}"
|
||||
key_rpm_prev_key = f"keypool:usage:rpm:{key.id}:{prev_minute}"
|
||||
key_rpd_key = f"keypool:usage:rpd:{key.id}:{current_day}"
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.incr(key_concurrency_key)
|
||||
pipe.incr(key_rpm_key)
|
||||
pipe.expire(key_rpm_key, 65)
|
||||
pipe.get(key_rpm_prev_key)
|
||||
pipe.incr(key_rpd_key)
|
||||
# Keep daily counters for two days so the dashboard can read yesterday's RPD.
|
||||
pipe.expire(key_rpd_key, self.DAILY_RPD_TTL_SECONDS)
|
||||
|
||||
results = await pipe.execute()
|
||||
curr_conc = results[0]
|
||||
curr_rpm_val = results[1]
|
||||
prev_rpm_val = int(results[3]) if results[3] else 0
|
||||
curr_rpd_val = results[4]
|
||||
|
||||
seconds_elapsed = now % 60
|
||||
prev_weight = (60 - seconds_elapsed) / 60
|
||||
estimated_rpm = curr_rpm_val + (prev_rpm_val * prev_weight)
|
||||
|
||||
if curr_conc > key.max_concurrency:
|
||||
await self.redis.decr(key_concurrency_key)
|
||||
await self.redis.decr(key_rpm_key)
|
||||
await self.redis.decr(key_rpd_key)
|
||||
return False
|
||||
|
||||
if estimated_rpm > key.rpm_limit:
|
||||
await self.redis.decr(key_concurrency_key)
|
||||
await self.redis.decr(key_rpm_key)
|
||||
await self.redis.decr(key_rpd_key)
|
||||
return False
|
||||
|
||||
if key.rpd_limit > 0 and curr_rpd_val > key.rpd_limit:
|
||||
await self.redis.decr(key_concurrency_key)
|
||||
await self.redis.decr(key_rpm_key)
|
||||
await self.redis.decr(key_rpd_key)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def release_lease(self, key: APIKey):
|
||||
"""Decrements concurrency."""
|
||||
key_concurrency_key = f"keypool:concurrency:{key.id}"
|
||||
await self.redis.decr(key_concurrency_key)
|
||||
|
||||
async def record_usage(self, key: APIKey, tokens: int):
|
||||
"""Records Token usage (TPM)."""
|
||||
if tokens <= 0:
|
||||
return
|
||||
|
||||
current_minute = int(time.time() / 60)
|
||||
key_tpm_key = f"keypool:usage:tpm:{key.id}:{current_minute}"
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.incrby(key_tpm_key, tokens)
|
||||
pipe.expire(key_tpm_key, 65)
|
||||
await pipe.execute()
|
||||
|
||||
|
||||
async def report_failure(self, key: APIKey, is_rate_limit: bool = False, error_message: str = ""):
|
||||
"""
|
||||
Reports a failure. A 429 moves the key to the back of the priority queue.
|
||||
"""
|
||||
del error_message
|
||||
if is_rate_limit:
|
||||
seq = await self.redis.incr(self.RATE_LIMIT_PENALTY_SEQ_KEY)
|
||||
await self.redis.set(self._get_rate_limit_penalty_key(key.id), seq)
|
||||
|
||||
async def get_all_key_stats(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieves statistics for all keys, including status, concurrency, and usage limits.
|
||||
"""
|
||||
keys = self.key_manager.get_all_keys()
|
||||
if not keys:
|
||||
return []
|
||||
|
||||
# Prepare keys for batch fetching from Redis
|
||||
now = time.time()
|
||||
current_minute = int(now / 60)
|
||||
prev_minute = current_minute - 1
|
||||
current_day = self._get_day_str()
|
||||
|
||||
redis_keys = []
|
||||
for key in keys:
|
||||
redis_keys.extend([
|
||||
f"keypool:concurrency:{key.id}",
|
||||
self._get_rate_limit_penalty_key(key.id),
|
||||
f"keypool:usage:rpm:{key.id}:{current_minute}",
|
||||
f"keypool:usage:rpm:{key.id}:{prev_minute}",
|
||||
f"keypool:usage:tpm:{key.id}:{current_minute}",
|
||||
f"keypool:usage:tpm:{key.id}:{prev_minute}",
|
||||
f"keypool:usage:rpd:{key.id}:{current_day}"
|
||||
])
|
||||
|
||||
if not redis_keys:
|
||||
return []
|
||||
|
||||
# Fetch all values in a single round-trip
|
||||
values = await self.redis.mget(redis_keys)
|
||||
|
||||
stats = []
|
||||
|
||||
# Process results in chunks of 7
|
||||
seconds_elapsed = now % 60
|
||||
prev_weight = (60 - seconds_elapsed) / 60
|
||||
|
||||
for i, key in enumerate(keys):
|
||||
base_idx = i * 7
|
||||
|
||||
curr_conc = int(values[base_idx]) if values[base_idx] else 0
|
||||
rate_limit_penalty = int(values[base_idx + 1]) if values[base_idx + 1] else 0
|
||||
|
||||
rpm_curr = int(values[base_idx + 2]) if values[base_idx + 2] else 0
|
||||
rpm_prev = int(values[base_idx + 3]) if values[base_idx + 3] else 0
|
||||
|
||||
tpm_curr = int(values[base_idx + 4]) if values[base_idx + 4] else 0
|
||||
tpm_prev = int(values[base_idx + 5]) if values[base_idx + 5] else 0
|
||||
|
||||
rpd_curr = int(values[base_idx + 6]) if values[base_idx + 6] else 0
|
||||
|
||||
estimated_rpm = int(rpm_curr + rpm_prev * prev_weight)
|
||||
estimated_tpm = int(tpm_curr + tpm_prev * prev_weight)
|
||||
|
||||
# Determine effective status
|
||||
status = key.status
|
||||
if key.status == KeyStatus.ACTIVE and rate_limit_penalty > 0:
|
||||
status = KeyStatus.COOLDOWN
|
||||
|
||||
key_stat = {
|
||||
"id": key.id,
|
||||
"model_name": key.model_name,
|
||||
"provider": key.provider,
|
||||
"status": status,
|
||||
"enabled": key.enabled,
|
||||
"config_id": getattr(key, 'config_id', None),
|
||||
"owner": getattr(key, 'owner', None),
|
||||
"endpoint_idx": getattr(key, 'endpoint_idx', None),
|
||||
"api_base": getattr(key, 'api_base', None),
|
||||
"limits": {
|
||||
"rpm": key.rpm_limit,
|
||||
"tpm": key.tpm_limit,
|
||||
"rpd": key.rpd_limit,
|
||||
"max_concurrency": key.max_concurrency
|
||||
},
|
||||
"usage": {
|
||||
"current_concurrency": curr_conc,
|
||||
"current_rpm": estimated_rpm,
|
||||
"current_tpm": estimated_tpm,
|
||||
"current_rpd": rpd_curr
|
||||
},
|
||||
"cooldown_remaining": 0
|
||||
}
|
||||
stats.append(key_stat)
|
||||
|
||||
return stats
|
||||
|
||||
async def get_usage_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get overall usage summary across all keys.
|
||||
Returns total RPM, TPM, RPD, concurrency, and active keys count.
|
||||
"""
|
||||
keys = self.key_manager.get_all_keys()
|
||||
if not keys:
|
||||
return {
|
||||
"total_rpm": 0,
|
||||
"total_tpm": 0,
|
||||
"total_rpd": 0,
|
||||
"total_yesterday_rpd": 0,
|
||||
"total_concurrency": 0,
|
||||
"active_keys": 0,
|
||||
"cooldown_keys": 0,
|
||||
"disabled_keys": 0,
|
||||
"total_keys": 0
|
||||
}
|
||||
|
||||
now = time.time()
|
||||
current_minute = int(now / 60)
|
||||
prev_minute = current_minute - 1
|
||||
current_day = self._get_day_str()
|
||||
yesterday_day = self._get_day_str(-1)
|
||||
|
||||
redis_keys = []
|
||||
for key in keys:
|
||||
redis_keys.extend([
|
||||
f"keypool:concurrency:{key.id}",
|
||||
self._get_rate_limit_penalty_key(key.id),
|
||||
f"keypool:usage:rpm:{key.id}:{current_minute}",
|
||||
f"keypool:usage:rpm:{key.id}:{prev_minute}",
|
||||
f"keypool:usage:tpm:{key.id}:{current_minute}",
|
||||
f"keypool:usage:tpm:{key.id}:{prev_minute}",
|
||||
f"keypool:usage:rpd:{key.id}:{current_day}",
|
||||
f"keypool:usage:rpd:{key.id}:{yesterday_day}"
|
||||
])
|
||||
|
||||
if not redis_keys:
|
||||
return {
|
||||
"total_rpm": 0,
|
||||
"total_tpm": 0,
|
||||
"total_rpd": 0,
|
||||
"total_yesterday_rpd": 0,
|
||||
"total_concurrency": 0,
|
||||
"active_keys": 0,
|
||||
"cooldown_keys": 0,
|
||||
"disabled_keys": 0,
|
||||
"total_keys": 0
|
||||
}
|
||||
|
||||
values = await self.redis.mget(redis_keys)
|
||||
|
||||
seconds_elapsed = now % 60
|
||||
prev_weight = (60 - seconds_elapsed) / 60
|
||||
|
||||
total_rpm = 0
|
||||
total_tpm = 0
|
||||
total_rpd = 0
|
||||
total_yesterday_rpd = 0
|
||||
total_concurrency = 0
|
||||
active_keys = 0
|
||||
cooldown_keys = 0
|
||||
disabled_keys = 0
|
||||
|
||||
for i, key in enumerate(keys):
|
||||
base_idx = i * 8
|
||||
|
||||
curr_conc = int(values[base_idx]) if values[base_idx] else 0
|
||||
rate_limit_penalty = int(values[base_idx + 1]) if values[base_idx + 1] else 0
|
||||
|
||||
rpm_curr = int(values[base_idx + 2]) if values[base_idx + 2] else 0
|
||||
rpm_prev = int(values[base_idx + 3]) if values[base_idx + 3] else 0
|
||||
tpm_curr = int(values[base_idx + 4]) if values[base_idx + 4] else 0
|
||||
tpm_prev = int(values[base_idx + 5]) if values[base_idx + 5] else 0
|
||||
rpd_curr = int(values[base_idx + 6]) if values[base_idx + 6] else 0
|
||||
rpd_yesterday = int(values[base_idx + 7]) if values[base_idx + 7] else 0
|
||||
|
||||
estimated_rpm = int(rpm_curr + rpm_prev * prev_weight)
|
||||
estimated_tpm = int(tpm_curr + tpm_prev * prev_weight)
|
||||
|
||||
total_rpm += estimated_rpm
|
||||
total_tpm += estimated_tpm
|
||||
total_rpd += rpd_curr
|
||||
total_yesterday_rpd += rpd_yesterday
|
||||
total_concurrency += curr_conc
|
||||
|
||||
if key.status == KeyStatus.ACTIVE and rate_limit_penalty == 0:
|
||||
active_keys += 1
|
||||
elif key.status == KeyStatus.DISABLED:
|
||||
disabled_keys += 1
|
||||
elif rate_limit_penalty > 0:
|
||||
cooldown_keys += 1
|
||||
else:
|
||||
active_keys += 1
|
||||
|
||||
return {
|
||||
"total_rpm": total_rpm,
|
||||
"total_tpm": total_tpm,
|
||||
"total_rpd": total_rpd,
|
||||
"total_yesterday_rpd": total_yesterday_rpd,
|
||||
"total_concurrency": total_concurrency,
|
||||
"active_keys": active_keys,
|
||||
"cooldown_keys": cooldown_keys,
|
||||
"disabled_keys": disabled_keys,
|
||||
"total_keys": len(keys)
|
||||
}
|
||||
|
||||
async def get_daily_usage_history(self, days: int = 7) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Aggregate total RPD usage for the most recent N days.
|
||||
Uses the same UTC day buckets as runtime RPD limiting.
|
||||
"""
|
||||
keys = self.key_manager.get_all_keys()
|
||||
if not keys or days <= 0:
|
||||
return []
|
||||
|
||||
day_labels = [self._get_day_str(-offset) for offset in range(days - 1, -1, -1)]
|
||||
|
||||
redis_keys = []
|
||||
for day_str in day_labels:
|
||||
for key in keys:
|
||||
redis_keys.append(f"keypool:usage:rpd:{key.id}:{day_str}")
|
||||
|
||||
values = await self.redis.mget(redis_keys)
|
||||
|
||||
history = []
|
||||
key_count = len(keys)
|
||||
for day_index, day_str in enumerate(day_labels):
|
||||
base_idx = day_index * key_count
|
||||
day_total = 0
|
||||
for key_offset in range(key_count):
|
||||
value = values[base_idx + key_offset]
|
||||
day_total += int(value) if value else 0
|
||||
|
||||
history.append({
|
||||
"day": day_str,
|
||||
"label": f"{day_str[4:6]}-{day_str[6:8]}",
|
||||
"total_rpd": day_total,
|
||||
})
|
||||
|
||||
return history
|
||||
294
app/core/key_manager.py
Executable file
294
app/core/key_manager.py
Executable file
@ -0,0 +1,294 @@
|
||||
import yaml
|
||||
import time
|
||||
from typing import List, Dict, Optional, Set, Any, Tuple
|
||||
from app.models import APIKey, KeyStatus
|
||||
|
||||
class KeyManager:
|
||||
def __init__(self, config_path: str):
|
||||
self.config_path = config_path
|
||||
self._keys: Dict[str, APIKey] = {}
|
||||
self._keys_by_model: Dict[str, List[APIKey]] = {}
|
||||
self._config_ids: Set[str] = set()
|
||||
self._last_load_time: float = 0
|
||||
self._raw_config: Dict[str, Any] = {}
|
||||
self._config_enabled_map: Dict[str, bool] = {}
|
||||
# Track endpoint enable status: {(config_id, idx): enabled}
|
||||
self._endpoint_enabled_map: Dict[Tuple[str, int], bool] = {}
|
||||
# Track endpoint URLs: {(config_id, idx): url}
|
||||
self._endpoint_urls: Dict[Tuple[str, int], str] = {}
|
||||
self.load_keys()
|
||||
|
||||
def load_keys(self):
|
||||
"""Loads keys from the YAML configuration file."""
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
self._raw_config = data
|
||||
self._keys = {}
|
||||
self._keys_by_model = {}
|
||||
self._config_ids = set()
|
||||
self._config_enabled_map = {}
|
||||
|
||||
configs = {}
|
||||
for cfg in data.get('model_configs', []):
|
||||
if 'id' in cfg:
|
||||
configs[cfg['id']] = cfg
|
||||
|
||||
for key_group in data.get('keys', []):
|
||||
config_id = key_group.get('config_id')
|
||||
group_enabled = key_group.get('enabled', True)
|
||||
keys_list = key_group.get('keys', [])
|
||||
|
||||
if config_id:
|
||||
self._config_ids.add(config_id)
|
||||
self._config_enabled_map[config_id] = group_enabled
|
||||
|
||||
if config_id and config_id in configs:
|
||||
cfg = configs[config_id]
|
||||
cfg_copy = cfg.copy()
|
||||
cfg_copy.pop('id', None)
|
||||
|
||||
# Check if this config has multiple endpoints (for local llama.cpp load balancing)
|
||||
api_endpoints = cfg_copy.get('api_endpoints', [])
|
||||
|
||||
if api_endpoints and len(api_endpoints) > 0:
|
||||
# Generate virtual keys for each endpoint
|
||||
for idx, endpoint_item in enumerate(api_endpoints):
|
||||
# Handle both old format (string URL) and new format (dict with url/enabled)
|
||||
if isinstance(endpoint_item, dict):
|
||||
endpoint_url = endpoint_item.get('url', '')
|
||||
endpoint_enabled = endpoint_item.get('enabled', True)
|
||||
else:
|
||||
endpoint_url = endpoint_item
|
||||
endpoint_enabled = True
|
||||
|
||||
# Store endpoint info for dashboard
|
||||
self._endpoint_enabled_map[(config_id, idx)] = endpoint_enabled
|
||||
self._endpoint_urls[(config_id, idx)] = endpoint_url
|
||||
|
||||
final_data = {}
|
||||
final_data.update(cfg_copy)
|
||||
final_data.pop('api_endpoints', None) # Remove the list from individual key
|
||||
final_data['config_id'] = config_id
|
||||
# Endpoint is enabled only if both group and endpoint itself are enabled
|
||||
final_data['enabled'] = group_enabled and endpoint_enabled
|
||||
final_data['api_base'] = endpoint_url # Use individual endpoint
|
||||
final_data['key'] = f"local-{config_id}-{idx}" # Virtual key for local service
|
||||
final_data['endpoint_idx'] = idx # Track endpoint index
|
||||
|
||||
# Virtual key ID that includes endpoint info for tracking
|
||||
endpoint_hash = hash(endpoint_url) & 0xFFFF
|
||||
key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}"
|
||||
final_data['id'] = key_id
|
||||
|
||||
if not group_enabled or not endpoint_enabled:
|
||||
final_data['status'] = KeyStatus.DISABLED
|
||||
|
||||
try:
|
||||
key = APIKey(**final_data)
|
||||
self._keys[key.id] = key
|
||||
|
||||
if key.model_name not in self._keys_by_model:
|
||||
self._keys_by_model[key.model_name] = []
|
||||
self._keys_by_model[key.model_name].append(key)
|
||||
except Exception as e:
|
||||
print(f"Error loading virtual key {key_id} for endpoint {endpoint_url}: {e}")
|
||||
else:
|
||||
# Standard key loading (original logic)
|
||||
for key_data in keys_list:
|
||||
final_data = {}
|
||||
final_data.update(cfg_copy)
|
||||
final_data['config_id'] = config_id
|
||||
final_data['enabled'] = group_enabled
|
||||
|
||||
if key_data.get('owner'):
|
||||
final_data['owner'] = key_data['owner']
|
||||
|
||||
final_data['key'] = key_data['key']
|
||||
|
||||
if not group_enabled:
|
||||
final_data['status'] = KeyStatus.DISABLED
|
||||
|
||||
key_id = f"{config_id}:{hash(key_data['key']) & 0xFFFFFFFF}"
|
||||
final_data['id'] = key_id
|
||||
|
||||
try:
|
||||
key = APIKey(**final_data)
|
||||
self._keys[key.id] = key
|
||||
|
||||
if key.model_name not in self._keys_by_model:
|
||||
self._keys_by_model[key.model_name] = []
|
||||
self._keys_by_model[key.model_name].append(key)
|
||||
except Exception as e:
|
||||
print(f"Error loading key {key_id}: {e}")
|
||||
else:
|
||||
for key_data in keys_list:
|
||||
final_data = {}
|
||||
final_data.update(key_group)
|
||||
final_data.pop('keys', None)
|
||||
final_data.update(key_data)
|
||||
|
||||
if 'id' not in final_data:
|
||||
final_data['id'] = key_data.get('key', 'unknown')[:16]
|
||||
|
||||
if not final_data.get('enabled', True):
|
||||
final_data['status'] = KeyStatus.DISABLED
|
||||
|
||||
try:
|
||||
key = APIKey(**final_data)
|
||||
self._keys[key.id] = key
|
||||
|
||||
if key.model_name not in self._keys_by_model:
|
||||
self._keys_by_model[key.model_name] = []
|
||||
self._keys_by_model[key.model_name].append(key)
|
||||
except Exception as e:
|
||||
print(f"Error loading key {final_data.get('id', 'unknown')}: {e}")
|
||||
|
||||
self._last_load_time = time.time()
|
||||
|
||||
def _save_config(self) -> bool:
|
||||
"""Save current configuration to YAML file."""
|
||||
try:
|
||||
for key_group in self._raw_config.get('keys', []):
|
||||
config_id = key_group.get('config_id')
|
||||
if config_id and config_id in self._config_enabled_map:
|
||||
key_group['enabled'] = self._config_enabled_map[config_id]
|
||||
|
||||
# Save endpoint enable status
|
||||
for (cfg_id, idx), enabled in self._endpoint_enabled_map.items():
|
||||
for cfg in self._raw_config.get('model_configs', []):
|
||||
if cfg.get('id') == cfg_id:
|
||||
endpoints = cfg.get('api_endpoints', [])
|
||||
if idx < len(endpoints):
|
||||
if isinstance(endpoints[idx], dict):
|
||||
endpoints[idx]['enabled'] = enabled
|
||||
else:
|
||||
# Convert string to dict with url and enabled
|
||||
endpoints[idx] = {'url': endpoints[idx], 'enabled': enabled}
|
||||
|
||||
with open(self.config_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(self._raw_config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error saving config: {e}")
|
||||
return False
|
||||
|
||||
def reload_keys(self) -> bool:
|
||||
"""Reload keys from config file. Returns True if successful."""
|
||||
try:
|
||||
self.load_keys()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error reloading keys: {e}")
|
||||
return False
|
||||
|
||||
def set_config_enabled(self, config_id: str, enabled: bool) -> bool:
|
||||
"""Enable or disable all keys for a config_id and save to file."""
|
||||
if config_id not in self._config_ids:
|
||||
return False
|
||||
|
||||
self._config_enabled_map[config_id] = enabled
|
||||
|
||||
updated = False
|
||||
for key in self._keys.values():
|
||||
if hasattr(key, 'config_id') and key.config_id == config_id:
|
||||
key.enabled = enabled
|
||||
if enabled:
|
||||
key.status = KeyStatus.ACTIVE
|
||||
else:
|
||||
key.status = KeyStatus.DISABLED
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
self._save_config()
|
||||
|
||||
return updated
|
||||
|
||||
def set_key_enabled(self, key_id: str, enabled: bool) -> bool:
|
||||
"""Enable or disable a specific key."""
|
||||
if key_id not in self._keys:
|
||||
return False
|
||||
|
||||
key = self._keys[key_id]
|
||||
key.enabled = enabled
|
||||
if enabled:
|
||||
key.status = KeyStatus.ACTIVE
|
||||
else:
|
||||
key.status = KeyStatus.DISABLED
|
||||
|
||||
return True
|
||||
|
||||
def get_config_ids(self) -> Set[str]:
|
||||
"""Get all config IDs."""
|
||||
return self._config_ids.copy()
|
||||
|
||||
def get_keys_by_config(self, config_id: str) -> List[APIKey]:
|
||||
"""Get all keys for a specific config_id."""
|
||||
return [k for k in self._keys.values() if hasattr(k, 'config_id') and k.config_id == config_id]
|
||||
|
||||
def get_key(self, key_id: str) -> Optional[APIKey]:
|
||||
return self._keys.get(key_id)
|
||||
|
||||
def get_candidate_keys(self, model_name: Optional[str] = None) -> List[APIKey]:
|
||||
"""Returns all keys for a given model, regardless of status. If model_name is None, returns all keys."""
|
||||
if not model_name:
|
||||
return self.get_all_keys()
|
||||
return self._keys_by_model.get(model_name, [])
|
||||
|
||||
def get_all_keys(self) -> List[APIKey]:
|
||||
return list(self._keys.values())
|
||||
|
||||
def get_last_load_time(self) -> float:
|
||||
return self._last_load_time
|
||||
|
||||
def get_config_enabled(self, config_id: str) -> bool:
|
||||
"""Get enabled status for a config_id."""
|
||||
return self._config_enabled_map.get(config_id, True)
|
||||
|
||||
def get_endpoint_info(self, config_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all endpoint info for a config_id."""
|
||||
endpoints = []
|
||||
for (cfg_id, idx), url in self._endpoint_urls.items():
|
||||
if cfg_id == config_id:
|
||||
enabled = self._endpoint_enabled_map.get((cfg_id, idx), True)
|
||||
endpoint_hash = hash(url) & 0xFFFF
|
||||
key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}"
|
||||
endpoints.append({
|
||||
'idx': idx,
|
||||
'url': url,
|
||||
'enabled': enabled,
|
||||
'key_id': key_id
|
||||
})
|
||||
return sorted(endpoints, key=lambda x: x['idx'])
|
||||
|
||||
def set_endpoint_enabled(self, config_id: str, endpoint_idx: int, enabled: bool) -> bool:
|
||||
"""Enable or disable a specific endpoint and save to file."""
|
||||
if (config_id, endpoint_idx) not in self._endpoint_enabled_map:
|
||||
return False
|
||||
|
||||
self._endpoint_enabled_map[(config_id, endpoint_idx)] = enabled
|
||||
|
||||
# Update the corresponding key's status
|
||||
url = self._endpoint_urls.get((config_id, endpoint_idx), '')
|
||||
endpoint_hash = hash(url) & 0xFFFF
|
||||
key_id = f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}"
|
||||
|
||||
if key_id in self._keys:
|
||||
key = self._keys[key_id]
|
||||
key.enabled = enabled
|
||||
if enabled:
|
||||
key.status = KeyStatus.ACTIVE
|
||||
else:
|
||||
key.status = KeyStatus.DISABLED
|
||||
|
||||
self._save_config()
|
||||
return True
|
||||
|
||||
def get_endpoint_key_id(self, config_id: str, endpoint_idx: int) -> Optional[str]:
|
||||
"""Get the key_id for a specific endpoint."""
|
||||
url = self._endpoint_urls.get((config_id, endpoint_idx))
|
||||
if url:
|
||||
endpoint_hash = hash(url) & 0xFFFF
|
||||
return f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}"
|
||||
return None
|
||||
311
app/core/providers.py
Executable file
311
app/core/providers.py
Executable file
@ -0,0 +1,311 @@
|
||||
import re
|
||||
from typing import Dict, Any, List, Optional, Tuple, Union
|
||||
from app.models import APIKey
|
||||
|
||||
class BaseProvider:
|
||||
def __init__(self, api_key: APIKey):
|
||||
self.api_key = api_key
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_url(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Parses provider response into a standardized format.
|
||||
Returns:
|
||||
{
|
||||
"content": str,
|
||||
"reasoning_content": Optional[str],
|
||||
"usage": {
|
||||
"input_tokens": int,
|
||||
"output_tokens": int,
|
||||
"total_tokens": int
|
||||
}
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key.key}"
|
||||
}
|
||||
|
||||
def get_url(self) -> str:
|
||||
base = self.api_key.api_base or "https://api.openai.com/v1"
|
||||
return f"{base.rstrip('/')}/chat/completions"
|
||||
|
||||
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = {"messages": messages, "model": self.api_key.model_name, "stream": False}
|
||||
|
||||
# 1. Start with Key's default params
|
||||
final_params = self.api_key.params.copy()
|
||||
|
||||
# 2. Update with request-level params (filter None)
|
||||
valid_request_params = {k: v for k, v in params.items() if v is not None}
|
||||
final_params.update(valid_request_params)
|
||||
|
||||
# 3. Add to payload
|
||||
data.update(final_params)
|
||||
|
||||
# 4. Handle extra_config from Key (e.g., Qwen extra_body)
|
||||
if self.api_key.extra_config:
|
||||
if "extra_body" in self.api_key.extra_config:
|
||||
data.update(self.api_key.extra_config["extra_body"])
|
||||
|
||||
return data
|
||||
|
||||
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
choice = data["choices"][0]
|
||||
text = choice["message"]["content"]
|
||||
reasoning = choice["message"].get("reasoning_content") or \
|
||||
choice.get("reasoning_content")
|
||||
|
||||
usage = data.get("usage", {})
|
||||
return {
|
||||
"content": text,
|
||||
"reasoning_content": reasoning,
|
||||
"usage": {
|
||||
"input_tokens": usage.get("prompt_tokens", 0),
|
||||
"output_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0)
|
||||
}
|
||||
}
|
||||
except (KeyError, IndexError):
|
||||
return {"content": "", "usage": {}}
|
||||
|
||||
class AzureProvider(OpenAIProvider):
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.api_key.key
|
||||
}
|
||||
|
||||
def get_url(self) -> str:
|
||||
# Expected api_base: https://{resource}.openai.azure.com/openai/deployments/{deployment}
|
||||
# We need to append api-version
|
||||
base = self.api_key.api_base or ""
|
||||
sep = "&" if "?" in base else "?"
|
||||
|
||||
# Get api-version from extra_config or default
|
||||
version = "2025-01-01-preview"
|
||||
if self.api_key.extra_config and "api_version" in self.api_key.extra_config:
|
||||
version = self.api_key.extra_config["api_version"]
|
||||
|
||||
return f"{base}{sep}api-version={version}"
|
||||
|
||||
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
payload = super().get_payload(messages, params)
|
||||
# Azure usually doesn't need 'model' in body if it's in URL (deployment)
|
||||
# But keeping it usually doesn't hurt, unless configured to remove.
|
||||
# For simplicity, we keep it or remove it if we want strict adherence to typical azure usage.
|
||||
# Let's remove it to be safe as deployment is in URL.
|
||||
payload.pop("model", None)
|
||||
return payload
|
||||
|
||||
class GeminiProvider(BaseProvider):
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
def get_url(self) -> str:
|
||||
# api_base example: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent
|
||||
base = self.api_key.api_base or f"https://generativelanguage.googleapis.com/v1beta/models/{self.api_key.model_name}:generateContent"
|
||||
sep = "&" if "?" in base else "?"
|
||||
return f"{base}{sep}key={self.api_key.key}"
|
||||
|
||||
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
gemini_contents = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in messages:
|
||||
role = "model" if msg["role"] == "assistant" else msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
if role == "system":
|
||||
text_val = content if isinstance(content, str) else content[0]["text"]
|
||||
system_instruction = {"parts": [{"text": text_val}]}
|
||||
continue
|
||||
|
||||
parts = []
|
||||
content_list = content if isinstance(content, list) else [{"type": "text", "text": content}]
|
||||
|
||||
for item in content_list:
|
||||
if item.get("type") == "text":
|
||||
parts.append({"text": item["text"]})
|
||||
elif item.get("type") == "image_url":
|
||||
url = item.get("image_url", {}).get("url", "")
|
||||
match = re.match(r'data:(.*?);base64,(.*)', url)
|
||||
if match:
|
||||
parts.append({
|
||||
"inline_data": {
|
||||
"mime_type": match.group(1),
|
||||
"data": match.group(2)
|
||||
}
|
||||
})
|
||||
|
||||
if parts:
|
||||
gemini_contents.append({"role": role, "parts": parts})
|
||||
|
||||
generation_config = {}
|
||||
|
||||
# 1. Start with Key's default params (mapped to Gemini format if possible, or assume config.py had correct format)
|
||||
# Note: config.py used "maxOutputTokens" (Gemini style) directly in 'params'.
|
||||
# We should merge self.api_key.params into generation_config
|
||||
generation_config.update(self.api_key.params)
|
||||
|
||||
# 2. Update with request params (mapped from OpenAI style)
|
||||
if "max_tokens" in params:
|
||||
generation_config["maxOutputTokens"] = params["max_tokens"]
|
||||
if "temperature" in params:
|
||||
generation_config["temperature"] = params["temperature"]
|
||||
if "top_p" in params:
|
||||
generation_config["topP"] = params["top_p"]
|
||||
|
||||
# 3. Handle thinking config from Key or Request
|
||||
# Check Key's extra_config first
|
||||
if self.api_key.extra_config and "thinking_config" in self.api_key.extra_config:
|
||||
generation_config["thinkingConfig"] = self.api_key.extra_config["thinking_config"]
|
||||
|
||||
# Override with request param if present
|
||||
if "thinking_config" in params:
|
||||
generation_config["thinkingConfig"] = params["thinking_config"]
|
||||
|
||||
payload = {
|
||||
"contents": gemini_contents,
|
||||
"generationConfig": generation_config
|
||||
}
|
||||
|
||||
if system_instruction:
|
||||
payload["system_instruction"] = system_instruction
|
||||
|
||||
return payload
|
||||
|
||||
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
candidate = data["candidates"][0]
|
||||
parts_text = []
|
||||
if "content" in candidate and "parts" in candidate["content"]:
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
parts_text.append(part["text"])
|
||||
|
||||
text = "\n".join(parts_text) if parts_text else ""
|
||||
|
||||
reasoning = None
|
||||
if "thinking_process" in candidate:
|
||||
reasoning = candidate["thinking_process"]
|
||||
|
||||
usage_metadata = data.get("usageMetadata", {})
|
||||
input_tokens = usage_metadata.get("promptTokenCount", 0)
|
||||
output_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
||||
total_tokens = input_tokens + output_tokens
|
||||
|
||||
return {
|
||||
"content": text,
|
||||
"reasoning_content": reasoning,
|
||||
"usage": {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
except (KeyError, IndexError):
|
||||
return {"content": "", "usage": {}}
|
||||
|
||||
class ClaudeProvider(BaseProvider):
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"x-api-key": self.api_key.key
|
||||
}
|
||||
|
||||
# Check extra_config for version
|
||||
if self.api_key.extra_config and "anthropic_version" in self.api_key.extra_config:
|
||||
headers["anthropic-version"] = self.api_key.extra_config["anthropic_version"]
|
||||
|
||||
return headers
|
||||
|
||||
def get_url(self) -> str:
|
||||
base = self.api_key.api_base or "https://api.anthropic.com/v1/messages"
|
||||
return base
|
||||
|
||||
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
claude_msgs = []
|
||||
system_prompt = ""
|
||||
|
||||
for msg in messages:
|
||||
role = "assistant" if msg["role"] == "assistant" else msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
if role == "system":
|
||||
system_prompt = content if isinstance(content, str) else content[0]["text"]
|
||||
continue
|
||||
|
||||
# Simplified content handling for text only for now, can expand if needed
|
||||
claude_msgs.append({"role": role, "content": content})
|
||||
|
||||
# Start with Key's default params
|
||||
payload = {
|
||||
"model": self.api_key.model_name,
|
||||
"messages": claude_msgs,
|
||||
"max_tokens": 1024 # Default fallback
|
||||
}
|
||||
|
||||
# Merge Key params
|
||||
payload.update(self.api_key.params)
|
||||
|
||||
# Merge request params
|
||||
if "max_tokens" in params:
|
||||
payload["max_tokens"] = params["max_tokens"]
|
||||
if "temperature" in params:
|
||||
payload["temperature"] = params["temperature"]
|
||||
|
||||
if system_prompt:
|
||||
payload["system"] = system_prompt
|
||||
|
||||
return payload
|
||||
|
||||
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
content_list = data.get("content", [])
|
||||
text = ""
|
||||
for item in content_list:
|
||||
if item.get("type") == "text":
|
||||
text += item.get("text", "")
|
||||
|
||||
usage = data.get("usage", {})
|
||||
input_tokens = usage.get("input_tokens", 0)
|
||||
output_tokens = usage.get("output_tokens", 0)
|
||||
|
||||
return {
|
||||
"content": text,
|
||||
"usage": {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": input_tokens + output_tokens
|
||||
}
|
||||
}
|
||||
except Exception:
|
||||
return {"content": "", "usage": {}}
|
||||
|
||||
def get_provider_handler(api_key: APIKey) -> BaseProvider:
|
||||
providers = {
|
||||
"google": GeminiProvider,
|
||||
"gemini": GeminiProvider,
|
||||
"openai": OpenAIProvider,
|
||||
"azure": AzureProvider,
|
||||
"claude": ClaudeProvider,
|
||||
"anthropic": ClaudeProvider,
|
||||
"zhipu": OpenAIProvider # Zhipu is mostly OpenAI compatible
|
||||
}
|
||||
provider_class = providers.get(api_key.provider, OpenAIProvider)
|
||||
return provider_class(api_key)
|
||||
1016
app/core/stats_exporter.py
Executable file
1016
app/core/stats_exporter.py
Executable file
File diff suppressed because it is too large
Load Diff
207
app/core/stats_store.py
Executable file
207
app/core/stats_store.py
Executable file
@ -0,0 +1,207 @@
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class StatsStore:
|
||||
DAILY_RETENTION_DAYS = 7
|
||||
|
||||
def __init__(self, db_path: str | Path):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _init_db(self):
|
||||
with self._connect() as conn:
|
||||
conn.executescript(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS latest_summary (
|
||||
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||
updated_at INTEGER NOT NULL,
|
||||
total_keys INTEGER NOT NULL,
|
||||
active_keys INTEGER NOT NULL,
|
||||
cooldown_keys INTEGER NOT NULL,
|
||||
disabled_keys INTEGER NOT NULL,
|
||||
total_concurrency INTEGER NOT NULL,
|
||||
total_rpm INTEGER NOT NULL,
|
||||
total_tpm INTEGER NOT NULL,
|
||||
total_rpd INTEGER NOT NULL,
|
||||
total_yesterday_rpd INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS latest_key_stats (
|
||||
key_id TEXT PRIMARY KEY,
|
||||
updated_at INTEGER NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
enabled INTEGER NOT NULL,
|
||||
config_id TEXT,
|
||||
owner TEXT,
|
||||
current_concurrency INTEGER NOT NULL,
|
||||
current_rpm INTEGER NOT NULL,
|
||||
current_tpm INTEGER NOT NULL,
|
||||
current_rpd INTEGER NOT NULL,
|
||||
cooldown_remaining REAL NOT NULL,
|
||||
rpm_limit INTEGER NOT NULL,
|
||||
tpm_limit INTEGER NOT NULL,
|
||||
rpd_limit INTEGER NOT NULL,
|
||||
max_concurrency INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS daily_usage (
|
||||
day TEXT PRIMARY KEY,
|
||||
label TEXT NOT NULL,
|
||||
total_rpd INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def save_export(self, stats: list[dict[str, Any]], usage_summary: dict[str, Any], daily_usage: list[dict[str, Any]]):
|
||||
now = int(datetime.now(timezone.utc).timestamp())
|
||||
cutoff = (datetime.now(timezone.utc) - timedelta(days=self.DAILY_RETENTION_DAYS)).strftime("%Y%m%d")
|
||||
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO latest_summary (
|
||||
id, updated_at, total_keys, active_keys, cooldown_keys, disabled_keys,
|
||||
total_concurrency, total_rpm, total_tpm, total_rpd, total_yesterday_rpd
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
updated_at = excluded.updated_at,
|
||||
total_keys = excluded.total_keys,
|
||||
active_keys = excluded.active_keys,
|
||||
cooldown_keys = excluded.cooldown_keys,
|
||||
disabled_keys = excluded.disabled_keys,
|
||||
total_concurrency = excluded.total_concurrency,
|
||||
total_rpm = excluded.total_rpm,
|
||||
total_tpm = excluded.total_tpm,
|
||||
total_rpd = excluded.total_rpd,
|
||||
total_yesterday_rpd = excluded.total_yesterday_rpd
|
||||
""",
|
||||
(
|
||||
1,
|
||||
now,
|
||||
usage_summary["total_keys"],
|
||||
usage_summary["active_keys"],
|
||||
usage_summary["cooldown_keys"],
|
||||
usage_summary["disabled_keys"],
|
||||
usage_summary["total_concurrency"],
|
||||
usage_summary["total_rpm"],
|
||||
usage_summary["total_tpm"],
|
||||
usage_summary["total_rpd"],
|
||||
usage_summary.get("total_yesterday_rpd", 0),
|
||||
),
|
||||
)
|
||||
|
||||
conn.execute("DELETE FROM latest_key_stats")
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO latest_key_stats (
|
||||
key_id, updated_at, model_name, provider, status, enabled, config_id, owner,
|
||||
current_concurrency, current_rpm, current_tpm, current_rpd, cooldown_remaining,
|
||||
rpm_limit, tpm_limit, rpd_limit, max_concurrency
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
(
|
||||
item["id"],
|
||||
now,
|
||||
item["model_name"],
|
||||
item["provider"],
|
||||
item["status"],
|
||||
1 if item["enabled"] else 0,
|
||||
item.get("config_id"),
|
||||
item.get("owner"),
|
||||
item["usage"]["current_concurrency"],
|
||||
item["usage"]["current_rpm"],
|
||||
item["usage"]["current_tpm"],
|
||||
item["usage"]["current_rpd"],
|
||||
item["cooldown_remaining"],
|
||||
item["limits"]["rpm"],
|
||||
item["limits"]["tpm"],
|
||||
item["limits"]["rpd"],
|
||||
item["limits"]["max_concurrency"],
|
||||
)
|
||||
for item in stats
|
||||
],
|
||||
)
|
||||
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO daily_usage (day, label, total_rpd, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(day) DO UPDATE SET
|
||||
label = excluded.label,
|
||||
total_rpd = MAX(daily_usage.total_rpd, excluded.total_rpd),
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
[
|
||||
(
|
||||
item["day"],
|
||||
item["label"],
|
||||
item["total_rpd"],
|
||||
now,
|
||||
)
|
||||
for item in daily_usage
|
||||
],
|
||||
)
|
||||
conn.execute("DELETE FROM daily_usage WHERE day < ?", (cutoff,))
|
||||
|
||||
def get_daily_usage(self, days: int = 7) -> list[dict[str, Any]]:
|
||||
if days <= 0:
|
||||
return []
|
||||
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT day, label, total_rpd
|
||||
FROM daily_usage
|
||||
ORDER BY day DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(days,),
|
||||
).fetchall()
|
||||
|
||||
rows = list(reversed(rows))
|
||||
return [
|
||||
{
|
||||
"day": row["day"],
|
||||
"label": row["label"],
|
||||
"total_rpd": row["total_rpd"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_latest_summary(self) -> dict[str, Any] | None:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute("SELECT * FROM latest_summary WHERE id = 1").fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return dict(row)
|
||||
|
||||
def get_latest_key_stats(self) -> list[dict[str, Any]]:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT *
|
||||
FROM latest_key_stats
|
||||
ORDER BY provider, model_name, key_id
|
||||
"""
|
||||
).fetchall()
|
||||
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def delete_day(self, day: str):
|
||||
with self._connect() as conn:
|
||||
conn.execute("DELETE FROM daily_usage WHERE day = ?", (day,))
|
||||
36
app/models.py
Executable file
36
app/models.py
Executable file
@ -0,0 +1,36 @@
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
class KeyStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
COOLDOWN = "cooldown"
|
||||
DISABLED = "disabled"
|
||||
|
||||
class APIKey(BaseModel):
|
||||
id: str = Field(..., description="Unique identifier for the key")
|
||||
key: str = Field(..., description="The actual API Key string")
|
||||
provider: str = Field(..., description="Provider identifier (e.g., google, zhipu)")
|
||||
api_base: Optional[str] = Field(None, description="Base URL for the API")
|
||||
model_name: str = Field(..., description="The physical model name")
|
||||
params: Dict[str, Any] = Field(default_factory=dict, description="Default model parameters (e.g., temperature)")
|
||||
extra_config: Dict[str, Any] = Field(default_factory=dict, description="Provider specific configuration")
|
||||
enabled: bool = Field(default=True, description="Whether the key is enabled in configuration")
|
||||
status: KeyStatus = Field(default=KeyStatus.ACTIVE, description="Current status of the key")
|
||||
rpm_limit: int = Field(..., description="Requests Per Minute limit")
|
||||
tpm_limit: int = Field(..., description="Tokens Per Minute limit")
|
||||
rpd_limit: int = Field(default=0, description="Requests Per Day limit (0 = unlimited)")
|
||||
max_concurrency: int = Field(..., description="Maximum concurrent requests allowed")
|
||||
config_id: Optional[str] = Field(None, description="Config ID this key belongs to")
|
||||
owner: Optional[str] = Field(None, description="Owner of this key")
|
||||
|
||||
# Runtime stats (Not loaded from config usually, but part of the object in memory)
|
||||
current_concurrency: int = Field(default=0, description="Current concurrent requests")
|
||||
current_rpm: int = Field(default=0, description="Current RPM usage")
|
||||
current_tpm: int = Field(default=0, description="Current TPM usage")
|
||||
current_rpd: int = Field(default=0, description="Current RPD usage")
|
||||
rate_limit_penalty: int = Field(default=0, description="429 demotion order. 0 means never rate limited.")
|
||||
endpoint_idx: Optional[int] = Field(default=None, description="Endpoint index for local llama load balancing")
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
10
app/runtime.py
Executable file
10
app/runtime.py
Executable file
@ -0,0 +1,10 @@
|
||||
import os
|
||||
|
||||
from app.core.dispatcher import Dispatcher
|
||||
from app.core.key_manager import KeyManager
|
||||
from app.core.stats_exporter import StatsExporter
|
||||
|
||||
|
||||
key_manager = KeyManager(os.getenv("KEY_CONFIG_PATH", "config/keys.yaml"))
|
||||
dispatcher = Dispatcher(key_manager)
|
||||
stats_exporter = StatsExporter(dispatcher, output_dir="stats", interval=10)
|
||||
18
app/utils/redis_client.py
Executable file
18
app/utils/redis_client.py
Executable file
@ -0,0 +1,18 @@
|
||||
import redis.asyncio as redis
|
||||
import os
|
||||
|
||||
class RedisClient:
|
||||
_instance = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
|
||||
cls._instance = redis.from_url(redis_url, decode_responses=True)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
async def close(cls):
|
||||
if cls._instance:
|
||||
await cls._instance.close()
|
||||
cls._instance = None
|
||||
41
config/keys_example.yaml
Executable file
41
config/keys_example.yaml
Executable file
@ -0,0 +1,41 @@
|
||||
model_configs:
|
||||
- id: nvidia-glm-cfg
|
||||
provider: nvidia
|
||||
model_name: z-ai/glm-5.1
|
||||
api_base: https://integrate.api.nvidia.com/v1
|
||||
rpm_limit: 40
|
||||
tpm_limit: 100000
|
||||
max_concurrency: 8
|
||||
params:
|
||||
temperature: 1
|
||||
top_p: 1
|
||||
max_tokens: 16384
|
||||
stream: false
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false
|
||||
clear_thinking: false
|
||||
- id: llama-local-cfg
|
||||
provider: openai
|
||||
model_name: llama-local
|
||||
api_base: http://192.168.2.101:8848/v1
|
||||
api_endpoints:
|
||||
- url: http://192.168.2.101:8848/v1
|
||||
enabled: true
|
||||
- url: http://192.168.1.51:8848/v1
|
||||
enabled: true
|
||||
- url: http://192.168.1.61:8848/v1
|
||||
enabled: true
|
||||
rpm_limit: 999999
|
||||
tpm_limit: 999999999
|
||||
max_concurrency: 999
|
||||
params:
|
||||
temperature: 1.0
|
||||
max_tokens: 6144
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false
|
||||
keys:
|
||||
- config_id: nvidia-glm-cfg
|
||||
enabled: true
|
||||
keys:
|
||||
- key: nvapi-key-1
|
||||
owner: name
|
||||
41
main.py
Executable file
41
main.py
Executable file
@ -0,0 +1,41 @@
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from app.api.routes import router
|
||||
from app.runtime import stats_exporter
|
||||
from app.utils.redis_client import RedisClient
|
||||
|
||||
app = FastAPI(title="Smart KeyPool Gateway")
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def dashboard_home():
|
||||
return await stats_exporter.render_dashboard()
|
||||
|
||||
@app.get("/dashboard", response_class=HTMLResponse)
|
||||
async def dashboard_page():
|
||||
return await stats_exporter.render_dashboard()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
redis = RedisClient.get_instance()
|
||||
try:
|
||||
await redis.ping()
|
||||
print("Connected to Redis")
|
||||
except Exception as e:
|
||||
print(f"Failed to connect to Redis: {e}")
|
||||
|
||||
await stats_exporter.start()
|
||||
print("Stats exporter started - dashboard available at / or /dashboard")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
await stats_exporter.stop()
|
||||
await RedisClient.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "8888")))
|
||||
6
requirements.txt
Executable file
6
requirements.txt
Executable file
@ -0,0 +1,6 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
redis
|
||||
pydantic
|
||||
pyyaml
|
||||
httpx
|
||||
284
test_keys.py
Executable file
284
test_keys.py
Executable file
@ -0,0 +1,284 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试所有启用的API密钥是否正常工作(并发版本)
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import httpx
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# 添加app目录到路径
|
||||
sys.path.insert(0, 'd:\\DPI\\keypool')
|
||||
|
||||
from app.core.key_manager import KeyManager
|
||||
from app.core.providers import get_provider_handler
|
||||
from app.models import APIKey
|
||||
|
||||
|
||||
class TestStatus(Enum):
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
key_id: str
|
||||
provider: str
|
||||
owner: Optional[str]
|
||||
status: TestStatus
|
||||
message: str
|
||||
response_time: float = 0.0
|
||||
details: Optional[Dict] = None
|
||||
|
||||
|
||||
async def test_single_key(api_key: APIKey, timeout: int = 240) -> TestResult:
|
||||
"""测试单个API密钥"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
owner = getattr(api_key, 'owner', None)
|
||||
|
||||
# 如果密钥被禁用,跳过测试
|
||||
if not api_key.enabled:
|
||||
return TestResult(
|
||||
key_id=api_key.id,
|
||||
provider=api_key.provider,
|
||||
owner=owner,
|
||||
status=TestStatus.SKIPPED,
|
||||
message="密钥已禁用 (enabled: false)",
|
||||
response_time=0.0
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_provider_handler(api_key)
|
||||
|
||||
# 准备测试消息
|
||||
test_messages = [
|
||||
{"role": "user", "content": "Hello, this is a test message. Please reply with 'OK'."}
|
||||
]
|
||||
|
||||
# 获取请求参数
|
||||
headers = provider.get_headers()
|
||||
url = provider.get_url()
|
||||
payload = provider.get_payload(test_messages, {})
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
parsed = provider.parse_response(data)
|
||||
content = parsed.get("content", "")
|
||||
|
||||
return TestResult(
|
||||
key_id=api_key.id,
|
||||
provider=api_key.provider,
|
||||
owner=owner,
|
||||
status=TestStatus.SUCCESS,
|
||||
message=f"测试成功,响应时间: {response_time:.2f}s",
|
||||
response_time=response_time,
|
||||
details={
|
||||
"content_preview": content[:100] + "..." if len(content) > 100 else content,
|
||||
"usage": parsed.get("usage", {})
|
||||
}
|
||||
)
|
||||
else:
|
||||
error_text = response.text[:500] if response.text else "No error details"
|
||||
return TestResult(
|
||||
key_id=api_key.id,
|
||||
provider=api_key.provider,
|
||||
owner=owner,
|
||||
status=TestStatus.FAILED,
|
||||
message=f"HTTP错误 {response.status_code}: {error_text}",
|
||||
response_time=time.time() - start_time
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return TestResult(
|
||||
key_id=api_key.id,
|
||||
provider=api_key.provider,
|
||||
owner=owner,
|
||||
status=TestStatus.FAILED,
|
||||
message=f"请求超时 (>{timeout}s)",
|
||||
response_time=time.time() - start_time
|
||||
)
|
||||
except httpx.ConnectError as e:
|
||||
return TestResult(
|
||||
key_id=api_key.id,
|
||||
provider=api_key.provider,
|
||||
owner=owner,
|
||||
status=TestStatus.FAILED,
|
||||
message=f"连接错误: {str(e)}",
|
||||
response_time=time.time() - start_time
|
||||
)
|
||||
except Exception as e:
|
||||
return TestResult(
|
||||
key_id=api_key.id,
|
||||
provider=api_key.provider,
|
||||
owner=owner,
|
||||
status=TestStatus.FAILED,
|
||||
message=f"异常: {type(e).__name__}: {str(e)}",
|
||||
response_time=time.time() - start_time
|
||||
)
|
||||
|
||||
|
||||
async def test_all_keys_concurrent(config_path: str = "config/keys.yaml", max_concurrency: int = 10):
|
||||
"""并发测试所有启用的密钥"""
|
||||
print("=" * 80)
|
||||
print("API密钥测试工具 (并发模式)")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
# 加载密钥管理器
|
||||
try:
|
||||
km = KeyManager(config_path)
|
||||
all_keys = km.get_all_keys()
|
||||
print(f"共加载 {len(all_keys)} 个密钥配置")
|
||||
except Exception as e:
|
||||
print(f"加载配置失败: {e}")
|
||||
return
|
||||
|
||||
# 筛选启用的密钥
|
||||
enabled_keys = [k for k in all_keys if k.enabled]
|
||||
print(f"其中 {len(enabled_keys)} 个密钥已启用")
|
||||
print(f"并发数: {max_concurrency}")
|
||||
print()
|
||||
|
||||
if not enabled_keys:
|
||||
print("没有启用的密钥需要测试")
|
||||
return
|
||||
|
||||
# 使用信号量控制并发
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def test_with_semaphore(key: APIKey) -> TestResult:
|
||||
async with semaphore:
|
||||
return await test_single_key(key)
|
||||
|
||||
# 创建所有测试任务
|
||||
tasks = [test_with_semaphore(key) for key in enabled_keys]
|
||||
|
||||
# 并发执行所有测试
|
||||
print("开始并发测试...")
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理可能的异常结果
|
||||
processed_results: List[TestResult] = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
key = enabled_keys[i]
|
||||
owner = getattr(key, 'owner', None)
|
||||
processed_results.append(TestResult(
|
||||
key_id=key.id,
|
||||
provider=key.provider,
|
||||
owner=owner,
|
||||
status=TestStatus.FAILED,
|
||||
message=f"测试任务异常: {type(result).__name__}: {str(result)}",
|
||||
response_time=0.0
|
||||
))
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
results = processed_results
|
||||
|
||||
# 按提供商分组显示结果
|
||||
results_by_provider: Dict[str, List[TestResult]] = {}
|
||||
for r in results:
|
||||
p = r.provider
|
||||
if p not in results_by_provider:
|
||||
results_by_provider[p] = []
|
||||
results_by_provider[p].append(r)
|
||||
|
||||
# 打印详细结果
|
||||
for provider, provider_results in sorted(results_by_provider.items()):
|
||||
print(f"\n{'='*80}")
|
||||
print(f"提供商: {provider.upper()} ({len(provider_results)} 个密钥)")
|
||||
print(f"{'='*80}")
|
||||
|
||||
for r in provider_results:
|
||||
status_icon = "✓" if r.status == TestStatus.SUCCESS else "✗" if r.status == TestStatus.FAILED else "○"
|
||||
owner_str = f" ({r.owner})" if r.owner else ""
|
||||
print(f"\n [{status_icon}] {r.key_id}{owner_str}")
|
||||
print(f" 状态: {r.message}")
|
||||
if r.details and r.status == TestStatus.SUCCESS:
|
||||
print(f" 回复: {r.details['content_preview']}")
|
||||
|
||||
# 打印汇总报告
|
||||
print("\n" + "=" * 80)
|
||||
print("测试汇总报告")
|
||||
print("=" * 80)
|
||||
|
||||
success_count = sum(1 for r in results if r.status == TestStatus.SUCCESS)
|
||||
failed_count = sum(1 for r in results if r.status == TestStatus.FAILED)
|
||||
skipped_count = sum(1 for r in results if r.status == TestStatus.SKIPPED)
|
||||
|
||||
print(f"\n总计: {len(results)} 个密钥")
|
||||
print(f" ✓ 成功: {success_count}")
|
||||
print(f" ✗ 失败: {failed_count}")
|
||||
print(f" ○ 跳过: {skipped_count}")
|
||||
|
||||
# 按提供商统计
|
||||
print("\n按提供商统计:")
|
||||
provider_stats: Dict[str, Dict[str, int]] = {}
|
||||
for r in results:
|
||||
p = r.provider
|
||||
if p not in provider_stats:
|
||||
provider_stats[p] = {"success": 0, "failed": 0, "skipped": 0}
|
||||
if r.status == TestStatus.SUCCESS:
|
||||
provider_stats[p]["success"] += 1
|
||||
elif r.status == TestStatus.FAILED:
|
||||
provider_stats[p]["failed"] += 1
|
||||
else:
|
||||
provider_stats[p]["skipped"] += 1
|
||||
|
||||
for provider, stats in sorted(provider_stats.items()):
|
||||
total = stats["success"] + stats["failed"] + stats["skipped"]
|
||||
success_rate = (stats["success"] / total * 100) if total > 0 else 0
|
||||
print(f" {provider:15s}: {stats['success']:2d}/{total:2d} 成功 ({success_rate:5.1f}%), {stats['failed']} 失败, {stats['skipped']} 跳过")
|
||||
|
||||
# 失败的详细信息
|
||||
failed_results = [r for r in results if r.status == TestStatus.FAILED]
|
||||
if failed_results:
|
||||
print("\n" + "-" * 80)
|
||||
print("失败的密钥详情:")
|
||||
print("-" * 80)
|
||||
for r in failed_results:
|
||||
print(f"\n [{r.provider}] {r.key_id}")
|
||||
if r.owner:
|
||||
print(f" 拥有者: {r.owner}")
|
||||
print(f" 错误: {r.message}")
|
||||
|
||||
# 成功的响应时间统计
|
||||
success_results = [r for r in results if r.status == TestStatus.SUCCESS and r.response_time > 0]
|
||||
if success_results:
|
||||
avg_time = sum(r.response_time for r in success_results) / len(success_results)
|
||||
min_time = min(r.response_time for r in success_results)
|
||||
max_time = max(r.response_time for r in success_results)
|
||||
print("\n" + "-" * 80)
|
||||
print("响应时间统计 (成功请求):")
|
||||
print("-" * 80)
|
||||
print(f" 平均: {avg_time:.2f}s")
|
||||
print(f" 最快: {min_time:.2f}s")
|
||||
print(f" 最慢: {max_time:.2f}s")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("测试完成")
|
||||
print("=" * 80)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
# 切换到脚本所在目录
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
os.chdir(script_dir)
|
||||
|
||||
# 运行测试(默认并发数10)
|
||||
asyncio.run(test_all_keys_concurrent(max_concurrency=11))
|
||||
Loading…
Reference in New Issue
Block a user