autool-test/app/core/key_manager.py
2026-06-17 11:13:11 +08:00

295 lines
13 KiB
Python
Executable File

import yaml
import time
from typing import List, Dict, Optional, Set, Any, Tuple
from app.models import APIKey, KeyStatus
class KeyManager:
def __init__(self, config_path: str):
self.config_path = config_path
self._keys: Dict[str, APIKey] = {}
self._keys_by_model: Dict[str, List[APIKey]] = {}
self._config_ids: Set[str] = set()
self._last_load_time: float = 0
self._raw_config: Dict[str, Any] = {}
self._config_enabled_map: Dict[str, bool] = {}
# Track endpoint enable status: {(config_id, idx): enabled}
self._endpoint_enabled_map: Dict[Tuple[str, int], bool] = {}
# Track endpoint URLs: {(config_id, idx): url}
self._endpoint_urls: Dict[Tuple[str, int], str] = {}
self.load_keys()
def load_keys(self):
"""Loads keys from the YAML configuration file."""
with open(self.config_path, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
self._raw_config = data
self._keys = {}
self._keys_by_model = {}
self._config_ids = set()
self._config_enabled_map = {}
configs = {}
for cfg in data.get('model_configs', []):
if 'id' in cfg:
configs[cfg['id']] = cfg
for key_group in data.get('keys', []):
config_id = key_group.get('config_id')
group_enabled = key_group.get('enabled', True)
keys_list = key_group.get('keys', [])
if config_id:
self._config_ids.add(config_id)
self._config_enabled_map[config_id] = group_enabled
if config_id and config_id in configs:
cfg = configs[config_id]
cfg_copy = cfg.copy()
cfg_copy.pop('id', None)
# Check if this config has multiple endpoints (for local llama.cpp load balancing)
api_endpoints = cfg_copy.get('api_endpoints', [])
if api_endpoints and len(api_endpoints) > 0:
# Generate virtual keys for each endpoint
for idx, endpoint_item in enumerate(api_endpoints):
# Handle both old format (string URL) and new format (dict with url/enabled)
if isinstance(endpoint_item, dict):
endpoint_url = endpoint_item.get('url', '')
endpoint_enabled = endpoint_item.get('enabled', True)
else:
endpoint_url = endpoint_item
endpoint_enabled = True
# Store endpoint info for dashboard
self._endpoint_enabled_map[(config_id, idx)] = endpoint_enabled
self._endpoint_urls[(config_id, idx)] = endpoint_url
final_data = {}
final_data.update(cfg_copy)
final_data.pop('api_endpoints', None) # Remove the list from individual key
final_data['config_id'] = config_id
# Endpoint is enabled only if both group and endpoint itself are enabled
final_data['enabled'] = group_enabled and endpoint_enabled
final_data['api_base'] = endpoint_url # Use individual endpoint
final_data['key'] = f"local-{config_id}-{idx}" # Virtual key for local service
final_data['endpoint_idx'] = idx # Track endpoint index
# Virtual key ID that includes endpoint info for tracking
endpoint_hash = hash(endpoint_url) & 0xFFFF
key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}"
final_data['id'] = key_id
if not group_enabled or not endpoint_enabled:
final_data['status'] = KeyStatus.DISABLED
try:
key = APIKey(**final_data)
self._keys[key.id] = key
if key.model_name not in self._keys_by_model:
self._keys_by_model[key.model_name] = []
self._keys_by_model[key.model_name].append(key)
except Exception as e:
print(f"Error loading virtual key {key_id} for endpoint {endpoint_url}: {e}")
else:
# Standard key loading (original logic)
for key_data in keys_list:
final_data = {}
final_data.update(cfg_copy)
final_data['config_id'] = config_id
final_data['enabled'] = group_enabled
if key_data.get('owner'):
final_data['owner'] = key_data['owner']
final_data['key'] = key_data['key']
if not group_enabled:
final_data['status'] = KeyStatus.DISABLED
key_id = f"{config_id}:{hash(key_data['key']) & 0xFFFFFFFF}"
final_data['id'] = key_id
try:
key = APIKey(**final_data)
self._keys[key.id] = key
if key.model_name not in self._keys_by_model:
self._keys_by_model[key.model_name] = []
self._keys_by_model[key.model_name].append(key)
except Exception as e:
print(f"Error loading key {key_id}: {e}")
else:
for key_data in keys_list:
final_data = {}
final_data.update(key_group)
final_data.pop('keys', None)
final_data.update(key_data)
if 'id' not in final_data:
final_data['id'] = key_data.get('key', 'unknown')[:16]
if not final_data.get('enabled', True):
final_data['status'] = KeyStatus.DISABLED
try:
key = APIKey(**final_data)
self._keys[key.id] = key
if key.model_name not in self._keys_by_model:
self._keys_by_model[key.model_name] = []
self._keys_by_model[key.model_name].append(key)
except Exception as e:
print(f"Error loading key {final_data.get('id', 'unknown')}: {e}")
self._last_load_time = time.time()
def _save_config(self) -> bool:
"""Save current configuration to YAML file."""
try:
for key_group in self._raw_config.get('keys', []):
config_id = key_group.get('config_id')
if config_id and config_id in self._config_enabled_map:
key_group['enabled'] = self._config_enabled_map[config_id]
# Save endpoint enable status
for (cfg_id, idx), enabled in self._endpoint_enabled_map.items():
for cfg in self._raw_config.get('model_configs', []):
if cfg.get('id') == cfg_id:
endpoints = cfg.get('api_endpoints', [])
if idx < len(endpoints):
if isinstance(endpoints[idx], dict):
endpoints[idx]['enabled'] = enabled
else:
# Convert string to dict with url and enabled
endpoints[idx] = {'url': endpoints[idx], 'enabled': enabled}
with open(self.config_path, 'w', encoding='utf-8') as f:
yaml.dump(self._raw_config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
return True
except Exception as e:
print(f"Error saving config: {e}")
return False
def reload_keys(self) -> bool:
"""Reload keys from config file. Returns True if successful."""
try:
self.load_keys()
return True
except Exception as e:
print(f"Error reloading keys: {e}")
return False
def set_config_enabled(self, config_id: str, enabled: bool) -> bool:
"""Enable or disable all keys for a config_id and save to file."""
if config_id not in self._config_ids:
return False
self._config_enabled_map[config_id] = enabled
updated = False
for key in self._keys.values():
if hasattr(key, 'config_id') and key.config_id == config_id:
key.enabled = enabled
if enabled:
key.status = KeyStatus.ACTIVE
else:
key.status = KeyStatus.DISABLED
updated = True
if updated:
self._save_config()
return updated
def set_key_enabled(self, key_id: str, enabled: bool) -> bool:
"""Enable or disable a specific key."""
if key_id not in self._keys:
return False
key = self._keys[key_id]
key.enabled = enabled
if enabled:
key.status = KeyStatus.ACTIVE
else:
key.status = KeyStatus.DISABLED
return True
def get_config_ids(self) -> Set[str]:
"""Get all config IDs."""
return self._config_ids.copy()
def get_keys_by_config(self, config_id: str) -> List[APIKey]:
"""Get all keys for a specific config_id."""
return [k for k in self._keys.values() if hasattr(k, 'config_id') and k.config_id == config_id]
def get_key(self, key_id: str) -> Optional[APIKey]:
return self._keys.get(key_id)
def get_candidate_keys(self, model_name: Optional[str] = None) -> List[APIKey]:
"""Returns all keys for a given model, regardless of status. If model_name is None, returns all keys."""
if not model_name:
return self.get_all_keys()
return self._keys_by_model.get(model_name, [])
def get_all_keys(self) -> List[APIKey]:
return list(self._keys.values())
def get_last_load_time(self) -> float:
return self._last_load_time
def get_config_enabled(self, config_id: str) -> bool:
"""Get enabled status for a config_id."""
return self._config_enabled_map.get(config_id, True)
def get_endpoint_info(self, config_id: str) -> List[Dict[str, Any]]:
"""Get all endpoint info for a config_id."""
endpoints = []
for (cfg_id, idx), url in self._endpoint_urls.items():
if cfg_id == config_id:
enabled = self._endpoint_enabled_map.get((cfg_id, idx), True)
endpoint_hash = hash(url) & 0xFFFF
key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}"
endpoints.append({
'idx': idx,
'url': url,
'enabled': enabled,
'key_id': key_id
})
return sorted(endpoints, key=lambda x: x['idx'])
def set_endpoint_enabled(self, config_id: str, endpoint_idx: int, enabled: bool) -> bool:
"""Enable or disable a specific endpoint and save to file."""
if (config_id, endpoint_idx) not in self._endpoint_enabled_map:
return False
self._endpoint_enabled_map[(config_id, endpoint_idx)] = enabled
# Update the corresponding key's status
url = self._endpoint_urls.get((config_id, endpoint_idx), '')
endpoint_hash = hash(url) & 0xFFFF
key_id = f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}"
if key_id in self._keys:
key = self._keys[key_id]
key.enabled = enabled
if enabled:
key.status = KeyStatus.ACTIVE
else:
key.status = KeyStatus.DISABLED
self._save_config()
return True
def get_endpoint_key_id(self, config_id: str, endpoint_idx: int) -> Optional[str]:
"""Get the key_id for a specific endpoint."""
url = self._endpoint_urls.get((config_id, endpoint_idx))
if url:
endpoint_hash = hash(url) & 0xFFFF
return f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}"
return None