autool-test/test_keys.py
2026-06-17 11:13:11 +08:00

285 lines
9.9 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试所有启用的API密钥是否正常工作并发版本
"""
import asyncio
import sys
import httpx
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from enum import Enum
from concurrent.futures import ThreadPoolExecutor
# 添加app目录到路径
sys.path.insert(0, 'd:\\DPI\\keypool')
from app.core.key_manager import KeyManager
from app.core.providers import get_provider_handler
from app.models import APIKey
class TestStatus(Enum):
SUCCESS = "success"
FAILED = "failed"
SKIPPED = "skipped"
@dataclass
class TestResult:
key_id: str
provider: str
owner: Optional[str]
status: TestStatus
message: str
response_time: float = 0.0
details: Optional[Dict] = None
async def test_single_key(api_key: APIKey, timeout: int = 240) -> TestResult:
"""测试单个API密钥"""
import time
start_time = time.time()
owner = getattr(api_key, 'owner', None)
# 如果密钥被禁用,跳过测试
if not api_key.enabled:
return TestResult(
key_id=api_key.id,
provider=api_key.provider,
owner=owner,
status=TestStatus.SKIPPED,
message="密钥已禁用 (enabled: false)",
response_time=0.0
)
try:
provider = get_provider_handler(api_key)
# 准备测试消息
test_messages = [
{"role": "user", "content": "Hello, this is a test message. Please reply with 'OK'."}
]
# 获取请求参数
headers = provider.get_headers()
url = provider.get_url()
payload = provider.get_payload(test_messages, {})
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, headers=headers, json=payload)
response_time = time.time() - start_time
if response.status_code == 200:
data = response.json()
parsed = provider.parse_response(data)
content = parsed.get("content", "")
return TestResult(
key_id=api_key.id,
provider=api_key.provider,
owner=owner,
status=TestStatus.SUCCESS,
message=f"测试成功,响应时间: {response_time:.2f}s",
response_time=response_time,
details={
"content_preview": content[:100] + "..." if len(content) > 100 else content,
"usage": parsed.get("usage", {})
}
)
else:
error_text = response.text[:500] if response.text else "No error details"
return TestResult(
key_id=api_key.id,
provider=api_key.provider,
owner=owner,
status=TestStatus.FAILED,
message=f"HTTP错误 {response.status_code}: {error_text}",
response_time=time.time() - start_time
)
except httpx.TimeoutException:
return TestResult(
key_id=api_key.id,
provider=api_key.provider,
owner=owner,
status=TestStatus.FAILED,
message=f"请求超时 (>{timeout}s)",
response_time=time.time() - start_time
)
except httpx.ConnectError as e:
return TestResult(
key_id=api_key.id,
provider=api_key.provider,
owner=owner,
status=TestStatus.FAILED,
message=f"连接错误: {str(e)}",
response_time=time.time() - start_time
)
except Exception as e:
return TestResult(
key_id=api_key.id,
provider=api_key.provider,
owner=owner,
status=TestStatus.FAILED,
message=f"异常: {type(e).__name__}: {str(e)}",
response_time=time.time() - start_time
)
async def test_all_keys_concurrent(config_path: str = "config/keys.yaml", max_concurrency: int = 10):
"""并发测试所有启用的密钥"""
print("=" * 80)
print("API密钥测试工具 (并发模式)")
print("=" * 80)
print()
# 加载密钥管理器
try:
km = KeyManager(config_path)
all_keys = km.get_all_keys()
print(f"共加载 {len(all_keys)} 个密钥配置")
except Exception as e:
print(f"加载配置失败: {e}")
return
# 筛选启用的密钥
enabled_keys = [k for k in all_keys if k.enabled]
print(f"其中 {len(enabled_keys)} 个密钥已启用")
print(f"并发数: {max_concurrency}")
print()
if not enabled_keys:
print("没有启用的密钥需要测试")
return
# 使用信号量控制并发
semaphore = asyncio.Semaphore(max_concurrency)
async def test_with_semaphore(key: APIKey) -> TestResult:
async with semaphore:
return await test_single_key(key)
# 创建所有测试任务
tasks = [test_with_semaphore(key) for key in enabled_keys]
# 并发执行所有测试
print("开始并发测试...")
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理可能的异常结果
processed_results: List[TestResult] = []
for i, result in enumerate(results):
if isinstance(result, Exception):
key = enabled_keys[i]
owner = getattr(key, 'owner', None)
processed_results.append(TestResult(
key_id=key.id,
provider=key.provider,
owner=owner,
status=TestStatus.FAILED,
message=f"测试任务异常: {type(result).__name__}: {str(result)}",
response_time=0.0
))
else:
processed_results.append(result)
results = processed_results
# 按提供商分组显示结果
results_by_provider: Dict[str, List[TestResult]] = {}
for r in results:
p = r.provider
if p not in results_by_provider:
results_by_provider[p] = []
results_by_provider[p].append(r)
# 打印详细结果
for provider, provider_results in sorted(results_by_provider.items()):
print(f"\n{'='*80}")
print(f"提供商: {provider.upper()} ({len(provider_results)} 个密钥)")
print(f"{'='*80}")
for r in provider_results:
status_icon = "" if r.status == TestStatus.SUCCESS else "" if r.status == TestStatus.FAILED else ""
owner_str = f" ({r.owner})" if r.owner else ""
print(f"\n [{status_icon}] {r.key_id}{owner_str}")
print(f" 状态: {r.message}")
if r.details and r.status == TestStatus.SUCCESS:
print(f" 回复: {r.details['content_preview']}")
# 打印汇总报告
print("\n" + "=" * 80)
print("测试汇总报告")
print("=" * 80)
success_count = sum(1 for r in results if r.status == TestStatus.SUCCESS)
failed_count = sum(1 for r in results if r.status == TestStatus.FAILED)
skipped_count = sum(1 for r in results if r.status == TestStatus.SKIPPED)
print(f"\n总计: {len(results)} 个密钥")
print(f" ✓ 成功: {success_count}")
print(f" ✗ 失败: {failed_count}")
print(f" ○ 跳过: {skipped_count}")
# 按提供商统计
print("\n按提供商统计:")
provider_stats: Dict[str, Dict[str, int]] = {}
for r in results:
p = r.provider
if p not in provider_stats:
provider_stats[p] = {"success": 0, "failed": 0, "skipped": 0}
if r.status == TestStatus.SUCCESS:
provider_stats[p]["success"] += 1
elif r.status == TestStatus.FAILED:
provider_stats[p]["failed"] += 1
else:
provider_stats[p]["skipped"] += 1
for provider, stats in sorted(provider_stats.items()):
total = stats["success"] + stats["failed"] + stats["skipped"]
success_rate = (stats["success"] / total * 100) if total > 0 else 0
print(f" {provider:15s}: {stats['success']:2d}/{total:2d} 成功 ({success_rate:5.1f}%), {stats['failed']} 失败, {stats['skipped']} 跳过")
# 失败的详细信息
failed_results = [r for r in results if r.status == TestStatus.FAILED]
if failed_results:
print("\n" + "-" * 80)
print("失败的密钥详情:")
print("-" * 80)
for r in failed_results:
print(f"\n [{r.provider}] {r.key_id}")
if r.owner:
print(f" 拥有者: {r.owner}")
print(f" 错误: {r.message}")
# 成功的响应时间统计
success_results = [r for r in results if r.status == TestStatus.SUCCESS and r.response_time > 0]
if success_results:
avg_time = sum(r.response_time for r in success_results) / len(success_results)
min_time = min(r.response_time for r in success_results)
max_time = max(r.response_time for r in success_results)
print("\n" + "-" * 80)
print("响应时间统计 (成功请求):")
print("-" * 80)
print(f" 平均: {avg_time:.2f}s")
print(f" 最快: {min_time:.2f}s")
print(f" 最慢: {max_time:.2f}s")
print("\n" + "=" * 80)
print("测试完成")
print("=" * 80)
return results
if __name__ == "__main__":
import os
# 切换到脚本所在目录
script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
# 运行测试默认并发数10
asyncio.run(test_all_keys_concurrent(max_concurrency=11))