312 lines
12 KiB
Python
Executable File
312 lines
12 KiB
Python
Executable File
import re
|
|
from typing import Dict, Any, List, Optional, Tuple, Union
|
|
from app.models import APIKey
|
|
|
|
class BaseProvider:
|
|
def __init__(self, api_key: APIKey):
|
|
self.api_key = api_key
|
|
|
|
def get_headers(self) -> Dict[str, str]:
|
|
raise NotImplementedError
|
|
|
|
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
|
raise NotImplementedError
|
|
|
|
def get_url(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Parses provider response into a standardized format.
|
|
Returns:
|
|
{
|
|
"content": str,
|
|
"reasoning_content": Optional[str],
|
|
"usage": {
|
|
"input_tokens": int,
|
|
"output_tokens": int,
|
|
"total_tokens": int
|
|
}
|
|
}
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
class OpenAIProvider(BaseProvider):
|
|
def get_headers(self) -> Dict[str, str]:
|
|
return {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key.key}"
|
|
}
|
|
|
|
def get_url(self) -> str:
|
|
base = self.api_key.api_base or "https://api.openai.com/v1"
|
|
return f"{base.rstrip('/')}/chat/completions"
|
|
|
|
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
|
data = {"messages": messages, "model": self.api_key.model_name, "stream": False}
|
|
|
|
# 1. Start with Key's default params
|
|
final_params = self.api_key.params.copy()
|
|
|
|
# 2. Update with request-level params (filter None)
|
|
valid_request_params = {k: v for k, v in params.items() if v is not None}
|
|
final_params.update(valid_request_params)
|
|
|
|
# 3. Add to payload
|
|
data.update(final_params)
|
|
|
|
# 4. Handle extra_config from Key (e.g., Qwen extra_body)
|
|
if self.api_key.extra_config:
|
|
if "extra_body" in self.api_key.extra_config:
|
|
data.update(self.api_key.extra_config["extra_body"])
|
|
|
|
return data
|
|
|
|
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
try:
|
|
choice = data["choices"][0]
|
|
text = choice["message"]["content"]
|
|
reasoning = choice["message"].get("reasoning_content") or \
|
|
choice.get("reasoning_content")
|
|
|
|
usage = data.get("usage", {})
|
|
return {
|
|
"content": text,
|
|
"reasoning_content": reasoning,
|
|
"usage": {
|
|
"input_tokens": usage.get("prompt_tokens", 0),
|
|
"output_tokens": usage.get("completion_tokens", 0),
|
|
"total_tokens": usage.get("total_tokens", 0)
|
|
}
|
|
}
|
|
except (KeyError, IndexError):
|
|
return {"content": "", "usage": {}}
|
|
|
|
class AzureProvider(OpenAIProvider):
|
|
def get_headers(self) -> Dict[str, str]:
|
|
return {
|
|
"Content-Type": "application/json",
|
|
"api-key": self.api_key.key
|
|
}
|
|
|
|
def get_url(self) -> str:
|
|
# Expected api_base: https://{resource}.openai.azure.com/openai/deployments/{deployment}
|
|
# We need to append api-version
|
|
base = self.api_key.api_base or ""
|
|
sep = "&" if "?" in base else "?"
|
|
|
|
# Get api-version from extra_config or default
|
|
version = "2025-01-01-preview"
|
|
if self.api_key.extra_config and "api_version" in self.api_key.extra_config:
|
|
version = self.api_key.extra_config["api_version"]
|
|
|
|
return f"{base}{sep}api-version={version}"
|
|
|
|
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
|
payload = super().get_payload(messages, params)
|
|
# Azure usually doesn't need 'model' in body if it's in URL (deployment)
|
|
# But keeping it usually doesn't hurt, unless configured to remove.
|
|
# For simplicity, we keep it or remove it if we want strict adherence to typical azure usage.
|
|
# Let's remove it to be safe as deployment is in URL.
|
|
payload.pop("model", None)
|
|
return payload
|
|
|
|
class GeminiProvider(BaseProvider):
|
|
def get_headers(self) -> Dict[str, str]:
|
|
return {"Content-Type": "application/json"}
|
|
|
|
def get_url(self) -> str:
|
|
# api_base example: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent
|
|
base = self.api_key.api_base or f"https://generativelanguage.googleapis.com/v1beta/models/{self.api_key.model_name}:generateContent"
|
|
sep = "&" if "?" in base else "?"
|
|
return f"{base}{sep}key={self.api_key.key}"
|
|
|
|
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
|
gemini_contents = []
|
|
system_instruction = None
|
|
|
|
for msg in messages:
|
|
role = "model" if msg["role"] == "assistant" else msg["role"]
|
|
content = msg["content"]
|
|
|
|
if role == "system":
|
|
text_val = content if isinstance(content, str) else content[0]["text"]
|
|
system_instruction = {"parts": [{"text": text_val}]}
|
|
continue
|
|
|
|
parts = []
|
|
content_list = content if isinstance(content, list) else [{"type": "text", "text": content}]
|
|
|
|
for item in content_list:
|
|
if item.get("type") == "text":
|
|
parts.append({"text": item["text"]})
|
|
elif item.get("type") == "image_url":
|
|
url = item.get("image_url", {}).get("url", "")
|
|
match = re.match(r'data:(.*?);base64,(.*)', url)
|
|
if match:
|
|
parts.append({
|
|
"inline_data": {
|
|
"mime_type": match.group(1),
|
|
"data": match.group(2)
|
|
}
|
|
})
|
|
|
|
if parts:
|
|
gemini_contents.append({"role": role, "parts": parts})
|
|
|
|
generation_config = {}
|
|
|
|
# 1. Start with Key's default params (mapped to Gemini format if possible, or assume config.py had correct format)
|
|
# Note: config.py used "maxOutputTokens" (Gemini style) directly in 'params'.
|
|
# We should merge self.api_key.params into generation_config
|
|
generation_config.update(self.api_key.params)
|
|
|
|
# 2. Update with request params (mapped from OpenAI style)
|
|
if "max_tokens" in params:
|
|
generation_config["maxOutputTokens"] = params["max_tokens"]
|
|
if "temperature" in params:
|
|
generation_config["temperature"] = params["temperature"]
|
|
if "top_p" in params:
|
|
generation_config["topP"] = params["top_p"]
|
|
|
|
# 3. Handle thinking config from Key or Request
|
|
# Check Key's extra_config first
|
|
if self.api_key.extra_config and "thinking_config" in self.api_key.extra_config:
|
|
generation_config["thinkingConfig"] = self.api_key.extra_config["thinking_config"]
|
|
|
|
# Override with request param if present
|
|
if "thinking_config" in params:
|
|
generation_config["thinkingConfig"] = params["thinking_config"]
|
|
|
|
payload = {
|
|
"contents": gemini_contents,
|
|
"generationConfig": generation_config
|
|
}
|
|
|
|
if system_instruction:
|
|
payload["system_instruction"] = system_instruction
|
|
|
|
return payload
|
|
|
|
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
try:
|
|
candidate = data["candidates"][0]
|
|
parts_text = []
|
|
if "content" in candidate and "parts" in candidate["content"]:
|
|
for part in candidate["content"]["parts"]:
|
|
if "text" in part:
|
|
parts_text.append(part["text"])
|
|
|
|
text = "\n".join(parts_text) if parts_text else ""
|
|
|
|
reasoning = None
|
|
if "thinking_process" in candidate:
|
|
reasoning = candidate["thinking_process"]
|
|
|
|
usage_metadata = data.get("usageMetadata", {})
|
|
input_tokens = usage_metadata.get("promptTokenCount", 0)
|
|
output_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
|
total_tokens = input_tokens + output_tokens
|
|
|
|
return {
|
|
"content": text,
|
|
"reasoning_content": reasoning,
|
|
"usage": {
|
|
"input_tokens": input_tokens,
|
|
"output_tokens": output_tokens,
|
|
"total_tokens": total_tokens
|
|
}
|
|
}
|
|
except (KeyError, IndexError):
|
|
return {"content": "", "usage": {}}
|
|
|
|
class ClaudeProvider(BaseProvider):
|
|
def get_headers(self) -> Dict[str, str]:
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"anthropic-version": "2023-06-01",
|
|
"x-api-key": self.api_key.key
|
|
}
|
|
|
|
# Check extra_config for version
|
|
if self.api_key.extra_config and "anthropic_version" in self.api_key.extra_config:
|
|
headers["anthropic-version"] = self.api_key.extra_config["anthropic_version"]
|
|
|
|
return headers
|
|
|
|
def get_url(self) -> str:
|
|
base = self.api_key.api_base or "https://api.anthropic.com/v1/messages"
|
|
return base
|
|
|
|
def get_payload(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Dict[str, Any]:
|
|
claude_msgs = []
|
|
system_prompt = ""
|
|
|
|
for msg in messages:
|
|
role = "assistant" if msg["role"] == "assistant" else msg["role"]
|
|
content = msg["content"]
|
|
|
|
if role == "system":
|
|
system_prompt = content if isinstance(content, str) else content[0]["text"]
|
|
continue
|
|
|
|
# Simplified content handling for text only for now, can expand if needed
|
|
claude_msgs.append({"role": role, "content": content})
|
|
|
|
# Start with Key's default params
|
|
payload = {
|
|
"model": self.api_key.model_name,
|
|
"messages": claude_msgs,
|
|
"max_tokens": 1024 # Default fallback
|
|
}
|
|
|
|
# Merge Key params
|
|
payload.update(self.api_key.params)
|
|
|
|
# Merge request params
|
|
if "max_tokens" in params:
|
|
payload["max_tokens"] = params["max_tokens"]
|
|
if "temperature" in params:
|
|
payload["temperature"] = params["temperature"]
|
|
|
|
if system_prompt:
|
|
payload["system"] = system_prompt
|
|
|
|
return payload
|
|
|
|
def parse_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
try:
|
|
content_list = data.get("content", [])
|
|
text = ""
|
|
for item in content_list:
|
|
if item.get("type") == "text":
|
|
text += item.get("text", "")
|
|
|
|
usage = data.get("usage", {})
|
|
input_tokens = usage.get("input_tokens", 0)
|
|
output_tokens = usage.get("output_tokens", 0)
|
|
|
|
return {
|
|
"content": text,
|
|
"usage": {
|
|
"input_tokens": input_tokens,
|
|
"output_tokens": output_tokens,
|
|
"total_tokens": input_tokens + output_tokens
|
|
}
|
|
}
|
|
except Exception:
|
|
return {"content": "", "usage": {}}
|
|
|
|
def get_provider_handler(api_key: APIKey) -> BaseProvider:
|
|
providers = {
|
|
"google": GeminiProvider,
|
|
"gemini": GeminiProvider,
|
|
"openai": OpenAIProvider,
|
|
"azure": AzureProvider,
|
|
"claude": ClaudeProvider,
|
|
"anthropic": ClaudeProvider,
|
|
"zhipu": OpenAIProvider # Zhipu is mostly OpenAI compatible
|
|
}
|
|
provider_class = providers.get(api_key.provider, OpenAIProvider)
|
|
return provider_class(api_key)
|