295 lines
13 KiB
Python
Executable File
295 lines
13 KiB
Python
Executable File
import yaml
|
|
import time
|
|
from typing import List, Dict, Optional, Set, Any, Tuple
|
|
from app.models import APIKey, KeyStatus
|
|
|
|
class KeyManager:
|
|
def __init__(self, config_path: str):
|
|
self.config_path = config_path
|
|
self._keys: Dict[str, APIKey] = {}
|
|
self._keys_by_model: Dict[str, List[APIKey]] = {}
|
|
self._config_ids: Set[str] = set()
|
|
self._last_load_time: float = 0
|
|
self._raw_config: Dict[str, Any] = {}
|
|
self._config_enabled_map: Dict[str, bool] = {}
|
|
# Track endpoint enable status: {(config_id, idx): enabled}
|
|
self._endpoint_enabled_map: Dict[Tuple[str, int], bool] = {}
|
|
# Track endpoint URLs: {(config_id, idx): url}
|
|
self._endpoint_urls: Dict[Tuple[str, int], str] = {}
|
|
self.load_keys()
|
|
|
|
def load_keys(self):
|
|
"""Loads keys from the YAML configuration file."""
|
|
with open(self.config_path, 'r', encoding='utf-8') as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
self._raw_config = data
|
|
self._keys = {}
|
|
self._keys_by_model = {}
|
|
self._config_ids = set()
|
|
self._config_enabled_map = {}
|
|
|
|
configs = {}
|
|
for cfg in data.get('model_configs', []):
|
|
if 'id' in cfg:
|
|
configs[cfg['id']] = cfg
|
|
|
|
for key_group in data.get('keys', []):
|
|
config_id = key_group.get('config_id')
|
|
group_enabled = key_group.get('enabled', True)
|
|
keys_list = key_group.get('keys', [])
|
|
|
|
if config_id:
|
|
self._config_ids.add(config_id)
|
|
self._config_enabled_map[config_id] = group_enabled
|
|
|
|
if config_id and config_id in configs:
|
|
cfg = configs[config_id]
|
|
cfg_copy = cfg.copy()
|
|
cfg_copy.pop('id', None)
|
|
|
|
# Check if this config has multiple endpoints (for local llama.cpp load balancing)
|
|
api_endpoints = cfg_copy.get('api_endpoints', [])
|
|
|
|
if api_endpoints and len(api_endpoints) > 0:
|
|
# Generate virtual keys for each endpoint
|
|
for idx, endpoint_item in enumerate(api_endpoints):
|
|
# Handle both old format (string URL) and new format (dict with url/enabled)
|
|
if isinstance(endpoint_item, dict):
|
|
endpoint_url = endpoint_item.get('url', '')
|
|
endpoint_enabled = endpoint_item.get('enabled', True)
|
|
else:
|
|
endpoint_url = endpoint_item
|
|
endpoint_enabled = True
|
|
|
|
# Store endpoint info for dashboard
|
|
self._endpoint_enabled_map[(config_id, idx)] = endpoint_enabled
|
|
self._endpoint_urls[(config_id, idx)] = endpoint_url
|
|
|
|
final_data = {}
|
|
final_data.update(cfg_copy)
|
|
final_data.pop('api_endpoints', None) # Remove the list from individual key
|
|
final_data['config_id'] = config_id
|
|
# Endpoint is enabled only if both group and endpoint itself are enabled
|
|
final_data['enabled'] = group_enabled and endpoint_enabled
|
|
final_data['api_base'] = endpoint_url # Use individual endpoint
|
|
final_data['key'] = f"local-{config_id}-{idx}" # Virtual key for local service
|
|
final_data['endpoint_idx'] = idx # Track endpoint index
|
|
|
|
# Virtual key ID that includes endpoint info for tracking
|
|
endpoint_hash = hash(endpoint_url) & 0xFFFF
|
|
key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}"
|
|
final_data['id'] = key_id
|
|
|
|
if not group_enabled or not endpoint_enabled:
|
|
final_data['status'] = KeyStatus.DISABLED
|
|
|
|
try:
|
|
key = APIKey(**final_data)
|
|
self._keys[key.id] = key
|
|
|
|
if key.model_name not in self._keys_by_model:
|
|
self._keys_by_model[key.model_name] = []
|
|
self._keys_by_model[key.model_name].append(key)
|
|
except Exception as e:
|
|
print(f"Error loading virtual key {key_id} for endpoint {endpoint_url}: {e}")
|
|
else:
|
|
# Standard key loading (original logic)
|
|
for key_data in keys_list:
|
|
final_data = {}
|
|
final_data.update(cfg_copy)
|
|
final_data['config_id'] = config_id
|
|
final_data['enabled'] = group_enabled
|
|
|
|
if key_data.get('owner'):
|
|
final_data['owner'] = key_data['owner']
|
|
|
|
final_data['key'] = key_data['key']
|
|
|
|
if not group_enabled:
|
|
final_data['status'] = KeyStatus.DISABLED
|
|
|
|
key_id = f"{config_id}:{hash(key_data['key']) & 0xFFFFFFFF}"
|
|
final_data['id'] = key_id
|
|
|
|
try:
|
|
key = APIKey(**final_data)
|
|
self._keys[key.id] = key
|
|
|
|
if key.model_name not in self._keys_by_model:
|
|
self._keys_by_model[key.model_name] = []
|
|
self._keys_by_model[key.model_name].append(key)
|
|
except Exception as e:
|
|
print(f"Error loading key {key_id}: {e}")
|
|
else:
|
|
for key_data in keys_list:
|
|
final_data = {}
|
|
final_data.update(key_group)
|
|
final_data.pop('keys', None)
|
|
final_data.update(key_data)
|
|
|
|
if 'id' not in final_data:
|
|
final_data['id'] = key_data.get('key', 'unknown')[:16]
|
|
|
|
if not final_data.get('enabled', True):
|
|
final_data['status'] = KeyStatus.DISABLED
|
|
|
|
try:
|
|
key = APIKey(**final_data)
|
|
self._keys[key.id] = key
|
|
|
|
if key.model_name not in self._keys_by_model:
|
|
self._keys_by_model[key.model_name] = []
|
|
self._keys_by_model[key.model_name].append(key)
|
|
except Exception as e:
|
|
print(f"Error loading key {final_data.get('id', 'unknown')}: {e}")
|
|
|
|
self._last_load_time = time.time()
|
|
|
|
def _save_config(self) -> bool:
|
|
"""Save current configuration to YAML file."""
|
|
try:
|
|
for key_group in self._raw_config.get('keys', []):
|
|
config_id = key_group.get('config_id')
|
|
if config_id and config_id in self._config_enabled_map:
|
|
key_group['enabled'] = self._config_enabled_map[config_id]
|
|
|
|
# Save endpoint enable status
|
|
for (cfg_id, idx), enabled in self._endpoint_enabled_map.items():
|
|
for cfg in self._raw_config.get('model_configs', []):
|
|
if cfg.get('id') == cfg_id:
|
|
endpoints = cfg.get('api_endpoints', [])
|
|
if idx < len(endpoints):
|
|
if isinstance(endpoints[idx], dict):
|
|
endpoints[idx]['enabled'] = enabled
|
|
else:
|
|
# Convert string to dict with url and enabled
|
|
endpoints[idx] = {'url': endpoints[idx], 'enabled': enabled}
|
|
|
|
with open(self.config_path, 'w', encoding='utf-8') as f:
|
|
yaml.dump(self._raw_config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error saving config: {e}")
|
|
return False
|
|
|
|
def reload_keys(self) -> bool:
|
|
"""Reload keys from config file. Returns True if successful."""
|
|
try:
|
|
self.load_keys()
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error reloading keys: {e}")
|
|
return False
|
|
|
|
def set_config_enabled(self, config_id: str, enabled: bool) -> bool:
|
|
"""Enable or disable all keys for a config_id and save to file."""
|
|
if config_id not in self._config_ids:
|
|
return False
|
|
|
|
self._config_enabled_map[config_id] = enabled
|
|
|
|
updated = False
|
|
for key in self._keys.values():
|
|
if hasattr(key, 'config_id') and key.config_id == config_id:
|
|
key.enabled = enabled
|
|
if enabled:
|
|
key.status = KeyStatus.ACTIVE
|
|
else:
|
|
key.status = KeyStatus.DISABLED
|
|
updated = True
|
|
|
|
if updated:
|
|
self._save_config()
|
|
|
|
return updated
|
|
|
|
def set_key_enabled(self, key_id: str, enabled: bool) -> bool:
|
|
"""Enable or disable a specific key."""
|
|
if key_id not in self._keys:
|
|
return False
|
|
|
|
key = self._keys[key_id]
|
|
key.enabled = enabled
|
|
if enabled:
|
|
key.status = KeyStatus.ACTIVE
|
|
else:
|
|
key.status = KeyStatus.DISABLED
|
|
|
|
return True
|
|
|
|
def get_config_ids(self) -> Set[str]:
|
|
"""Get all config IDs."""
|
|
return self._config_ids.copy()
|
|
|
|
def get_keys_by_config(self, config_id: str) -> List[APIKey]:
|
|
"""Get all keys for a specific config_id."""
|
|
return [k for k in self._keys.values() if hasattr(k, 'config_id') and k.config_id == config_id]
|
|
|
|
def get_key(self, key_id: str) -> Optional[APIKey]:
|
|
return self._keys.get(key_id)
|
|
|
|
def get_candidate_keys(self, model_name: Optional[str] = None) -> List[APIKey]:
|
|
"""Returns all keys for a given model, regardless of status. If model_name is None, returns all keys."""
|
|
if not model_name:
|
|
return self.get_all_keys()
|
|
return self._keys_by_model.get(model_name, [])
|
|
|
|
def get_all_keys(self) -> List[APIKey]:
|
|
return list(self._keys.values())
|
|
|
|
def get_last_load_time(self) -> float:
|
|
return self._last_load_time
|
|
|
|
def get_config_enabled(self, config_id: str) -> bool:
|
|
"""Get enabled status for a config_id."""
|
|
return self._config_enabled_map.get(config_id, True)
|
|
|
|
def get_endpoint_info(self, config_id: str) -> List[Dict[str, Any]]:
|
|
"""Get all endpoint info for a config_id."""
|
|
endpoints = []
|
|
for (cfg_id, idx), url in self._endpoint_urls.items():
|
|
if cfg_id == config_id:
|
|
enabled = self._endpoint_enabled_map.get((cfg_id, idx), True)
|
|
endpoint_hash = hash(url) & 0xFFFF
|
|
key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}"
|
|
endpoints.append({
|
|
'idx': idx,
|
|
'url': url,
|
|
'enabled': enabled,
|
|
'key_id': key_id
|
|
})
|
|
return sorted(endpoints, key=lambda x: x['idx'])
|
|
|
|
def set_endpoint_enabled(self, config_id: str, endpoint_idx: int, enabled: bool) -> bool:
|
|
"""Enable or disable a specific endpoint and save to file."""
|
|
if (config_id, endpoint_idx) not in self._endpoint_enabled_map:
|
|
return False
|
|
|
|
self._endpoint_enabled_map[(config_id, endpoint_idx)] = enabled
|
|
|
|
# Update the corresponding key's status
|
|
url = self._endpoint_urls.get((config_id, endpoint_idx), '')
|
|
endpoint_hash = hash(url) & 0xFFFF
|
|
key_id = f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}"
|
|
|
|
if key_id in self._keys:
|
|
key = self._keys[key_id]
|
|
key.enabled = enabled
|
|
if enabled:
|
|
key.status = KeyStatus.ACTIVE
|
|
else:
|
|
key.status = KeyStatus.DISABLED
|
|
|
|
self._save_config()
|
|
return True
|
|
|
|
def get_endpoint_key_id(self, config_id: str, endpoint_idx: int) -> Optional[str]:
|
|
"""Get the key_id for a specific endpoint."""
|
|
url = self._endpoint_urls.get((config_id, endpoint_idx))
|
|
if url:
|
|
endpoint_hash = hash(url) & 0xFFFF
|
|
return f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}"
|
|
return None
|