autool-test/app/core/providers.py
2026-06-17 11:13:11 +08:00

312 lines
12 KiB
Python
Executable File

import re
from typing import Dict, Any, List, Optional, Tuple, Union
from app.models import APIKey
class BaseProvider:
def __init__(self, api_key: APIKey):
self.api_key = api_key
def get_headers(self) -> Dict[str, str]:
raise NotImplementedError
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
raise NotImplementedError
def get_url(self) -> str:
raise NotImplementedError
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Parses provider response into a standardized format.
Returns:
{
"content": str,
"reasoning_content": Optional[str],
"usage": {
"input_tokens": int,
"output_tokens": int,
"total_tokens": int
}
}
"""
raise NotImplementedError
class OpenAIProvider(BaseProvider):
def get_headers(self) -> Dict[str, str]:
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key.key}"
}
def get_url(self) -> str:
base = self.api_key.api_base or "https://api.openai.com/v1"
return f"{base.rstrip('/')}/chat/completions"
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
data = {"messages": messages, "model": self.api_key.model_name, "stream": False}
# 1. Start with Key's default params
final_params = self.api_key.params.copy()
# 2. Update with request-level params (filter None)
valid_request_params = {k: v for k, v in params.items() if v is not None}
final_params.update(valid_request_params)
# 3. Add to payload
data.update(final_params)
# 4. Handle extra_config from Key (e.g., Qwen extra_body)
if self.api_key.extra_config:
if "extra_body" in self.api_key.extra_config:
data.update(self.api_key.extra_config["extra_body"])
return data
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
choice = data["choices"][0]
text = choice["message"]["content"]
reasoning = choice["message"].get("reasoning_content") or \
choice.get("reasoning_content")
usage = data.get("usage", {})
return {
"content": text,
"reasoning_content": reasoning,
"usage": {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0)
}
}
except (KeyError, IndexError):
return {"content": "", "usage": {}}
class AzureProvider(OpenAIProvider):
def get_headers(self) -> Dict[str, str]:
return {
"Content-Type": "application/json",
"api-key": self.api_key.key
}
def get_url(self) -> str:
# Expected api_base: https://{resource}.openai.azure.com/openai/deployments/{deployment}
# We need to append api-version
base = self.api_key.api_base or ""
sep = "&" if "?" in base else "?"
# Get api-version from extra_config or default
version = "2025-01-01-preview"
if self.api_key.extra_config and "api_version" in self.api_key.extra_config:
version = self.api_key.extra_config["api_version"]
return f"{base}{sep}api-version={version}"
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
payload = super().get_payload(messages, params)
# Azure usually doesn't need 'model' in body if it's in URL (deployment)
# But keeping it usually doesn't hurt, unless configured to remove.
# For simplicity, we keep it or remove it if we want strict adherence to typical azure usage.
# Let's remove it to be safe as deployment is in URL.
payload.pop("model", None)
return payload
class GeminiProvider(BaseProvider):
def get_headers(self) -> Dict[str, str]:
return {"Content-Type": "application/json"}
def get_url(self) -> str:
# api_base example: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent
base = self.api_key.api_base or f"https://generativelanguage.googleapis.com/v1beta/models/{self.api_key.model_name}:generateContent"
sep = "&" if "?" in base else "?"
return f"{base}{sep}key={self.api_key.key}"
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
gemini_contents = []
system_instruction = None
for msg in messages:
role = "model" if msg["role"] == "assistant" else msg["role"]
content = msg["content"]
if role == "system":
text_val = content if isinstance(content, str) else content[0]["text"]
system_instruction = {"parts": [{"text": text_val}]}
continue
parts = []
content_list = content if isinstance(content, list) else [{"type": "text", "text": content}]
for item in content_list:
if item.get("type") == "text":
parts.append({"text": item["text"]})
elif item.get("type") == "image_url":
url = item.get("image_url", {}).get("url", "")
match = re.match(r'data:(.*?);base64,(.*)', url)
if match:
parts.append({
"inline_data": {
"mime_type": match.group(1),
"data": match.group(2)
}
})
if parts:
gemini_contents.append({"role": role, "parts": parts})
generation_config = {}
# 1. Start with Key's default params (mapped to Gemini format if possible, or assume config.py had correct format)
# Note: config.py used "maxOutputTokens" (Gemini style) directly in 'params'.
# We should merge self.api_key.params into generation_config
generation_config.update(self.api_key.params)
# 2. Update with request params (mapped from OpenAI style)
if "max_tokens" in params:
generation_config["maxOutputTokens"] = params["max_tokens"]
if "temperature" in params:
generation_config["temperature"] = params["temperature"]
if "top_p" in params:
generation_config["topP"] = params["top_p"]
# 3. Handle thinking config from Key or Request
# Check Key's extra_config first
if self.api_key.extra_config and "thinking_config" in self.api_key.extra_config:
generation_config["thinkingConfig"] = self.api_key.extra_config["thinking_config"]
# Override with request param if present
if "thinking_config" in params:
generation_config["thinkingConfig"] = params["thinking_config"]
payload = {
"contents": gemini_contents,
"generationConfig": generation_config
}
if system_instruction:
payload["system_instruction"] = system_instruction
return payload
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
candidate = data["candidates"][0]
parts_text = []
if "content" in candidate and "parts" in candidate["content"]:
for part in candidate["content"]["parts"]:
if "text" in part:
parts_text.append(part["text"])
text = "\n".join(parts_text) if parts_text else ""
reasoning = None
if "thinking_process" in candidate:
reasoning = candidate["thinking_process"]
usage_metadata = data.get("usageMetadata", {})
input_tokens = usage_metadata.get("promptTokenCount", 0)
output_tokens = usage_metadata.get("candidatesTokenCount", 0)
total_tokens = input_tokens + output_tokens
return {
"content": text,
"reasoning_content": reasoning,
"usage": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens
}
}
except (KeyError, IndexError):
return {"content": "", "usage": {}}
class ClaudeProvider(BaseProvider):
def get_headers(self) -> Dict[str, str]:
headers = {
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
"x-api-key": self.api_key.key
}
# Check extra_config for version
if self.api_key.extra_config and "anthropic_version" in self.api_key.extra_config:
headers["anthropic-version"] = self.api_key.extra_config["anthropic_version"]
return headers
def get_url(self) -> str:
base = self.api_key.api_base or "https://api.anthropic.com/v1/messages"
return base
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
claude_msgs = []
system_prompt = ""
for msg in messages:
role = "assistant" if msg["role"] == "assistant" else msg["role"]
content = msg["content"]
if role == "system":
system_prompt = content if isinstance(content, str) else content[0]["text"]
continue
# Simplified content handling for text only for now, can expand if needed
claude_msgs.append({"role": role, "content": content})
# Start with Key's default params
payload = {
"model": self.api_key.model_name,
"messages": claude_msgs,
"max_tokens": 1024 # Default fallback
}
# Merge Key params
payload.update(self.api_key.params)
# Merge request params
if "max_tokens" in params:
payload["max_tokens"] = params["max_tokens"]
if "temperature" in params:
payload["temperature"] = params["temperature"]
if system_prompt:
payload["system"] = system_prompt
return payload
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
content_list = data.get("content", [])
text = ""
for item in content_list:
if item.get("type") == "text":
text += item.get("text", "")
usage = data.get("usage", {})
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
return {
"content": text,
"usage": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens
}
}
except Exception:
return {"content": "", "usage": {}}
def get_provider_handler(api_key: APIKey) -> BaseProvider:
providers = {
"google": GeminiProvider,
"gemini": GeminiProvider,
"openai": OpenAIProvider,
"azure": AzureProvider,
"claude": ClaudeProvider,
"anthropic": ClaudeProvider,
"zhipu": OpenAIProvider # Zhipu is mostly OpenAI compatible
}
provider_class = providers.get(api_key.provider, OpenAIProvider)
return provider_class(api_key)