#!/usr/bin/env python3 """ 测试所有启用的API密钥是否正常工作(并发版本) """ import asyncio import sys import httpx from typing import Dict, Any, List, Optional from dataclasses import dataclass from enum import Enum from concurrent.futures import ThreadPoolExecutor # 添加app目录到路径 sys.path.insert(0, 'd:\\DPI\\keypool') from app.core.key_manager import KeyManager from app.core.providers import get_provider_handler from app.models import APIKey class TestStatus(Enum): SUCCESS = "success" FAILED = "failed" SKIPPED = "skipped" @dataclass class TestResult: key_id: str provider: str owner: Optional[str] status: TestStatus message: str response_time: float = 0.0 details: Optional[Dict] = None async def test_single_key(api_key: APIKey, timeout: int = 240) -> TestResult: """测试单个API密钥""" import time start_time = time.time() owner = getattr(api_key, 'owner', None) # 如果密钥被禁用,跳过测试 if not api_key.enabled: return TestResult( key_id=api_key.id, provider=api_key.provider, owner=owner, status=TestStatus.SKIPPED, message="密钥已禁用 (enabled: false)", response_time=0.0 ) try: provider = get_provider_handler(api_key) # 准备测试消息 test_messages = [ {"role": "user", "content": "Hello, this is a test message. Please reply with 'OK'."} ] # 获取请求参数 headers = provider.get_headers() url = provider.get_url() payload = provider.get_payload(test_messages, {}) async with httpx.AsyncClient(timeout=timeout) as client: response = await client.post(url, headers=headers, json=payload) response_time = time.time() - start_time if response.status_code == 200: data = response.json() parsed = provider.parse_response(data) content = parsed.get("content", "") return TestResult( key_id=api_key.id, provider=api_key.provider, owner=owner, status=TestStatus.SUCCESS, message=f"测试成功,响应时间: {response_time:.2f}s", response_time=response_time, details={ "content_preview": content[:100] + "..." if len(content) > 100 else content, "usage": parsed.get("usage", {}) } ) else: error_text = response.text[:500] if response.text else "No error details" return TestResult( key_id=api_key.id, provider=api_key.provider, owner=owner, status=TestStatus.FAILED, message=f"HTTP错误 {response.status_code}: {error_text}", response_time=time.time() - start_time ) except httpx.TimeoutException: return TestResult( key_id=api_key.id, provider=api_key.provider, owner=owner, status=TestStatus.FAILED, message=f"请求超时 (>{timeout}s)", response_time=time.time() - start_time ) except httpx.ConnectError as e: return TestResult( key_id=api_key.id, provider=api_key.provider, owner=owner, status=TestStatus.FAILED, message=f"连接错误: {str(e)}", response_time=time.time() - start_time ) except Exception as e: return TestResult( key_id=api_key.id, provider=api_key.provider, owner=owner, status=TestStatus.FAILED, message=f"异常: {type(e).__name__}: {str(e)}", response_time=time.time() - start_time ) async def test_all_keys_concurrent(config_path: str = "config/keys.yaml", max_concurrency: int = 10): """并发测试所有启用的密钥""" print("=" * 80) print("API密钥测试工具 (并发模式)") print("=" * 80) print() # 加载密钥管理器 try: km = KeyManager(config_path) all_keys = km.get_all_keys() print(f"共加载 {len(all_keys)} 个密钥配置") except Exception as e: print(f"加载配置失败: {e}") return # 筛选启用的密钥 enabled_keys = [k for k in all_keys if k.enabled] print(f"其中 {len(enabled_keys)} 个密钥已启用") print(f"并发数: {max_concurrency}") print() if not enabled_keys: print("没有启用的密钥需要测试") return # 使用信号量控制并发 semaphore = asyncio.Semaphore(max_concurrency) async def test_with_semaphore(key: APIKey) -> TestResult: async with semaphore: return await test_single_key(key) # 创建所有测试任务 tasks = [test_with_semaphore(key) for key in enabled_keys] # 并发执行所有测试 print("开始并发测试...") results = await asyncio.gather(*tasks, return_exceptions=True) # 处理可能的异常结果 processed_results: List[TestResult] = [] for i, result in enumerate(results): if isinstance(result, Exception): key = enabled_keys[i] owner = getattr(key, 'owner', None) processed_results.append(TestResult( key_id=key.id, provider=key.provider, owner=owner, status=TestStatus.FAILED, message=f"测试任务异常: {type(result).__name__}: {str(result)}", response_time=0.0 )) else: processed_results.append(result) results = processed_results # 按提供商分组显示结果 results_by_provider: Dict[str, List[TestResult]] = {} for r in results: p = r.provider if p not in results_by_provider: results_by_provider[p] = [] results_by_provider[p].append(r) # 打印详细结果 for provider, provider_results in sorted(results_by_provider.items()): print(f"\n{'='*80}") print(f"提供商: {provider.upper()} ({len(provider_results)} 个密钥)") print(f"{'='*80}") for r in provider_results: status_icon = "✓" if r.status == TestStatus.SUCCESS else "✗" if r.status == TestStatus.FAILED else "○" owner_str = f" ({r.owner})" if r.owner else "" print(f"\n [{status_icon}] {r.key_id}{owner_str}") print(f" 状态: {r.message}") if r.details and r.status == TestStatus.SUCCESS: print(f" 回复: {r.details['content_preview']}") # 打印汇总报告 print("\n" + "=" * 80) print("测试汇总报告") print("=" * 80) success_count = sum(1 for r in results if r.status == TestStatus.SUCCESS) failed_count = sum(1 for r in results if r.status == TestStatus.FAILED) skipped_count = sum(1 for r in results if r.status == TestStatus.SKIPPED) print(f"\n总计: {len(results)} 个密钥") print(f" ✓ 成功: {success_count}") print(f" ✗ 失败: {failed_count}") print(f" ○ 跳过: {skipped_count}") # 按提供商统计 print("\n按提供商统计:") provider_stats: Dict[str, Dict[str, int]] = {} for r in results: p = r.provider if p not in provider_stats: provider_stats[p] = {"success": 0, "failed": 0, "skipped": 0} if r.status == TestStatus.SUCCESS: provider_stats[p]["success"] += 1 elif r.status == TestStatus.FAILED: provider_stats[p]["failed"] += 1 else: provider_stats[p]["skipped"] += 1 for provider, stats in sorted(provider_stats.items()): total = stats["success"] + stats["failed"] + stats["skipped"] success_rate = (stats["success"] / total * 100) if total > 0 else 0 print(f" {provider:15s}: {stats['success']:2d}/{total:2d} 成功 ({success_rate:5.1f}%), {stats['failed']} 失败, {stats['skipped']} 跳过") # 失败的详细信息 failed_results = [r for r in results if r.status == TestStatus.FAILED] if failed_results: print("\n" + "-" * 80) print("失败的密钥详情:") print("-" * 80) for r in failed_results: print(f"\n [{r.provider}] {r.key_id}") if r.owner: print(f" 拥有者: {r.owner}") print(f" 错误: {r.message}") # 成功的响应时间统计 success_results = [r for r in results if r.status == TestStatus.SUCCESS and r.response_time > 0] if success_results: avg_time = sum(r.response_time for r in success_results) / len(success_results) min_time = min(r.response_time for r in success_results) max_time = max(r.response_time for r in success_results) print("\n" + "-" * 80) print("响应时间统计 (成功请求):") print("-" * 80) print(f" 平均: {avg_time:.2f}s") print(f" 最快: {min_time:.2f}s") print(f" 最慢: {max_time:.2f}s") print("\n" + "=" * 80) print("测试完成") print("=" * 80) return results if __name__ == "__main__": import os # 切换到脚本所在目录 script_dir = os.path.dirname(os.path.abspath(__file__)) os.chdir(script_dir) # 运行测试(默认并发数10) asyncio.run(test_all_keys_concurrent(max_concurrency=11))