285 lines
9.9 KiB
Python
Executable File
285 lines
9.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
测试所有启用的API密钥是否正常工作(并发版本)
|
||
"""
|
||
import asyncio
|
||
import sys
|
||
import httpx
|
||
from typing import Dict, Any, List, Optional
|
||
from dataclasses import dataclass
|
||
from enum import Enum
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
# 添加app目录到路径
|
||
sys.path.insert(0, 'd:\\DPI\\keypool')
|
||
|
||
from app.core.key_manager import KeyManager
|
||
from app.core.providers import get_provider_handler
|
||
from app.models import APIKey
|
||
|
||
|
||
class TestStatus(Enum):
|
||
SUCCESS = "success"
|
||
FAILED = "failed"
|
||
SKIPPED = "skipped"
|
||
|
||
|
||
@dataclass
|
||
class TestResult:
|
||
key_id: str
|
||
provider: str
|
||
owner: Optional[str]
|
||
status: TestStatus
|
||
message: str
|
||
response_time: float = 0.0
|
||
details: Optional[Dict] = None
|
||
|
||
|
||
async def test_single_key(api_key: APIKey, timeout: int = 240) -> TestResult:
|
||
"""测试单个API密钥"""
|
||
import time
|
||
|
||
start_time = time.time()
|
||
owner = getattr(api_key, 'owner', None)
|
||
|
||
# 如果密钥被禁用,跳过测试
|
||
if not api_key.enabled:
|
||
return TestResult(
|
||
key_id=api_key.id,
|
||
provider=api_key.provider,
|
||
owner=owner,
|
||
status=TestStatus.SKIPPED,
|
||
message="密钥已禁用 (enabled: false)",
|
||
response_time=0.0
|
||
)
|
||
|
||
try:
|
||
provider = get_provider_handler(api_key)
|
||
|
||
# 准备测试消息
|
||
test_messages = [
|
||
{"role": "user", "content": "Hello, this is a test message. Please reply with 'OK'."}
|
||
]
|
||
|
||
# 获取请求参数
|
||
headers = provider.get_headers()
|
||
url = provider.get_url()
|
||
payload = provider.get_payload(test_messages, {})
|
||
|
||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||
response = await client.post(url, headers=headers, json=payload)
|
||
response_time = time.time() - start_time
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
parsed = provider.parse_response(data)
|
||
content = parsed.get("content", "")
|
||
|
||
return TestResult(
|
||
key_id=api_key.id,
|
||
provider=api_key.provider,
|
||
owner=owner,
|
||
status=TestStatus.SUCCESS,
|
||
message=f"测试成功,响应时间: {response_time:.2f}s",
|
||
response_time=response_time,
|
||
details={
|
||
"content_preview": content[:100] + "..." if len(content) > 100 else content,
|
||
"usage": parsed.get("usage", {})
|
||
}
|
||
)
|
||
else:
|
||
error_text = response.text[:500] if response.text else "No error details"
|
||
return TestResult(
|
||
key_id=api_key.id,
|
||
provider=api_key.provider,
|
||
owner=owner,
|
||
status=TestStatus.FAILED,
|
||
message=f"HTTP错误 {response.status_code}: {error_text}",
|
||
response_time=time.time() - start_time
|
||
)
|
||
|
||
except httpx.TimeoutException:
|
||
return TestResult(
|
||
key_id=api_key.id,
|
||
provider=api_key.provider,
|
||
owner=owner,
|
||
status=TestStatus.FAILED,
|
||
message=f"请求超时 (>{timeout}s)",
|
||
response_time=time.time() - start_time
|
||
)
|
||
except httpx.ConnectError as e:
|
||
return TestResult(
|
||
key_id=api_key.id,
|
||
provider=api_key.provider,
|
||
owner=owner,
|
||
status=TestStatus.FAILED,
|
||
message=f"连接错误: {str(e)}",
|
||
response_time=time.time() - start_time
|
||
)
|
||
except Exception as e:
|
||
return TestResult(
|
||
key_id=api_key.id,
|
||
provider=api_key.provider,
|
||
owner=owner,
|
||
status=TestStatus.FAILED,
|
||
message=f"异常: {type(e).__name__}: {str(e)}",
|
||
response_time=time.time() - start_time
|
||
)
|
||
|
||
|
||
async def test_all_keys_concurrent(config_path: str = "config/keys.yaml", max_concurrency: int = 10):
|
||
"""并发测试所有启用的密钥"""
|
||
print("=" * 80)
|
||
print("API密钥测试工具 (并发模式)")
|
||
print("=" * 80)
|
||
print()
|
||
|
||
# 加载密钥管理器
|
||
try:
|
||
km = KeyManager(config_path)
|
||
all_keys = km.get_all_keys()
|
||
print(f"共加载 {len(all_keys)} 个密钥配置")
|
||
except Exception as e:
|
||
print(f"加载配置失败: {e}")
|
||
return
|
||
|
||
# 筛选启用的密钥
|
||
enabled_keys = [k for k in all_keys if k.enabled]
|
||
print(f"其中 {len(enabled_keys)} 个密钥已启用")
|
||
print(f"并发数: {max_concurrency}")
|
||
print()
|
||
|
||
if not enabled_keys:
|
||
print("没有启用的密钥需要测试")
|
||
return
|
||
|
||
# 使用信号量控制并发
|
||
semaphore = asyncio.Semaphore(max_concurrency)
|
||
|
||
async def test_with_semaphore(key: APIKey) -> TestResult:
|
||
async with semaphore:
|
||
return await test_single_key(key)
|
||
|
||
# 创建所有测试任务
|
||
tasks = [test_with_semaphore(key) for key in enabled_keys]
|
||
|
||
# 并发执行所有测试
|
||
print("开始并发测试...")
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
# 处理可能的异常结果
|
||
processed_results: List[TestResult] = []
|
||
for i, result in enumerate(results):
|
||
if isinstance(result, Exception):
|
||
key = enabled_keys[i]
|
||
owner = getattr(key, 'owner', None)
|
||
processed_results.append(TestResult(
|
||
key_id=key.id,
|
||
provider=key.provider,
|
||
owner=owner,
|
||
status=TestStatus.FAILED,
|
||
message=f"测试任务异常: {type(result).__name__}: {str(result)}",
|
||
response_time=0.0
|
||
))
|
||
else:
|
||
processed_results.append(result)
|
||
|
||
results = processed_results
|
||
|
||
# 按提供商分组显示结果
|
||
results_by_provider: Dict[str, List[TestResult]] = {}
|
||
for r in results:
|
||
p = r.provider
|
||
if p not in results_by_provider:
|
||
results_by_provider[p] = []
|
||
results_by_provider[p].append(r)
|
||
|
||
# 打印详细结果
|
||
for provider, provider_results in sorted(results_by_provider.items()):
|
||
print(f"\n{'='*80}")
|
||
print(f"提供商: {provider.upper()} ({len(provider_results)} 个密钥)")
|
||
print(f"{'='*80}")
|
||
|
||
for r in provider_results:
|
||
status_icon = "✓" if r.status == TestStatus.SUCCESS else "✗" if r.status == TestStatus.FAILED else "○"
|
||
owner_str = f" ({r.owner})" if r.owner else ""
|
||
print(f"\n [{status_icon}] {r.key_id}{owner_str}")
|
||
print(f" 状态: {r.message}")
|
||
if r.details and r.status == TestStatus.SUCCESS:
|
||
print(f" 回复: {r.details['content_preview']}")
|
||
|
||
# 打印汇总报告
|
||
print("\n" + "=" * 80)
|
||
print("测试汇总报告")
|
||
print("=" * 80)
|
||
|
||
success_count = sum(1 for r in results if r.status == TestStatus.SUCCESS)
|
||
failed_count = sum(1 for r in results if r.status == TestStatus.FAILED)
|
||
skipped_count = sum(1 for r in results if r.status == TestStatus.SKIPPED)
|
||
|
||
print(f"\n总计: {len(results)} 个密钥")
|
||
print(f" ✓ 成功: {success_count}")
|
||
print(f" ✗ 失败: {failed_count}")
|
||
print(f" ○ 跳过: {skipped_count}")
|
||
|
||
# 按提供商统计
|
||
print("\n按提供商统计:")
|
||
provider_stats: Dict[str, Dict[str, int]] = {}
|
||
for r in results:
|
||
p = r.provider
|
||
if p not in provider_stats:
|
||
provider_stats[p] = {"success": 0, "failed": 0, "skipped": 0}
|
||
if r.status == TestStatus.SUCCESS:
|
||
provider_stats[p]["success"] += 1
|
||
elif r.status == TestStatus.FAILED:
|
||
provider_stats[p]["failed"] += 1
|
||
else:
|
||
provider_stats[p]["skipped"] += 1
|
||
|
||
for provider, stats in sorted(provider_stats.items()):
|
||
total = stats["success"] + stats["failed"] + stats["skipped"]
|
||
success_rate = (stats["success"] / total * 100) if total > 0 else 0
|
||
print(f" {provider:15s}: {stats['success']:2d}/{total:2d} 成功 ({success_rate:5.1f}%), {stats['failed']} 失败, {stats['skipped']} 跳过")
|
||
|
||
# 失败的详细信息
|
||
failed_results = [r for r in results if r.status == TestStatus.FAILED]
|
||
if failed_results:
|
||
print("\n" + "-" * 80)
|
||
print("失败的密钥详情:")
|
||
print("-" * 80)
|
||
for r in failed_results:
|
||
print(f"\n [{r.provider}] {r.key_id}")
|
||
if r.owner:
|
||
print(f" 拥有者: {r.owner}")
|
||
print(f" 错误: {r.message}")
|
||
|
||
# 成功的响应时间统计
|
||
success_results = [r for r in results if r.status == TestStatus.SUCCESS and r.response_time > 0]
|
||
if success_results:
|
||
avg_time = sum(r.response_time for r in success_results) / len(success_results)
|
||
min_time = min(r.response_time for r in success_results)
|
||
max_time = max(r.response_time for r in success_results)
|
||
print("\n" + "-" * 80)
|
||
print("响应时间统计 (成功请求):")
|
||
print("-" * 80)
|
||
print(f" 平均: {avg_time:.2f}s")
|
||
print(f" 最快: {min_time:.2f}s")
|
||
print(f" 最慢: {max_time:.2f}s")
|
||
|
||
print("\n" + "=" * 80)
|
||
print("测试完成")
|
||
print("=" * 80)
|
||
|
||
return results
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import os
|
||
|
||
# 切换到脚本所在目录
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
os.chdir(script_dir)
|
||
|
||
# 运行测试(默认并发数10)
|
||
asyncio.run(test_all_keys_concurrent(max_concurrency=11))
|