From fb7f06ed03da6891ada82a6e508bfa12d7e12f1d Mon Sep 17 00:00:00 2001 From: yufengzhen Date: Wed, 17 Jun 2026 11:13:11 +0800 Subject: [PATCH] init --- .gitignore | 3 + app/api/routes.py | 227 ++++++++ app/core/dispatcher.py | 556 ++++++++++++++++++++ app/core/key_manager.py | 294 +++++++++++ app/core/providers.py | 311 +++++++++++ app/core/stats_exporter.py | 1016 ++++++++++++++++++++++++++++++++++++ app/core/stats_store.py | 207 ++++++++ app/models.py | 36 ++ app/runtime.py | 10 + app/utils/redis_client.py | 18 + config/keys_example.yaml | 41 ++ main.py | 41 ++ requirements.txt | 6 + test_keys.py | 284 ++++++++++ 14 files changed, 3050 insertions(+) create mode 100755 .gitignore create mode 100755 app/api/routes.py create mode 100755 app/core/dispatcher.py create mode 100755 app/core/key_manager.py create mode 100755 app/core/providers.py create mode 100755 app/core/stats_exporter.py create mode 100755 app/core/stats_store.py create mode 100755 app/models.py create mode 100755 app/runtime.py create mode 100755 app/utils/redis_client.py create mode 100755 config/keys_example.yaml create mode 100755 main.py create mode 100755 requirements.txt create mode 100755 test_keys.py diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..5dabd16 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +test/* +keys.yaml diff --git a/app/api/routes.py b/app/api/routes.py new file mode 100755 index 0000000..e692c7f --- /dev/null +++ b/app/api/routes.py @@ -0,0 +1,227 @@ +from fastapi import APIRouter, HTTPException, Request, BackgroundTasks +from app.core.providers import get_provider_handler +from app.runtime import key_manager, dispatcher, stats_exporter +import httpx +import time +import logging +import traceback + +logger = logging.getLogger(__name__) + +router = APIRouter() + +@router.get("/v1/stats") +async def get_stats(): + """ + Returns usage statistics for all API keys. + """ + stats = await dispatcher.get_all_key_stats() + return {"data": stats} + +@router.get("/v1/usage") +async def get_usage(): + """ + Returns overall usage summary. + """ + summary = await dispatcher.get_usage_summary() + return {"data": summary} + +@router.post("/v1/stats/reset/rpd") +async def reset_today_rpd(): + """ + Reset today's RPD counters in Redis and refresh dashboard output. + """ + day = dispatcher._get_day_str() + deleted_keys = await dispatcher.reset_today_rpd_usage() + stats_exporter.clear_daily_usage_day(day) + await stats_exporter.export_once() + return { + "success": True, + "day": day, + "deleted_keys": deleted_keys, + "message": f"Reset today's RPD counters for {deleted_keys} keys", + } + +@router.post("/v1/keys/config/{config_id}/enable") +async def enable_config(config_id: str): + """ + Enable all keys for a specific config_id (auto-saves to config file). + """ + success = key_manager.set_config_enabled(config_id, True) + if success: + return {"success": True, "message": f"Config {config_id} enabled and saved"} + else: + raise HTTPException(status_code=404, detail=f"Config {config_id} not found") + +@router.post("/v1/keys/config/{config_id}/disable") +async def disable_config(config_id: str): + """ + Disable all keys for a specific config_id (auto-saves to config file). + """ + success = key_manager.set_config_enabled(config_id, False) + if success: + return {"success": True, "message": f"Config {config_id} disabled and saved"} + else: + raise HTTPException(status_code=404, detail=f"Config {config_id} not found") + +@router.post("/v1/keys/{key_id}/enable") +async def enable_key(key_id: str): + """ + Enable a specific key. + """ + success = key_manager.set_key_enabled(key_id, True) + if success: + return {"success": True, "message": f"Key {key_id} enabled"} + else: + raise HTTPException(status_code=404, detail=f"Key {key_id} not found") + +@router.post("/v1/keys/{key_id}/disable") +async def disable_key(key_id: str): + """ + Disable a specific key. + """ + success = key_manager.set_key_enabled(key_id, False) + if success: + return {"success": True, "message": f"Key {key_id} disabled"} + else: + raise HTTPException(status_code=404, detail=f"Key {key_id} not found") + + +@router.post("/v1/keys/config/{config_id}/endpoint/{endpoint_idx}/enable") +async def enable_endpoint(config_id: str, endpoint_idx: int): + """ + Enable a specific endpoint for a config (auto-saves to config file). + """ + success = key_manager.set_endpoint_enabled(config_id, endpoint_idx, True) + if success: + return {"success": True, "message": f"Endpoint {endpoint_idx} of {config_id} enabled and saved"} + else: + raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_idx} of {config_id} not found") + + +@router.post("/v1/keys/config/{config_id}/endpoint/{endpoint_idx}/disable") +async def disable_endpoint(config_id: str, endpoint_idx: int): + """ + Disable a specific endpoint for a config (auto-saves to config file). + """ + success = key_manager.set_endpoint_enabled(config_id, endpoint_idx, False) + if success: + return {"success": True, "message": f"Endpoint {endpoint_idx} of {config_id} disabled and saved"} + else: + raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_idx} of {config_id} not found") + +@router.get("/v1/configs") +async def get_configs(): + """ + Returns all config IDs. + """ + configs = [] + for config_id in key_manager.get_config_ids(): + keys = key_manager.get_keys_by_config(config_id) + if keys: + configs.append({ + "config_id": config_id, + "model_name": keys[0].model_name, + "provider": keys[0].provider, + "key_count": len(keys), + "enabled": keys[0].enabled + }) + return {"data": configs} + +@router.post("/v1/chat/completions") +async def chat_completions(request: Request, background_tasks: BackgroundTasks): + """ + Unified entry point. + Expects JSON body with "model" field. + """ + try: + body = await request.json() + except: + logger.error(f"[Invalid JSON] detail={traceback.format_exc()}...") + raise HTTPException(status_code=400, detail="Invalid JSON") + + model_name = body.get("model") + + key, error_reason = await dispatcher.select_key(model_name) + if not key: + logger.error(f"[Key Not Found] model={model_name} | error={error_reason}...") + raise HTTPException(status_code=503, detail=error_reason) + + if not model_name: + model_name = key.model_name + + acquired = await dispatcher.acquire_lease(key) + if not acquired: + logger.error(f"[System busy (Concurrency Limit)] model={model_name} | provider={key.provider} | key={key.id}|owner={key.owner or 'N/A'}|key_prefix={key.key[:20] if key.key else 'N/A'}...") + raise HTTPException(status_code=503, detail="System busy (Concurrency Limit)") + + try: + handler = get_provider_handler(key) + url = handler.get_url() + headers = handler.get_headers() + + params = {k: v for k, v in body.items() if k not in ["model", "messages", "stream"]} + + payload = handler.get_payload(body.get("messages", []), params) + + timeout = httpx.Timeout(300.0, connect=10.0) + + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post(url, headers=headers, json=payload) + + if resp.status_code != 200: + error_detail = resp.text[:500] if len(resp.text) > 500 else resp.text + logger.error( + f"[Request Failed] model={model_name} | key={key.id}|owner={key.owner or 'N/A'}|key_prefix={key.key[:20] if key.key else 'N/A'} | " + f"api_base={key.api_base} | " + f"status={resp.status_code} | response={error_detail}" + ) + + if resp.status_code == 429: + error_message = "" + try: + error_data = resp.json() + if isinstance(error_data, dict): + error_message = error_data.get("error", {}).get("message", "") or error_data.get("message", "") + except: + error_message = resp.text + await dispatcher.report_failure(key, is_rate_limit=True, error_message=error_message) + + raise HTTPException(status_code=resp.status_code, detail=f"Provider Error: {error_detail}") + + data = resp.json() + parsed_result = handler.parse_response(data) + + total_tokens = parsed_result["usage"].get("total_tokens", 0) + background_tasks.add_task(dispatcher.record_usage, key, total_tokens) + + response_data = { + "id": f"chatcmpl-{int(time.time())}", + "object": "chat.completion", + "created": int(time.time()), + "model": model_name, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": parsed_result["content"], + "reasoning_content": parsed_result.get("reasoning_content") + }, + "finish_reason": "stop" + }], + "usage": parsed_result["usage"] + } + + return response_data + + except HTTPException: + raise + except Exception as e: + _key_info = f"{key.id}|owner={key.owner or 'N/A'}|key_prefix={key.key[:20] if key.key else 'N/A'}" if key else 'N/A' + logger.exception( + f"[Internal Error] model={model_name} | " + f"key={_key_info} | api_base={getattr(key, 'api_base', 'N/A') if key else 'N/A'} | error={str(e)}" + ) + raise HTTPException(status_code=500, detail=str(e)) + finally: + await dispatcher.release_lease(key) diff --git a/app/core/dispatcher.py b/app/core/dispatcher.py new file mode 100755 index 0000000..46d46a5 --- /dev/null +++ b/app/core/dispatcher.py @@ -0,0 +1,556 @@ +import time +import asyncio +import fnmatch +from datetime import datetime, timedelta, timezone +from typing import Optional, List, Dict, Any +from app.models import APIKey, KeyStatus +from app.core.key_manager import KeyManager +from app.utils.redis_client import RedisClient + +class Dispatcher: + DAILY_RPD_TTL_SECONDS = 90 * 86400 + RATE_LIMIT_PENALTY_SEQ_KEY = "keypool:priority:429:seq" + PROVIDER_PRIORITIES = { + "modelscope": 0, + "nvidia": 0, + "siliconflow": 0, + "llama-local": 0, + } + + def __init__(self, key_manager: KeyManager): + self.key_manager = key_manager + self.redis = RedisClient.get_instance() + + @staticmethod + def _get_day_str(offset_days: int = 0) -> str: + current_day = datetime.now(timezone.utc) + timedelta(days=offset_days) + return current_day.strftime("%Y%m%d") + + @staticmethod + def _get_rate_limit_penalty_key(key_id: str) -> str: + return f"keypool:priority:429:{key_id}" + + @classmethod + def _get_provider_priority(cls, provider: str) -> int: + return cls.PROVIDER_PRIORITIES.get(provider.lower(), len(cls.PROVIDER_PRIORITIES)) + + async def _delete_matching_keys(self, pattern: str, batch_size: int = 200) -> int: + if hasattr(self.redis, "scan_iter"): + batch = [] + deleted = 0 + async for key in self.redis.scan_iter(match=pattern, count=batch_size): + batch.append(key) + if len(batch) >= batch_size: + deleted += await self.redis.delete(*batch) + batch.clear() + if batch: + deleted += await self.redis.delete(*batch) + return deleted + + if not hasattr(self.redis, "values"): + raise RuntimeError("Redis client does not support scan_iter") + + keys = [key for key in self.redis.values if fnmatch.fnmatch(key, pattern)] + return await self.redis.delete(*keys) if keys else 0 + + async def reset_today_rpd_usage(self) -> int: + day = self._get_day_str() + return await self._delete_matching_keys(f"keypool:usage:rpd:*:{day}") + + async def get_key_stats(self, key: APIKey): + """Fetches runtime stats from Redis for a single key.""" + concurrency_key = f"keypool:concurrency:{key.id}" + penalty_key = self._get_rate_limit_penalty_key(key.id) + + current_day = self._get_day_str() + rpd_key = f"keypool:usage:rpd:{key.id}:{current_day}" + + # Pipeline for efficiency + pipe = self.redis.pipeline() + pipe.get(concurrency_key) + pipe.get(penalty_key) + pipe.get(rpd_key) + results = await pipe.execute() + + current_concurrency = int(results[0]) if results[0] else 0 + rate_limit_penalty = int(results[1]) if results[1] else 0 + current_rpd = int(results[2]) if results[2] else 0 + + key.current_concurrency = current_concurrency + key.rate_limit_penalty = rate_limit_penalty + key.current_rpd = current_rpd + return key + + async def select_key(self, model_name: Optional[str] = None, timeout: float = 30.0) -> tuple[Optional[APIKey], str]: + """ + Selects the best available key for the given model. + If model_name is None, selects from ALL available keys. + + If no keys are available immediately, it will wait (queue) up to `timeout` seconds. + + Strategy: + 1. Try to find an available key immediately. + 2. If none found, wait for a short interval and retry (polling/queue simulation). + In a production system, we might use a real Redis list/pubsub for queuing, + but for now, we'll use a local async loop with exponential backoff or simple polling. + + Returns: (key, error_reason) + + Note: If model_name is explicitly specified (force_select=True), the status check + for DISABLED keys is bypassed, allowing access to disabled models. + """ + start_time = time.time() + # If model_name is explicitly specified, force select (bypass enabled check) + force_select = bool(model_name) + + while True: + # 1. Try to select a key + key, error = await self._try_select_key_once(model_name, force_select=force_select) + if key: + return key, "" + + # 2. Check timeout + if time.time() - start_time > timeout: + return None, f"Timeout waiting for available key: {error}" + + # 3. Wait before retrying (simple polling). + await asyncio.sleep(0.5) + + async def _try_select_key_once(self, model_name: Optional[str] = None, force_select: bool = False) -> tuple[Optional[APIKey], str]: + """ + Internal method to try selecting a key once without waiting. + Strategy: + 1. Filter by status and hard limits. + 2. Prefer keys that never hit 429. + 3. Prefer providers by static rank: modelscope > nvidia > siliconflow. + 4. For keys that did hit 429, use the demotion order as a queue. + 5. Use load as the final tie breaker inside the same priority tier. + + Args: + model_name: Optional model name to filter by + force_select: If True, bypass the status check (allows accessing disabled models) + + Returns: (key, error_reason) + """ + candidates = self.key_manager.get_candidate_keys(model_name) + if not candidates: + return None, "No keys configured for this model" if model_name else "No keys configured in the pool" + + # Prepare keys for mget + concurrency_keys = [f"keypool:concurrency:{k.id}" for k in candidates] + penalty_keys = [self._get_rate_limit_penalty_key(k.id) for k in candidates] + + now = time.time() + current_minute = int(now / 60) + prev_minute = current_minute - 1 + + rpm_curr_keys = [f"keypool:usage:rpm:{k.id}:{current_minute}" for k in candidates] + rpm_prev_keys = [f"keypool:usage:rpm:{k.id}:{prev_minute}" for k in candidates] + + tpm_curr_keys = [f"keypool:usage:tpm:{k.id}:{current_minute}" for k in candidates] + tpm_prev_keys = [f"keypool:usage:tpm:{k.id}:{prev_minute}" for k in candidates] + + current_day = self._get_day_str() + rpd_keys = [f"keypool:usage:rpd:{k.id}:{current_day}" for k in candidates] + + all_keys = concurrency_keys + penalty_keys + rpm_curr_keys + rpm_prev_keys + tpm_curr_keys + tpm_prev_keys + rpd_keys + if not all_keys: + return None, "No keys available (empty stats)" + + values = await self.redis.mget(all_keys) + + n = len(candidates) + concurrency_values = values[:n] + penalty_values = values[n:2*n] + rpm_curr_values = values[2*n:3*n] + rpm_prev_values = values[3*n:4*n] + tpm_curr_values = values[4*n:5*n] + tpm_prev_values = values[5*n:6*n] + rpd_values = values[6*n:7*n] + + # Calculate window weight + # If we are at second 0 of the minute, weight of prev is 1.0 (actually 0.99...) + # If we are at second 59, weight of prev is ~0 + # Formula: weight = (60 - seconds_elapsed) / 60 + seconds_elapsed = now % 60 + prev_weight = (60 - seconds_elapsed) / 60 + + available_keys = [] + rejection_reasons = [] + + for i, key in enumerate(candidates): + current_conc = int(concurrency_values[i]) if concurrency_values[i] else 0 + rate_limit_penalty = int(penalty_values[i]) if penalty_values[i] else 0 + + r_curr = int(rpm_curr_values[i]) if rpm_curr_values[i] else 0 + r_prev = int(rpm_prev_values[i]) if rpm_prev_values[i] else 0 + + t_curr = int(tpm_curr_values[i]) if tpm_curr_values[i] else 0 + t_prev = int(tpm_prev_values[i]) if tpm_prev_values[i] else 0 + + current_rpd = int(rpd_values[i]) if rpd_values[i] else 0 + + current_rpm = int(r_curr + r_prev * prev_weight) + current_tpm = int(t_curr + t_prev * prev_weight) + + key.current_concurrency = current_conc + key.rate_limit_penalty = rate_limit_penalty + key.current_rpm = current_rpm + key.current_tpm = current_tpm + key.current_rpd = current_rpd + + if key.status != KeyStatus.ACTIVE: + # If force_select is True (explicit model specified), bypass status check + if not force_select: + rejection_reasons.append(f"{key.id}({key.owner}): Inactive (status={key.status})") + continue + + if key.current_concurrency >= key.max_concurrency: + rejection_reasons.append(f"{key.id}({key.owner}): Concurrency limit ({key.current_concurrency}/{key.max_concurrency})") + continue + if current_rpm >= key.rpm_limit: + rejection_reasons.append(f"{key.id}({key.owner}): RPM limit ({current_rpm}/{key.rpm_limit})") + continue + if current_tpm >= key.tpm_limit: + rejection_reasons.append(f"{key.id}({key.owner}): TPM limit ({current_tpm}/{key.tpm_limit})") + continue + if key.rpd_limit > 0 and current_rpd >= key.rpd_limit: + rejection_reasons.append(f"{key.id}({key.owner}): RPD limit ({current_rpd}/{key.rpd_limit})") + continue + + available_keys.append(key) + + if not available_keys: + reason = "; ".join(rejection_reasons) if rejection_reasons else "No available keys found" + return None, f"All keys unavailable: {reason}" + + def calculate_load_score(k: APIKey) -> float: + conc_ratio = k.current_concurrency / k.max_concurrency if k.max_concurrency > 0 else 0 + rpm_ratio = k.current_rpm / k.rpm_limit if k.rpm_limit > 0 else 0 + tpm_ratio = k.current_tpm / k.tpm_limit if k.tpm_limit > 0 else 0 + rpd_ratio = k.current_rpd / k.rpd_limit if k.rpd_limit > 0 else 0 + return max(conc_ratio, rpm_ratio, tpm_ratio, rpd_ratio) + + def selection_key(k: APIKey) -> tuple[int, int, int, float]: + provider_priority = self._get_provider_priority(k.provider) + if k.rate_limit_penalty > 0: + return (1, k.rate_limit_penalty, provider_priority, calculate_load_score(k)) + return (0, provider_priority, 0, calculate_load_score(k)) + + best_key = min(available_keys, key=selection_key) + + return best_key, "" + + async def acquire_lease(self, key: APIKey) -> bool: + """ + Increments concurrency, RPM and RPD for the key. + Double check limit in Redis to be safe (race conditions). + """ + key_concurrency_key = f"keypool:concurrency:{key.id}" + + now = time.time() + current_minute = int(now / 60) + prev_minute = current_minute - 1 + current_day = self._get_day_str() + + key_rpm_key = f"keypool:usage:rpm:{key.id}:{current_minute}" + key_rpm_prev_key = f"keypool:usage:rpm:{key.id}:{prev_minute}" + key_rpd_key = f"keypool:usage:rpd:{key.id}:{current_day}" + + pipe = self.redis.pipeline() + pipe.incr(key_concurrency_key) + pipe.incr(key_rpm_key) + pipe.expire(key_rpm_key, 65) + pipe.get(key_rpm_prev_key) + pipe.incr(key_rpd_key) + # Keep daily counters for two days so the dashboard can read yesterday's RPD. + pipe.expire(key_rpd_key, self.DAILY_RPD_TTL_SECONDS) + + results = await pipe.execute() + curr_conc = results[0] + curr_rpm_val = results[1] + prev_rpm_val = int(results[3]) if results[3] else 0 + curr_rpd_val = results[4] + + seconds_elapsed = now % 60 + prev_weight = (60 - seconds_elapsed) / 60 + estimated_rpm = curr_rpm_val + (prev_rpm_val * prev_weight) + + if curr_conc > key.max_concurrency: + await self.redis.decr(key_concurrency_key) + await self.redis.decr(key_rpm_key) + await self.redis.decr(key_rpd_key) + return False + + if estimated_rpm > key.rpm_limit: + await self.redis.decr(key_concurrency_key) + await self.redis.decr(key_rpm_key) + await self.redis.decr(key_rpd_key) + return False + + if key.rpd_limit > 0 and curr_rpd_val > key.rpd_limit: + await self.redis.decr(key_concurrency_key) + await self.redis.decr(key_rpm_key) + await self.redis.decr(key_rpd_key) + return False + + return True + + async def release_lease(self, key: APIKey): + """Decrements concurrency.""" + key_concurrency_key = f"keypool:concurrency:{key.id}" + await self.redis.decr(key_concurrency_key) + + async def record_usage(self, key: APIKey, tokens: int): + """Records Token usage (TPM).""" + if tokens <= 0: + return + + current_minute = int(time.time() / 60) + key_tpm_key = f"keypool:usage:tpm:{key.id}:{current_minute}" + + pipe = self.redis.pipeline() + pipe.incrby(key_tpm_key, tokens) + pipe.expire(key_tpm_key, 65) + await pipe.execute() + + + async def report_failure(self, key: APIKey, is_rate_limit: bool = False, error_message: str = ""): + """ + Reports a failure. A 429 moves the key to the back of the priority queue. + """ + del error_message + if is_rate_limit: + seq = await self.redis.incr(self.RATE_LIMIT_PENALTY_SEQ_KEY) + await self.redis.set(self._get_rate_limit_penalty_key(key.id), seq) + + async def get_all_key_stats(self) -> List[Dict[str, Any]]: + """ + Retrieves statistics for all keys, including status, concurrency, and usage limits. + """ + keys = self.key_manager.get_all_keys() + if not keys: + return [] + + # Prepare keys for batch fetching from Redis + now = time.time() + current_minute = int(now / 60) + prev_minute = current_minute - 1 + current_day = self._get_day_str() + + redis_keys = [] + for key in keys: + redis_keys.extend([ + f"keypool:concurrency:{key.id}", + self._get_rate_limit_penalty_key(key.id), + f"keypool:usage:rpm:{key.id}:{current_minute}", + f"keypool:usage:rpm:{key.id}:{prev_minute}", + f"keypool:usage:tpm:{key.id}:{current_minute}", + f"keypool:usage:tpm:{key.id}:{prev_minute}", + f"keypool:usage:rpd:{key.id}:{current_day}" + ]) + + if not redis_keys: + return [] + + # Fetch all values in a single round-trip + values = await self.redis.mget(redis_keys) + + stats = [] + + # Process results in chunks of 7 + seconds_elapsed = now % 60 + prev_weight = (60 - seconds_elapsed) / 60 + + for i, key in enumerate(keys): + base_idx = i * 7 + + curr_conc = int(values[base_idx]) if values[base_idx] else 0 + rate_limit_penalty = int(values[base_idx + 1]) if values[base_idx + 1] else 0 + + rpm_curr = int(values[base_idx + 2]) if values[base_idx + 2] else 0 + rpm_prev = int(values[base_idx + 3]) if values[base_idx + 3] else 0 + + tpm_curr = int(values[base_idx + 4]) if values[base_idx + 4] else 0 + tpm_prev = int(values[base_idx + 5]) if values[base_idx + 5] else 0 + + rpd_curr = int(values[base_idx + 6]) if values[base_idx + 6] else 0 + + estimated_rpm = int(rpm_curr + rpm_prev * prev_weight) + estimated_tpm = int(tpm_curr + tpm_prev * prev_weight) + + # Determine effective status + status = key.status + if key.status == KeyStatus.ACTIVE and rate_limit_penalty > 0: + status = KeyStatus.COOLDOWN + + key_stat = { + "id": key.id, + "model_name": key.model_name, + "provider": key.provider, + "status": status, + "enabled": key.enabled, + "config_id": getattr(key, 'config_id', None), + "owner": getattr(key, 'owner', None), + "endpoint_idx": getattr(key, 'endpoint_idx', None), + "api_base": getattr(key, 'api_base', None), + "limits": { + "rpm": key.rpm_limit, + "tpm": key.tpm_limit, + "rpd": key.rpd_limit, + "max_concurrency": key.max_concurrency + }, + "usage": { + "current_concurrency": curr_conc, + "current_rpm": estimated_rpm, + "current_tpm": estimated_tpm, + "current_rpd": rpd_curr + }, + "cooldown_remaining": 0 + } + stats.append(key_stat) + + return stats + + async def get_usage_summary(self) -> Dict[str, Any]: + """ + Get overall usage summary across all keys. + Returns total RPM, TPM, RPD, concurrency, and active keys count. + """ + keys = self.key_manager.get_all_keys() + if not keys: + return { + "total_rpm": 0, + "total_tpm": 0, + "total_rpd": 0, + "total_yesterday_rpd": 0, + "total_concurrency": 0, + "active_keys": 0, + "cooldown_keys": 0, + "disabled_keys": 0, + "total_keys": 0 + } + + now = time.time() + current_minute = int(now / 60) + prev_minute = current_minute - 1 + current_day = self._get_day_str() + yesterday_day = self._get_day_str(-1) + + redis_keys = [] + for key in keys: + redis_keys.extend([ + f"keypool:concurrency:{key.id}", + self._get_rate_limit_penalty_key(key.id), + f"keypool:usage:rpm:{key.id}:{current_minute}", + f"keypool:usage:rpm:{key.id}:{prev_minute}", + f"keypool:usage:tpm:{key.id}:{current_minute}", + f"keypool:usage:tpm:{key.id}:{prev_minute}", + f"keypool:usage:rpd:{key.id}:{current_day}", + f"keypool:usage:rpd:{key.id}:{yesterday_day}" + ]) + + if not redis_keys: + return { + "total_rpm": 0, + "total_tpm": 0, + "total_rpd": 0, + "total_yesterday_rpd": 0, + "total_concurrency": 0, + "active_keys": 0, + "cooldown_keys": 0, + "disabled_keys": 0, + "total_keys": 0 + } + + values = await self.redis.mget(redis_keys) + + seconds_elapsed = now % 60 + prev_weight = (60 - seconds_elapsed) / 60 + + total_rpm = 0 + total_tpm = 0 + total_rpd = 0 + total_yesterday_rpd = 0 + total_concurrency = 0 + active_keys = 0 + cooldown_keys = 0 + disabled_keys = 0 + + for i, key in enumerate(keys): + base_idx = i * 8 + + curr_conc = int(values[base_idx]) if values[base_idx] else 0 + rate_limit_penalty = int(values[base_idx + 1]) if values[base_idx + 1] else 0 + + rpm_curr = int(values[base_idx + 2]) if values[base_idx + 2] else 0 + rpm_prev = int(values[base_idx + 3]) if values[base_idx + 3] else 0 + tpm_curr = int(values[base_idx + 4]) if values[base_idx + 4] else 0 + tpm_prev = int(values[base_idx + 5]) if values[base_idx + 5] else 0 + rpd_curr = int(values[base_idx + 6]) if values[base_idx + 6] else 0 + rpd_yesterday = int(values[base_idx + 7]) if values[base_idx + 7] else 0 + + estimated_rpm = int(rpm_curr + rpm_prev * prev_weight) + estimated_tpm = int(tpm_curr + tpm_prev * prev_weight) + + total_rpm += estimated_rpm + total_tpm += estimated_tpm + total_rpd += rpd_curr + total_yesterday_rpd += rpd_yesterday + total_concurrency += curr_conc + + if key.status == KeyStatus.ACTIVE and rate_limit_penalty == 0: + active_keys += 1 + elif key.status == KeyStatus.DISABLED: + disabled_keys += 1 + elif rate_limit_penalty > 0: + cooldown_keys += 1 + else: + active_keys += 1 + + return { + "total_rpm": total_rpm, + "total_tpm": total_tpm, + "total_rpd": total_rpd, + "total_yesterday_rpd": total_yesterday_rpd, + "total_concurrency": total_concurrency, + "active_keys": active_keys, + "cooldown_keys": cooldown_keys, + "disabled_keys": disabled_keys, + "total_keys": len(keys) + } + + async def get_daily_usage_history(self, days: int = 7) -> List[Dict[str, Any]]: + """ + Aggregate total RPD usage for the most recent N days. + Uses the same UTC day buckets as runtime RPD limiting. + """ + keys = self.key_manager.get_all_keys() + if not keys or days <= 0: + return [] + + day_labels = [self._get_day_str(-offset) for offset in range(days - 1, -1, -1)] + + redis_keys = [] + for day_str in day_labels: + for key in keys: + redis_keys.append(f"keypool:usage:rpd:{key.id}:{day_str}") + + values = await self.redis.mget(redis_keys) + + history = [] + key_count = len(keys) + for day_index, day_str in enumerate(day_labels): + base_idx = day_index * key_count + day_total = 0 + for key_offset in range(key_count): + value = values[base_idx + key_offset] + day_total += int(value) if value else 0 + + history.append({ + "day": day_str, + "label": f"{day_str[4:6]}-{day_str[6:8]}", + "total_rpd": day_total, + }) + + return history diff --git a/app/core/key_manager.py b/app/core/key_manager.py new file mode 100755 index 0000000..18f7822 --- /dev/null +++ b/app/core/key_manager.py @@ -0,0 +1,294 @@ +import yaml +import time +from typing import List, Dict, Optional, Set, Any, Tuple +from app.models import APIKey, KeyStatus + +class KeyManager: + def __init__(self, config_path: str): + self.config_path = config_path + self._keys: Dict[str, APIKey] = {} + self._keys_by_model: Dict[str, List[APIKey]] = {} + self._config_ids: Set[str] = set() + self._last_load_time: float = 0 + self._raw_config: Dict[str, Any] = {} + self._config_enabled_map: Dict[str, bool] = {} + # Track endpoint enable status: {(config_id, idx): enabled} + self._endpoint_enabled_map: Dict[Tuple[str, int], bool] = {} + # Track endpoint URLs: {(config_id, idx): url} + self._endpoint_urls: Dict[Tuple[str, int], str] = {} + self.load_keys() + + def load_keys(self): + """Loads keys from the YAML configuration file.""" + with open(self.config_path, 'r', encoding='utf-8') as f: + data = yaml.safe_load(f) + + self._raw_config = data + self._keys = {} + self._keys_by_model = {} + self._config_ids = set() + self._config_enabled_map = {} + + configs = {} + for cfg in data.get('model_configs', []): + if 'id' in cfg: + configs[cfg['id']] = cfg + + for key_group in data.get('keys', []): + config_id = key_group.get('config_id') + group_enabled = key_group.get('enabled', True) + keys_list = key_group.get('keys', []) + + if config_id: + self._config_ids.add(config_id) + self._config_enabled_map[config_id] = group_enabled + + if config_id and config_id in configs: + cfg = configs[config_id] + cfg_copy = cfg.copy() + cfg_copy.pop('id', None) + + # Check if this config has multiple endpoints (for local llama.cpp load balancing) + api_endpoints = cfg_copy.get('api_endpoints', []) + + if api_endpoints and len(api_endpoints) > 0: + # Generate virtual keys for each endpoint + for idx, endpoint_item in enumerate(api_endpoints): + # Handle both old format (string URL) and new format (dict with url/enabled) + if isinstance(endpoint_item, dict): + endpoint_url = endpoint_item.get('url', '') + endpoint_enabled = endpoint_item.get('enabled', True) + else: + endpoint_url = endpoint_item + endpoint_enabled = True + + # Store endpoint info for dashboard + self._endpoint_enabled_map[(config_id, idx)] = endpoint_enabled + self._endpoint_urls[(config_id, idx)] = endpoint_url + + final_data = {} + final_data.update(cfg_copy) + final_data.pop('api_endpoints', None) # Remove the list from individual key + final_data['config_id'] = config_id + # Endpoint is enabled only if both group and endpoint itself are enabled + final_data['enabled'] = group_enabled and endpoint_enabled + final_data['api_base'] = endpoint_url # Use individual endpoint + final_data['key'] = f"local-{config_id}-{idx}" # Virtual key for local service + final_data['endpoint_idx'] = idx # Track endpoint index + + # Virtual key ID that includes endpoint info for tracking + endpoint_hash = hash(endpoint_url) & 0xFFFF + key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}" + final_data['id'] = key_id + + if not group_enabled or not endpoint_enabled: + final_data['status'] = KeyStatus.DISABLED + + try: + key = APIKey(**final_data) + self._keys[key.id] = key + + if key.model_name not in self._keys_by_model: + self._keys_by_model[key.model_name] = [] + self._keys_by_model[key.model_name].append(key) + except Exception as e: + print(f"Error loading virtual key {key_id} for endpoint {endpoint_url}: {e}") + else: + # Standard key loading (original logic) + for key_data in keys_list: + final_data = {} + final_data.update(cfg_copy) + final_data['config_id'] = config_id + final_data['enabled'] = group_enabled + + if key_data.get('owner'): + final_data['owner'] = key_data['owner'] + + final_data['key'] = key_data['key'] + + if not group_enabled: + final_data['status'] = KeyStatus.DISABLED + + key_id = f"{config_id}:{hash(key_data['key']) & 0xFFFFFFFF}" + final_data['id'] = key_id + + try: + key = APIKey(**final_data) + self._keys[key.id] = key + + if key.model_name not in self._keys_by_model: + self._keys_by_model[key.model_name] = [] + self._keys_by_model[key.model_name].append(key) + except Exception as e: + print(f"Error loading key {key_id}: {e}") + else: + for key_data in keys_list: + final_data = {} + final_data.update(key_group) + final_data.pop('keys', None) + final_data.update(key_data) + + if 'id' not in final_data: + final_data['id'] = key_data.get('key', 'unknown')[:16] + + if not final_data.get('enabled', True): + final_data['status'] = KeyStatus.DISABLED + + try: + key = APIKey(**final_data) + self._keys[key.id] = key + + if key.model_name not in self._keys_by_model: + self._keys_by_model[key.model_name] = [] + self._keys_by_model[key.model_name].append(key) + except Exception as e: + print(f"Error loading key {final_data.get('id', 'unknown')}: {e}") + + self._last_load_time = time.time() + + def _save_config(self) -> bool: + """Save current configuration to YAML file.""" + try: + for key_group in self._raw_config.get('keys', []): + config_id = key_group.get('config_id') + if config_id and config_id in self._config_enabled_map: + key_group['enabled'] = self._config_enabled_map[config_id] + + # Save endpoint enable status + for (cfg_id, idx), enabled in self._endpoint_enabled_map.items(): + for cfg in self._raw_config.get('model_configs', []): + if cfg.get('id') == cfg_id: + endpoints = cfg.get('api_endpoints', []) + if idx < len(endpoints): + if isinstance(endpoints[idx], dict): + endpoints[idx]['enabled'] = enabled + else: + # Convert string to dict with url and enabled + endpoints[idx] = {'url': endpoints[idx], 'enabled': enabled} + + with open(self.config_path, 'w', encoding='utf-8') as f: + yaml.dump(self._raw_config, f, default_flow_style=False, allow_unicode=True, sort_keys=False) + + return True + except Exception as e: + print(f"Error saving config: {e}") + return False + + def reload_keys(self) -> bool: + """Reload keys from config file. Returns True if successful.""" + try: + self.load_keys() + return True + except Exception as e: + print(f"Error reloading keys: {e}") + return False + + def set_config_enabled(self, config_id: str, enabled: bool) -> bool: + """Enable or disable all keys for a config_id and save to file.""" + if config_id not in self._config_ids: + return False + + self._config_enabled_map[config_id] = enabled + + updated = False + for key in self._keys.values(): + if hasattr(key, 'config_id') and key.config_id == config_id: + key.enabled = enabled + if enabled: + key.status = KeyStatus.ACTIVE + else: + key.status = KeyStatus.DISABLED + updated = True + + if updated: + self._save_config() + + return updated + + def set_key_enabled(self, key_id: str, enabled: bool) -> bool: + """Enable or disable a specific key.""" + if key_id not in self._keys: + return False + + key = self._keys[key_id] + key.enabled = enabled + if enabled: + key.status = KeyStatus.ACTIVE + else: + key.status = KeyStatus.DISABLED + + return True + + def get_config_ids(self) -> Set[str]: + """Get all config IDs.""" + return self._config_ids.copy() + + def get_keys_by_config(self, config_id: str) -> List[APIKey]: + """Get all keys for a specific config_id.""" + return [k for k in self._keys.values() if hasattr(k, 'config_id') and k.config_id == config_id] + + def get_key(self, key_id: str) -> Optional[APIKey]: + return self._keys.get(key_id) + + def get_candidate_keys(self, model_name: Optional[str] = None) -> List[APIKey]: + """Returns all keys for a given model, regardless of status. If model_name is None, returns all keys.""" + if not model_name: + return self.get_all_keys() + return self._keys_by_model.get(model_name, []) + + def get_all_keys(self) -> List[APIKey]: + return list(self._keys.values()) + + def get_last_load_time(self) -> float: + return self._last_load_time + + def get_config_enabled(self, config_id: str) -> bool: + """Get enabled status for a config_id.""" + return self._config_enabled_map.get(config_id, True) + + def get_endpoint_info(self, config_id: str) -> List[Dict[str, Any]]: + """Get all endpoint info for a config_id.""" + endpoints = [] + for (cfg_id, idx), url in self._endpoint_urls.items(): + if cfg_id == config_id: + enabled = self._endpoint_enabled_map.get((cfg_id, idx), True) + endpoint_hash = hash(url) & 0xFFFF + key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}" + endpoints.append({ + 'idx': idx, + 'url': url, + 'enabled': enabled, + 'key_id': key_id + }) + return sorted(endpoints, key=lambda x: x['idx']) + + def set_endpoint_enabled(self, config_id: str, endpoint_idx: int, enabled: bool) -> bool: + """Enable or disable a specific endpoint and save to file.""" + if (config_id, endpoint_idx) not in self._endpoint_enabled_map: + return False + + self._endpoint_enabled_map[(config_id, endpoint_idx)] = enabled + + # Update the corresponding key's status + url = self._endpoint_urls.get((config_id, endpoint_idx), '') + endpoint_hash = hash(url) & 0xFFFF + key_id = f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}" + + if key_id in self._keys: + key = self._keys[key_id] + key.enabled = enabled + if enabled: + key.status = KeyStatus.ACTIVE + else: + key.status = KeyStatus.DISABLED + + self._save_config() + return True + + def get_endpoint_key_id(self, config_id: str, endpoint_idx: int) -> Optional[str]: + """Get the key_id for a specific endpoint.""" + url = self._endpoint_urls.get((config_id, endpoint_idx)) + if url: + endpoint_hash = hash(url) & 0xFFFF + return f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}" + return None diff --git a/app/core/providers.py b/app/core/providers.py new file mode 100755 index 0000000..3bb35cb --- /dev/null +++ b/app/core/providers.py @@ -0,0 +1,311 @@ +import re +from typing import Dict, Any, List, Optional, Tuple, Union +from app.models import APIKey + +class BaseProvider: + def __init__(self, api_key: APIKey): + self.api_key = api_key + + def get_headers(self) -> Dict[str, str]: + raise NotImplementedError + + def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]: + raise NotImplementedError + + def get_url(self) -> str: + raise NotImplementedError + + def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Parses provider response into a standardized format. + Returns: + { + "content": str, + "reasoning_content": Optional[str], + "usage": { + "input_tokens": int, + "output_tokens": int, + "total_tokens": int + } + } + """ + raise NotImplementedError + +class OpenAIProvider(BaseProvider): + def get_headers(self) -> Dict[str, str]: + return { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key.key}" + } + + def get_url(self) -> str: + base = self.api_key.api_base or "https://api.openai.com/v1" + return f"{base.rstrip('/')}/chat/completions" + + def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]: + data = {"messages": messages, "model": self.api_key.model_name, "stream": False} + + # 1. Start with Key's default params + final_params = self.api_key.params.copy() + + # 2. Update with request-level params (filter None) + valid_request_params = {k: v for k, v in params.items() if v is not None} + final_params.update(valid_request_params) + + # 3. Add to payload + data.update(final_params) + + # 4. Handle extra_config from Key (e.g., Qwen extra_body) + if self.api_key.extra_config: + if "extra_body" in self.api_key.extra_config: + data.update(self.api_key.extra_config["extra_body"]) + + return data + + def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]: + try: + choice = data["choices"][0] + text = choice["message"]["content"] + reasoning = choice["message"].get("reasoning_content") or \ + choice.get("reasoning_content") + + usage = data.get("usage", {}) + return { + "content": text, + "reasoning_content": reasoning, + "usage": { + "input_tokens": usage.get("prompt_tokens", 0), + "output_tokens": usage.get("completion_tokens", 0), + "total_tokens": usage.get("total_tokens", 0) + } + } + except (KeyError, IndexError): + return {"content": "", "usage": {}} + +class AzureProvider(OpenAIProvider): + def get_headers(self) -> Dict[str, str]: + return { + "Content-Type": "application/json", + "api-key": self.api_key.key + } + + def get_url(self) -> str: + # Expected api_base: https://{resource}.openai.azure.com/openai/deployments/{deployment} + # We need to append api-version + base = self.api_key.api_base or "" + sep = "&" if "?" in base else "?" + + # Get api-version from extra_config or default + version = "2025-01-01-preview" + if self.api_key.extra_config and "api_version" in self.api_key.extra_config: + version = self.api_key.extra_config["api_version"] + + return f"{base}{sep}api-version={version}" + + def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]: + payload = super().get_payload(messages, params) + # Azure usually doesn't need 'model' in body if it's in URL (deployment) + # But keeping it usually doesn't hurt, unless configured to remove. + # For simplicity, we keep it or remove it if we want strict adherence to typical azure usage. + # Let's remove it to be safe as deployment is in URL. + payload.pop("model", None) + return payload + +class GeminiProvider(BaseProvider): + def get_headers(self) -> Dict[str, str]: + return {"Content-Type": "application/json"} + + def get_url(self) -> str: + # api_base example: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent + base = self.api_key.api_base or f"https://generativelanguage.googleapis.com/v1beta/models/{self.api_key.model_name}:generateContent" + sep = "&" if "?" in base else "?" + return f"{base}{sep}key={self.api_key.key}" + + def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]: + gemini_contents = [] + system_instruction = None + + for msg in messages: + role = "model" if msg["role"] == "assistant" else msg["role"] + content = msg["content"] + + if role == "system": + text_val = content if isinstance(content, str) else content[0]["text"] + system_instruction = {"parts": [{"text": text_val}]} + continue + + parts = [] + content_list = content if isinstance(content, list) else [{"type": "text", "text": content}] + + for item in content_list: + if item.get("type") == "text": + parts.append({"text": item["text"]}) + elif item.get("type") == "image_url": + url = item.get("image_url", {}).get("url", "") + match = re.match(r'data:(.*?);base64,(.*)', url) + if match: + parts.append({ + "inline_data": { + "mime_type": match.group(1), + "data": match.group(2) + } + }) + + if parts: + gemini_contents.append({"role": role, "parts": parts}) + + generation_config = {} + + # 1. Start with Key's default params (mapped to Gemini format if possible, or assume config.py had correct format) + # Note: config.py used "maxOutputTokens" (Gemini style) directly in 'params'. + # We should merge self.api_key.params into generation_config + generation_config.update(self.api_key.params) + + # 2. Update with request params (mapped from OpenAI style) + if "max_tokens" in params: + generation_config["maxOutputTokens"] = params["max_tokens"] + if "temperature" in params: + generation_config["temperature"] = params["temperature"] + if "top_p" in params: + generation_config["topP"] = params["top_p"] + + # 3. Handle thinking config from Key or Request + # Check Key's extra_config first + if self.api_key.extra_config and "thinking_config" in self.api_key.extra_config: + generation_config["thinkingConfig"] = self.api_key.extra_config["thinking_config"] + + # Override with request param if present + if "thinking_config" in params: + generation_config["thinkingConfig"] = params["thinking_config"] + + payload = { + "contents": gemini_contents, + "generationConfig": generation_config + } + + if system_instruction: + payload["system_instruction"] = system_instruction + + return payload + + def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]: + try: + candidate = data["candidates"][0] + parts_text = [] + if "content" in candidate and "parts" in candidate["content"]: + for part in candidate["content"]["parts"]: + if "text" in part: + parts_text.append(part["text"]) + + text = "\n".join(parts_text) if parts_text else "" + + reasoning = None + if "thinking_process" in candidate: + reasoning = candidate["thinking_process"] + + usage_metadata = data.get("usageMetadata", {}) + input_tokens = usage_metadata.get("promptTokenCount", 0) + output_tokens = usage_metadata.get("candidatesTokenCount", 0) + total_tokens = input_tokens + output_tokens + + return { + "content": text, + "reasoning_content": reasoning, + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens + } + } + except (KeyError, IndexError): + return {"content": "", "usage": {}} + +class ClaudeProvider(BaseProvider): + def get_headers(self) -> Dict[str, str]: + headers = { + "Content-Type": "application/json", + "anthropic-version": "2023-06-01", + "x-api-key": self.api_key.key + } + + # Check extra_config for version + if self.api_key.extra_config and "anthropic_version" in self.api_key.extra_config: + headers["anthropic-version"] = self.api_key.extra_config["anthropic_version"] + + return headers + + def get_url(self) -> str: + base = self.api_key.api_base or "https://api.anthropic.com/v1/messages" + return base + + def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]: + claude_msgs = [] + system_prompt = "" + + for msg in messages: + role = "assistant" if msg["role"] == "assistant" else msg["role"] + content = msg["content"] + + if role == "system": + system_prompt = content if isinstance(content, str) else content[0]["text"] + continue + + # Simplified content handling for text only for now, can expand if needed + claude_msgs.append({"role": role, "content": content}) + + # Start with Key's default params + payload = { + "model": self.api_key.model_name, + "messages": claude_msgs, + "max_tokens": 1024 # Default fallback + } + + # Merge Key params + payload.update(self.api_key.params) + + # Merge request params + if "max_tokens" in params: + payload["max_tokens"] = params["max_tokens"] + if "temperature" in params: + payload["temperature"] = params["temperature"] + + if system_prompt: + payload["system"] = system_prompt + + return payload + + def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]: + try: + content_list = data.get("content", []) + text = "" + for item in content_list: + if item.get("type") == "text": + text += item.get("text", "") + + usage = data.get("usage", {}) + input_tokens = usage.get("input_tokens", 0) + output_tokens = usage.get("output_tokens", 0) + + return { + "content": text, + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens + } + } + except Exception: + return {"content": "", "usage": {}} + +def get_provider_handler(api_key: APIKey) -> BaseProvider: + providers = { + "google": GeminiProvider, + "gemini": GeminiProvider, + "openai": OpenAIProvider, + "azure": AzureProvider, + "claude": ClaudeProvider, + "anthropic": ClaudeProvider, + "zhipu": OpenAIProvider # Zhipu is mostly OpenAI compatible + } + provider_class = providers.get(api_key.provider, OpenAIProvider) + return provider_class(api_key) diff --git a/app/core/stats_exporter.py b/app/core/stats_exporter.py new file mode 100755 index 0000000..7d63505 --- /dev/null +++ b/app/core/stats_exporter.py @@ -0,0 +1,1016 @@ +import asyncio +import time +from pathlib import Path +from typing import TYPE_CHECKING + +from app.core.stats_store import StatsStore + +if TYPE_CHECKING: + from app.core.dispatcher import Dispatcher + +class StatsExporter: + def __init__(self, dispatcher: "Dispatcher", output_dir: str = "stats", interval: int = 10): + self.dispatcher = dispatcher + self.output_dir = Path(output_dir) + self.interval = interval + self._running = False + self._task = None + self.output_dir.mkdir(parents=True, exist_ok=True) + self.store = StatsStore(self.output_dir / "stats.sqlite3") + + async def start(self): + self._running = True + self._task = asyncio.create_task(self._run_loop()) + + async def stop(self): + self._running = False + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + async def _run_loop(self): + while self._running: + try: + await self._export_stats() + except Exception as e: + print(f"Error exporting stats: {e}") + await asyncio.sleep(self.interval) + + async def export_once(self): + await self._export_stats() + + def clear_daily_usage_day(self, day: str): + self.store.delete_day(day) + + async def render_dashboard(self) -> str: + stats, usage_summary, daily_usage, config_stats = await self._collect_dashboard_data() + return self._generate_html(config_stats, stats, usage_summary, daily_usage) + + def _get_endpoint_status_for_config(self, config_id: str, keys: list) -> list: + """Get endpoint status from key_manager for a config.""" + return self.dispatcher.key_manager.get_endpoint_info(config_id) + + async def _export_stats(self): + await self._collect_dashboard_data() + + async def _collect_dashboard_data(self) -> tuple[list, dict, list, dict]: + stats = await self.dispatcher.get_all_key_stats() + usage_summary = await self.dispatcher.get_usage_summary() + live_daily_usage = await self.dispatcher.get_daily_usage_history(days=7) + self.store.save_export(stats, usage_summary, live_daily_usage) + daily_usage = self.store.get_daily_usage(days=7) or live_daily_usage + config_stats = self._group_by_config(stats) + return stats, usage_summary, daily_usage, config_stats + + def _group_by_config(self, stats: list) -> dict: + grouped = {} + for s in stats: + config_id = s.get('config_id', 'default') + if config_id not in grouped: + grouped[config_id] = { + 'config_id': config_id, + 'model_name': s['model_name'], + 'provider': s['provider'], + 'enabled': s['enabled'], + 'keys': [], + 'endpoints': [], + 'total_concurrency': 0, + 'total_rpm': 0, + 'total_tpm': 0, + 'total_rpd': 0, + 'active_count': 0, + 'cooldown_count': 0, + 'disabled_count': 0, + } + grouped[config_id]['keys'].append(s) + grouped[config_id]['total_concurrency'] += s['usage']['current_concurrency'] + grouped[config_id]['total_rpm'] += s['usage']['current_rpm'] + grouped[config_id]['total_tpm'] += s['usage']['current_tpm'] + grouped[config_id]['total_rpd'] += s['usage']['current_rpd'] + if s['status'] == 'active': + grouped[config_id]['active_count'] += 1 + elif s['status'] == 'cooldown': + grouped[config_id]['cooldown_count'] += 1 + else: + grouped[config_id]['disabled_count'] += 1 + # Collect endpoint info from key data + endpoint_idx = s.get('endpoint_idx') + if endpoint_idx is not None: + endpoint = { + 'idx': endpoint_idx, + 'key_id': s['id'], + 'status': s['status'], + 'enabled': s['enabled'], + 'api_base': s.get('api_base', ''), + 'usage': s['usage'], + 'limits': s['limits'], + } + # Avoid duplicates + if not any(e['idx'] == endpoint_idx for e in grouped[config_id]['endpoints']): + grouped[config_id]['endpoints'].append(endpoint) + # Sort endpoints by idx + for cfg in grouped.values(): + cfg['endpoints'].sort(key=lambda x: x['idx']) + return grouped + + def _generate_html(self, config_stats: dict, all_stats: list, usage_summary: dict, daily_usage: list) -> str: + now_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + daily_usage_section = self._generate_daily_usage_section(daily_usage) + + config_rows = "" + for config_id, cfg in config_stats.items(): + status_color = "#10b981" if cfg['enabled'] else "#64748b" + toggle_text = "禁用" if cfg['enabled'] else "启用" + toggle_class = "btn-disable" if cfg['enabled'] else "btn-enable" + + active_badge = f"{cfg['active_count']}" if cfg['active_count'] > 0 else "" + cooldown_badge = f"{cfg['cooldown_count']}" if cfg['cooldown_count'] > 0 else "" + disabled_badge = f"{cfg['disabled_count']}" if cfg['disabled_count'] > 0 else "" + + config_rows += f""" + + + {config_id[:25]}... + + {cfg['model_name'][:30]} + {cfg['provider']} + + + {'ON' if cfg['enabled'] else 'OFF'} + + + {cfg['total_concurrency']} + {cfg['total_rpm']} + {self._format_number(cfg['total_tpm'])} + {cfg['total_rpd']} + {len(cfg['keys'])} {active_badge}{cooldown_badge}{disabled_badge} + + + + + + + +
+ {self._generate_endpoints_section(cfg)} + + + + + + + + + + + + + + + {self._generate_key_rows(cfg['keys'])} + +
Key IDOwnerStatusConcurrencyRPMTPMRPD429
+
+ + """ + + return f""" + + + + + + KeyPool Dashboard + + + +
+
+
+
+

KeyPool Dashboard

+
Last updated: {now_str} | Auto-refresh: 10s
+
+ +
+
+ + {daily_usage_section} + + {self._generate_usage_summary_cards(usage_summary)} + +
+ + + + + + + + + + + + + + + + + {config_rows} + +
Config IDModelProviderStatusConcRPMTPMRPDKeysActions
+
+ + +
+ +
+ + + +""" + + def _generate_daily_usage_section(self, daily_usage: list) -> str: + if not daily_usage: + return """ +
+
+
+

Daily Usage

+
+
+
""" + + totals = [item.get("total_rpd", 0) for item in daily_usage] + total_7d = sum(totals) + avg_7d = int(total_7d / len(totals)) if totals else 0 + peak_7d = max(totals) if totals else 0 + + columns = "".join(self._generate_daily_usage_bar(item, peak_7d) for item in daily_usage) + + return f""" +
+
+
+

Daily Usage

+
+
+
+ 7D Total + {self._format_number(total_7d)} +
+
+ 7D Average + {self._format_number(avg_7d)} +
+
+ Peak Day + {self._format_number(peak_7d)} +
+
+
+
+ {columns} +
+
""" + + def _generate_daily_usage_bar(self, item: dict, peak_value: int) -> str: + value = item.get("total_rpd", 0) + peak = max(peak_value, 1) + height_pct = max(8, int(value / peak * 100)) if value > 0 else 8 + bar_class = "daily-bar" if value > 0 else "daily-bar is-zero" + title = f"{item.get('day', item.get('label', ''))}: {value}" + + return f""" +
+
{self._format_number(value)}
+
+
+
+
{item.get('label', '-')}
+
""" + + def _generate_usage_summary_cards(self, usage_summary: dict) -> str: + cards = [ + ("total", "K", usage_summary["total_keys"], "Total Keys"), + ("active", "A", usage_summary["active_keys"], "Active"), + ("cooldown", "C", usage_summary["cooldown_keys"], "Cooldown"), + ("disabled", "D", usage_summary["disabled_keys"], "Disabled"), + ("concurrency", "Q", usage_summary["total_concurrency"], "Concurrency"), + ("rpm", "R", usage_summary["total_rpm"], "RPM"), + ("tpm", "T", self._format_number(usage_summary["total_tpm"]), "TPM"), + ("rpd", "D", usage_summary["total_rpd"], "RPD"), + ("rpd", "Y", usage_summary.get("total_yesterday_rpd", 0), "Yesterday RPD"), + ] + + card_html = "".join( + f""" +
+
{icon}
+
{value}
+
""" + for card_class, icon, value, label in cards + ) + + return f""" +
+ {card_html} +
""" + + def _generate_endpoints_section(self, cfg: dict) -> str: + """Generate endpoint management section for configs with multiple endpoints.""" + endpoints = cfg.get('endpoints', []) + if not endpoints: + return "" + + config_id = cfg['config_id'] + rows = "" + for ep in endpoints: + status_color = "#4ade80" if ep['status'] == 'active' else "#6b7280" + toggle_text = "禁用" if ep['enabled'] else "启用" + toggle_class = "btn-disable" if ep['enabled'] else "btn-enable" + + # Get IP from URL for display + api_base = ep.get('api_base', '') + url_display = api_base.replace('http://', '').replace('https://', '')[:30] + + conc_pct = (ep['usage']['current_concurrency'] / ep['limits']['max_concurrency'] * 100) if ep['limits']['max_concurrency'] > 0 else 0 + rpm_pct = (ep['usage']['current_rpm'] / ep['limits']['rpm'] * 100) if ep['limits']['rpm'] > 0 else 0 + + rows += f""" + + EP-{ep['idx']} + {url_display} + {ep['status']} + {self._progress_bar(conc_pct, ep['usage']['current_concurrency'], ep['limits']['max_concurrency'])} + {self._progress_bar(rpm_pct, ep['usage']['current_rpm'], ep['limits']['rpm'])} + + + + """ + + return f""" +
+

+ 本地节点 (Endpoints) +

+ + + + + + + + + + + + + {rows} + +
NodeURLStatusConcurrencyRPMAction
+
""" + + def _generate_key_rows(self, keys: list) -> str: + rows = "" + for k in keys: + status_color = self._get_status_color(k["status"]) + + conc_pct = (k["usage"]["current_concurrency"] / k["limits"]["max_concurrency"] * 100) if k["limits"]["max_concurrency"] > 0 else 0 + rpm_pct = (k["usage"]["current_rpm"] / k["limits"]["rpm"] * 100) if k["limits"]["rpm"] > 0 else 0 + tpm_pct = (k["usage"]["current_tpm"] / k["limits"]["tpm"] * 100) if k["limits"]["tpm"] > 0 else 0 + rpd_pct = (k["usage"]["current_rpd"] / k["limits"]["rpd"] * 100) if k["limits"]["rpd"] > 0 else 0 + + cooldown_str = "429" if k["status"] == "cooldown" else "-" + + rows += f""" + + {k['id'][:20]}... + {k.get('owner', '-')} + {k['status']} + {self._progress_bar(conc_pct, k['usage']['current_concurrency'], k['limits']['max_concurrency'])} + {self._progress_bar(rpm_pct, k['usage']['current_rpm'], k['limits']['rpm'])} + {self._progress_bar(tpm_pct, k['usage']['current_tpm'], k['limits']['tpm'])} + {self._progress_bar(rpd_pct, k['usage']['current_rpd'], k['limits']['rpd'])} + {cooldown_str} + """ + return rows + + def _get_status_color(self, status: str) -> str: + colors = { + "active": "#4ade80", + "cooldown": "#f97316", + "disabled": "#6b7280", + "error": "#ef4444", + } + return colors.get(status.lower(), "#6b7280") + + def _progress_bar(self, percentage: float, current: int, limit: int) -> str: + if limit == 0: + return f"N/A" + + percentage = min(100, max(0, percentage)) + + if percentage < 50: + color = "#4ade80" + elif percentage < 80: + color = "#fbbf24" + else: + color = "#ef4444" + + return f""" +
+
+
+
+ {current}/{limit} +
""" + + def _format_number(self, num: int) -> str: + if num >= 1000000: + return f"{num/1000000:.1f}M" + elif num >= 1000: + return f"{num/1000:.1f}K" + return str(num) diff --git a/app/core/stats_store.py b/app/core/stats_store.py new file mode 100755 index 0000000..e88762f --- /dev/null +++ b/app/core/stats_store.py @@ -0,0 +1,207 @@ +import sqlite3 +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any + + +class StatsStore: + DAILY_RETENTION_DAYS = 7 + + def __init__(self, db_path: str | Path): + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_db() + + def _connect(self) -> sqlite3.Connection: + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + return conn + + def _init_db(self): + with self._connect() as conn: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS latest_summary ( + id INTEGER PRIMARY KEY CHECK (id = 1), + updated_at INTEGER NOT NULL, + total_keys INTEGER NOT NULL, + active_keys INTEGER NOT NULL, + cooldown_keys INTEGER NOT NULL, + disabled_keys INTEGER NOT NULL, + total_concurrency INTEGER NOT NULL, + total_rpm INTEGER NOT NULL, + total_tpm INTEGER NOT NULL, + total_rpd INTEGER NOT NULL, + total_yesterday_rpd INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS latest_key_stats ( + key_id TEXT PRIMARY KEY, + updated_at INTEGER NOT NULL, + model_name TEXT NOT NULL, + provider TEXT NOT NULL, + status TEXT NOT NULL, + enabled INTEGER NOT NULL, + config_id TEXT, + owner TEXT, + current_concurrency INTEGER NOT NULL, + current_rpm INTEGER NOT NULL, + current_tpm INTEGER NOT NULL, + current_rpd INTEGER NOT NULL, + cooldown_remaining REAL NOT NULL, + rpm_limit INTEGER NOT NULL, + tpm_limit INTEGER NOT NULL, + rpd_limit INTEGER NOT NULL, + max_concurrency INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS daily_usage ( + day TEXT PRIMARY KEY, + label TEXT NOT NULL, + total_rpd INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ); + """ + ) + + def save_export(self, stats: list[dict[str, Any]], usage_summary: dict[str, Any], daily_usage: list[dict[str, Any]]): + now = int(datetime.now(timezone.utc).timestamp()) + cutoff = (datetime.now(timezone.utc) - timedelta(days=self.DAILY_RETENTION_DAYS)).strftime("%Y%m%d") + + with self._connect() as conn: + conn.execute( + """ + INSERT INTO latest_summary ( + id, updated_at, total_keys, active_keys, cooldown_keys, disabled_keys, + total_concurrency, total_rpm, total_tpm, total_rpd, total_yesterday_rpd + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + updated_at = excluded.updated_at, + total_keys = excluded.total_keys, + active_keys = excluded.active_keys, + cooldown_keys = excluded.cooldown_keys, + disabled_keys = excluded.disabled_keys, + total_concurrency = excluded.total_concurrency, + total_rpm = excluded.total_rpm, + total_tpm = excluded.total_tpm, + total_rpd = excluded.total_rpd, + total_yesterday_rpd = excluded.total_yesterday_rpd + """, + ( + 1, + now, + usage_summary["total_keys"], + usage_summary["active_keys"], + usage_summary["cooldown_keys"], + usage_summary["disabled_keys"], + usage_summary["total_concurrency"], + usage_summary["total_rpm"], + usage_summary["total_tpm"], + usage_summary["total_rpd"], + usage_summary.get("total_yesterday_rpd", 0), + ), + ) + + conn.execute("DELETE FROM latest_key_stats") + conn.executemany( + """ + INSERT INTO latest_key_stats ( + key_id, updated_at, model_name, provider, status, enabled, config_id, owner, + current_concurrency, current_rpm, current_tpm, current_rpd, cooldown_remaining, + rpm_limit, tpm_limit, rpd_limit, max_concurrency + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + item["id"], + now, + item["model_name"], + item["provider"], + item["status"], + 1 if item["enabled"] else 0, + item.get("config_id"), + item.get("owner"), + item["usage"]["current_concurrency"], + item["usage"]["current_rpm"], + item["usage"]["current_tpm"], + item["usage"]["current_rpd"], + item["cooldown_remaining"], + item["limits"]["rpm"], + item["limits"]["tpm"], + item["limits"]["rpd"], + item["limits"]["max_concurrency"], + ) + for item in stats + ], + ) + + conn.executemany( + """ + INSERT INTO daily_usage (day, label, total_rpd, updated_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(day) DO UPDATE SET + label = excluded.label, + total_rpd = MAX(daily_usage.total_rpd, excluded.total_rpd), + updated_at = excluded.updated_at + """, + [ + ( + item["day"], + item["label"], + item["total_rpd"], + now, + ) + for item in daily_usage + ], + ) + conn.execute("DELETE FROM daily_usage WHERE day < ?", (cutoff,)) + + def get_daily_usage(self, days: int = 7) -> list[dict[str, Any]]: + if days <= 0: + return [] + + with self._connect() as conn: + rows = conn.execute( + """ + SELECT day, label, total_rpd + FROM daily_usage + ORDER BY day DESC + LIMIT ? + """, + (days,), + ).fetchall() + + rows = list(reversed(rows)) + return [ + { + "day": row["day"], + "label": row["label"], + "total_rpd": row["total_rpd"], + } + for row in rows + ] + + def get_latest_summary(self) -> dict[str, Any] | None: + with self._connect() as conn: + row = conn.execute("SELECT * FROM latest_summary WHERE id = 1").fetchone() + + if not row: + return None + + return dict(row) + + def get_latest_key_stats(self) -> list[dict[str, Any]]: + with self._connect() as conn: + rows = conn.execute( + """ + SELECT * + FROM latest_key_stats + ORDER BY provider, model_name, key_id + """ + ).fetchall() + + return [dict(row) for row in rows] + + def delete_day(self, day: str): + with self._connect() as conn: + conn.execute("DELETE FROM daily_usage WHERE day = ?", (day,)) diff --git a/app/models.py b/app/models.py new file mode 100755 index 0000000..35741ea --- /dev/null +++ b/app/models.py @@ -0,0 +1,36 @@ +from enum import Enum +from pydantic import BaseModel, Field +from typing import Optional, Dict, Any + +class KeyStatus(str, Enum): + ACTIVE = "active" + COOLDOWN = "cooldown" + DISABLED = "disabled" + +class APIKey(BaseModel): + id: str = Field(..., description="Unique identifier for the key") + key: str = Field(..., description="The actual API Key string") + provider: str = Field(..., description="Provider identifier (e.g., google, zhipu)") + api_base: Optional[str] = Field(None, description="Base URL for the API") + model_name: str = Field(..., description="The physical model name") + params: Dict[str, Any] = Field(default_factory=dict, description="Default model parameters (e.g., temperature)") + extra_config: Dict[str, Any] = Field(default_factory=dict, description="Provider specific configuration") + enabled: bool = Field(default=True, description="Whether the key is enabled in configuration") + status: KeyStatus = Field(default=KeyStatus.ACTIVE, description="Current status of the key") + rpm_limit: int = Field(..., description="Requests Per Minute limit") + tpm_limit: int = Field(..., description="Tokens Per Minute limit") + rpd_limit: int = Field(default=0, description="Requests Per Day limit (0 = unlimited)") + max_concurrency: int = Field(..., description="Maximum concurrent requests allowed") + config_id: Optional[str] = Field(None, description="Config ID this key belongs to") + owner: Optional[str] = Field(None, description="Owner of this key") + + # Runtime stats (Not loaded from config usually, but part of the object in memory) + current_concurrency: int = Field(default=0, description="Current concurrent requests") + current_rpm: int = Field(default=0, description="Current RPM usage") + current_tpm: int = Field(default=0, description="Current TPM usage") + current_rpd: int = Field(default=0, description="Current RPD usage") + rate_limit_penalty: int = Field(default=0, description="429 demotion order. 0 means never rate limited.") + endpoint_idx: Optional[int] = Field(default=None, description="Endpoint index for local llama load balancing") + + class Config: + use_enum_values = True diff --git a/app/runtime.py b/app/runtime.py new file mode 100755 index 0000000..69607c1 --- /dev/null +++ b/app/runtime.py @@ -0,0 +1,10 @@ +import os + +from app.core.dispatcher import Dispatcher +from app.core.key_manager import KeyManager +from app.core.stats_exporter import StatsExporter + + +key_manager = KeyManager(os.getenv("KEY_CONFIG_PATH", "config/keys.yaml")) +dispatcher = Dispatcher(key_manager) +stats_exporter = StatsExporter(dispatcher, output_dir="stats", interval=10) diff --git a/app/utils/redis_client.py b/app/utils/redis_client.py new file mode 100755 index 0000000..8ffe045 --- /dev/null +++ b/app/utils/redis_client.py @@ -0,0 +1,18 @@ +import redis.asyncio as redis +import os + +class RedisClient: + _instance = None + + @classmethod + def get_instance(cls): + if cls._instance is None: + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") + cls._instance = redis.from_url(redis_url, decode_responses=True) + return cls._instance + + @classmethod + async def close(cls): + if cls._instance: + await cls._instance.close() + cls._instance = None diff --git a/config/keys_example.yaml b/config/keys_example.yaml new file mode 100755 index 0000000..e3e501b --- /dev/null +++ b/config/keys_example.yaml @@ -0,0 +1,41 @@ +model_configs: +- id: nvidia-glm-cfg + provider: nvidia + model_name: z-ai/glm-5.1 + api_base: https://integrate.api.nvidia.com/v1 + rpm_limit: 40 + tpm_limit: 100000 + max_concurrency: 8 + params: + temperature: 1 + top_p: 1 + max_tokens: 16384 + stream: false + chat_template_kwargs: + enable_thinking: false + clear_thinking: false +- id: llama-local-cfg + provider: openai + model_name: llama-local + api_base: http://192.168.2.101:8848/v1 + api_endpoints: + - url: http://192.168.2.101:8848/v1 + enabled: true + - url: http://192.168.1.51:8848/v1 + enabled: true + - url: http://192.168.1.61:8848/v1 + enabled: true + rpm_limit: 999999 + tpm_limit: 999999999 + max_concurrency: 999 + params: + temperature: 1.0 + max_tokens: 6144 + chat_template_kwargs: + enable_thinking: false +keys: +- config_id: nvidia-glm-cfg + enabled: true + keys: + - key: nvapi-key-1 + owner: name diff --git a/main.py b/main.py new file mode 100755 index 0000000..55ecc5f --- /dev/null +++ b/main.py @@ -0,0 +1,41 @@ +import os + +from fastapi import FastAPI +from fastapi.responses import HTMLResponse + +from app.api.routes import router +from app.runtime import stats_exporter +from app.utils.redis_client import RedisClient + +app = FastAPI(title="Smart KeyPool Gateway") + +app.include_router(router) + +@app.get("/", response_class=HTMLResponse) +async def dashboard_home(): + return await stats_exporter.render_dashboard() + +@app.get("/dashboard", response_class=HTMLResponse) +async def dashboard_page(): + return await stats_exporter.render_dashboard() + +@app.on_event("startup") +async def startup_event(): + redis = RedisClient.get_instance() + try: + await redis.ping() + print("Connected to Redis") + except Exception as e: + print(f"Failed to connect to Redis: {e}") + + await stats_exporter.start() + print("Stats exporter started - dashboard available at / or /dashboard") + +@app.on_event("shutdown") +async def shutdown_event(): + await stats_exporter.stop() + await RedisClient.close() + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "8888"))) diff --git a/requirements.txt b/requirements.txt new file mode 100755 index 0000000..a28626e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +fastapi +uvicorn +redis +pydantic +pyyaml +httpx diff --git a/test_keys.py b/test_keys.py new file mode 100755 index 0000000..023b044 --- /dev/null +++ b/test_keys.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +""" +测试所有启用的API密钥是否正常工作(并发版本) +""" +import asyncio +import sys +import httpx +from typing import Dict, Any, List, Optional +from dataclasses import dataclass +from enum import Enum +from concurrent.futures import ThreadPoolExecutor + +# 添加app目录到路径 +sys.path.insert(0, 'd:\\DPI\\keypool') + +from app.core.key_manager import KeyManager +from app.core.providers import get_provider_handler +from app.models import APIKey + + +class TestStatus(Enum): + SUCCESS = "success" + FAILED = "failed" + SKIPPED = "skipped" + + +@dataclass +class TestResult: + key_id: str + provider: str + owner: Optional[str] + status: TestStatus + message: str + response_time: float = 0.0 + details: Optional[Dict] = None + + +async def test_single_key(api_key: APIKey, timeout: int = 240) -> TestResult: + """测试单个API密钥""" + import time + + start_time = time.time() + owner = getattr(api_key, 'owner', None) + + # 如果密钥被禁用,跳过测试 + if not api_key.enabled: + return TestResult( + key_id=api_key.id, + provider=api_key.provider, + owner=owner, + status=TestStatus.SKIPPED, + message="密钥已禁用 (enabled: false)", + response_time=0.0 + ) + + try: + provider = get_provider_handler(api_key) + + # 准备测试消息 + test_messages = [ + {"role": "user", "content": "Hello, this is a test message. Please reply with 'OK'."} + ] + + # 获取请求参数 + headers = provider.get_headers() + url = provider.get_url() + payload = provider.get_payload(test_messages, {}) + + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post(url, headers=headers, json=payload) + response_time = time.time() - start_time + + if response.status_code == 200: + data = response.json() + parsed = provider.parse_response(data) + content = parsed.get("content", "") + + return TestResult( + key_id=api_key.id, + provider=api_key.provider, + owner=owner, + status=TestStatus.SUCCESS, + message=f"测试成功,响应时间: {response_time:.2f}s", + response_time=response_time, + details={ + "content_preview": content[:100] + "..." if len(content) > 100 else content, + "usage": parsed.get("usage", {}) + } + ) + else: + error_text = response.text[:500] if response.text else "No error details" + return TestResult( + key_id=api_key.id, + provider=api_key.provider, + owner=owner, + status=TestStatus.FAILED, + message=f"HTTP错误 {response.status_code}: {error_text}", + response_time=time.time() - start_time + ) + + except httpx.TimeoutException: + return TestResult( + key_id=api_key.id, + provider=api_key.provider, + owner=owner, + status=TestStatus.FAILED, + message=f"请求超时 (>{timeout}s)", + response_time=time.time() - start_time + ) + except httpx.ConnectError as e: + return TestResult( + key_id=api_key.id, + provider=api_key.provider, + owner=owner, + status=TestStatus.FAILED, + message=f"连接错误: {str(e)}", + response_time=time.time() - start_time + ) + except Exception as e: + return TestResult( + key_id=api_key.id, + provider=api_key.provider, + owner=owner, + status=TestStatus.FAILED, + message=f"异常: {type(e).__name__}: {str(e)}", + response_time=time.time() - start_time + ) + + +async def test_all_keys_concurrent(config_path: str = "config/keys.yaml", max_concurrency: int = 10): + """并发测试所有启用的密钥""" + print("=" * 80) + print("API密钥测试工具 (并发模式)") + print("=" * 80) + print() + + # 加载密钥管理器 + try: + km = KeyManager(config_path) + all_keys = km.get_all_keys() + print(f"共加载 {len(all_keys)} 个密钥配置") + except Exception as e: + print(f"加载配置失败: {e}") + return + + # 筛选启用的密钥 + enabled_keys = [k for k in all_keys if k.enabled] + print(f"其中 {len(enabled_keys)} 个密钥已启用") + print(f"并发数: {max_concurrency}") + print() + + if not enabled_keys: + print("没有启用的密钥需要测试") + return + + # 使用信号量控制并发 + semaphore = asyncio.Semaphore(max_concurrency) + + async def test_with_semaphore(key: APIKey) -> TestResult: + async with semaphore: + return await test_single_key(key) + + # 创建所有测试任务 + tasks = [test_with_semaphore(key) for key in enabled_keys] + + # 并发执行所有测试 + print("开始并发测试...") + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理可能的异常结果 + processed_results: List[TestResult] = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + key = enabled_keys[i] + owner = getattr(key, 'owner', None) + processed_results.append(TestResult( + key_id=key.id, + provider=key.provider, + owner=owner, + status=TestStatus.FAILED, + message=f"测试任务异常: {type(result).__name__}: {str(result)}", + response_time=0.0 + )) + else: + processed_results.append(result) + + results = processed_results + + # 按提供商分组显示结果 + results_by_provider: Dict[str, List[TestResult]] = {} + for r in results: + p = r.provider + if p not in results_by_provider: + results_by_provider[p] = [] + results_by_provider[p].append(r) + + # 打印详细结果 + for provider, provider_results in sorted(results_by_provider.items()): + print(f"\n{'='*80}") + print(f"提供商: {provider.upper()} ({len(provider_results)} 个密钥)") + print(f"{'='*80}") + + for r in provider_results: + status_icon = "✓" if r.status == TestStatus.SUCCESS else "✗" if r.status == TestStatus.FAILED else "○" + owner_str = f" ({r.owner})" if r.owner else "" + print(f"\n [{status_icon}] {r.key_id}{owner_str}") + print(f" 状态: {r.message}") + if r.details and r.status == TestStatus.SUCCESS: + print(f" 回复: {r.details['content_preview']}") + + # 打印汇总报告 + print("\n" + "=" * 80) + print("测试汇总报告") + print("=" * 80) + + success_count = sum(1 for r in results if r.status == TestStatus.SUCCESS) + failed_count = sum(1 for r in results if r.status == TestStatus.FAILED) + skipped_count = sum(1 for r in results if r.status == TestStatus.SKIPPED) + + print(f"\n总计: {len(results)} 个密钥") + print(f" ✓ 成功: {success_count}") + print(f" ✗ 失败: {failed_count}") + print(f" ○ 跳过: {skipped_count}") + + # 按提供商统计 + print("\n按提供商统计:") + provider_stats: Dict[str, Dict[str, int]] = {} + for r in results: + p = r.provider + if p not in provider_stats: + provider_stats[p] = {"success": 0, "failed": 0, "skipped": 0} + if r.status == TestStatus.SUCCESS: + provider_stats[p]["success"] += 1 + elif r.status == TestStatus.FAILED: + provider_stats[p]["failed"] += 1 + else: + provider_stats[p]["skipped"] += 1 + + for provider, stats in sorted(provider_stats.items()): + total = stats["success"] + stats["failed"] + stats["skipped"] + success_rate = (stats["success"] / total * 100) if total > 0 else 0 + print(f" {provider:15s}: {stats['success']:2d}/{total:2d} 成功 ({success_rate:5.1f}%), {stats['failed']} 失败, {stats['skipped']} 跳过") + + # 失败的详细信息 + failed_results = [r for r in results if r.status == TestStatus.FAILED] + if failed_results: + print("\n" + "-" * 80) + print("失败的密钥详情:") + print("-" * 80) + for r in failed_results: + print(f"\n [{r.provider}] {r.key_id}") + if r.owner: + print(f" 拥有者: {r.owner}") + print(f" 错误: {r.message}") + + # 成功的响应时间统计 + success_results = [r for r in results if r.status == TestStatus.SUCCESS and r.response_time > 0] + if success_results: + avg_time = sum(r.response_time for r in success_results) / len(success_results) + min_time = min(r.response_time for r in success_results) + max_time = max(r.response_time for r in success_results) + print("\n" + "-" * 80) + print("响应时间统计 (成功请求):") + print("-" * 80) + print(f" 平均: {avg_time:.2f}s") + print(f" 最快: {min_time:.2f}s") + print(f" 最慢: {max_time:.2f}s") + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + return results + + +if __name__ == "__main__": + import os + + # 切换到脚本所在目录 + script_dir = os.path.dirname(os.path.abspath(__file__)) + os.chdir(script_dir) + + # 运行测试(默认并发数10) + asyncio.run(test_all_keys_concurrent(max_concurrency=11))