import yaml import time from typing import List, Dict, Optional, Set, Any, Tuple from app.models import APIKey, KeyStatus class KeyManager: def __init__(self, config_path: str): self.config_path = config_path self._keys: Dict[str, APIKey] = {} self._keys_by_model: Dict[str, List[APIKey]] = {} self._config_ids: Set[str] = set() self._last_load_time: float = 0 self._raw_config: Dict[str, Any] = {} self._config_enabled_map: Dict[str, bool] = {} # Track endpoint enable status: {(config_id, idx): enabled} self._endpoint_enabled_map: Dict[Tuple[str, int], bool] = {} # Track endpoint URLs: {(config_id, idx): url} self._endpoint_urls: Dict[Tuple[str, int], str] = {} self.load_keys() def load_keys(self): """Loads keys from the YAML configuration file.""" with open(self.config_path, 'r', encoding='utf-8') as f: data = yaml.safe_load(f) self._raw_config = data self._keys = {} self._keys_by_model = {} self._config_ids = set() self._config_enabled_map = {} configs = {} for cfg in data.get('model_configs', []): if 'id' in cfg: configs[cfg['id']] = cfg for key_group in data.get('keys', []): config_id = key_group.get('config_id') group_enabled = key_group.get('enabled', True) keys_list = key_group.get('keys', []) if config_id: self._config_ids.add(config_id) self._config_enabled_map[config_id] = group_enabled if config_id and config_id in configs: cfg = configs[config_id] cfg_copy = cfg.copy() cfg_copy.pop('id', None) # Check if this config has multiple endpoints (for local llama.cpp load balancing) api_endpoints = cfg_copy.get('api_endpoints', []) if api_endpoints and len(api_endpoints) > 0: # Generate virtual keys for each endpoint for idx, endpoint_item in enumerate(api_endpoints): # Handle both old format (string URL) and new format (dict with url/enabled) if isinstance(endpoint_item, dict): endpoint_url = endpoint_item.get('url', '') endpoint_enabled = endpoint_item.get('enabled', True) else: endpoint_url = endpoint_item endpoint_enabled = True # Store endpoint info for dashboard self._endpoint_enabled_map[(config_id, idx)] = endpoint_enabled self._endpoint_urls[(config_id, idx)] = endpoint_url final_data = {} final_data.update(cfg_copy) final_data.pop('api_endpoints', None) # Remove the list from individual key final_data['config_id'] = config_id # Endpoint is enabled only if both group and endpoint itself are enabled final_data['enabled'] = group_enabled and endpoint_enabled final_data['api_base'] = endpoint_url # Use individual endpoint final_data['key'] = f"local-{config_id}-{idx}" # Virtual key for local service final_data['endpoint_idx'] = idx # Track endpoint index # Virtual key ID that includes endpoint info for tracking endpoint_hash = hash(endpoint_url) & 0xFFFF key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}" final_data['id'] = key_id if not group_enabled or not endpoint_enabled: final_data['status'] = KeyStatus.DISABLED try: key = APIKey(**final_data) self._keys[key.id] = key if key.model_name not in self._keys_by_model: self._keys_by_model[key.model_name] = [] self._keys_by_model[key.model_name].append(key) except Exception as e: print(f"Error loading virtual key {key_id} for endpoint {endpoint_url}: {e}") else: # Standard key loading (original logic) for key_data in keys_list: final_data = {} final_data.update(cfg_copy) final_data['config_id'] = config_id final_data['enabled'] = group_enabled if key_data.get('owner'): final_data['owner'] = key_data['owner'] final_data['key'] = key_data['key'] if not group_enabled: final_data['status'] = KeyStatus.DISABLED key_id = f"{config_id}:{hash(key_data['key']) & 0xFFFFFFFF}" final_data['id'] = key_id try: key = APIKey(**final_data) self._keys[key.id] = key if key.model_name not in self._keys_by_model: self._keys_by_model[key.model_name] = [] self._keys_by_model[key.model_name].append(key) except Exception as e: print(f"Error loading key {key_id}: {e}") else: for key_data in keys_list: final_data = {} final_data.update(key_group) final_data.pop('keys', None) final_data.update(key_data) if 'id' not in final_data: final_data['id'] = key_data.get('key', 'unknown')[:16] if not final_data.get('enabled', True): final_data['status'] = KeyStatus.DISABLED try: key = APIKey(**final_data) self._keys[key.id] = key if key.model_name not in self._keys_by_model: self._keys_by_model[key.model_name] = [] self._keys_by_model[key.model_name].append(key) except Exception as e: print(f"Error loading key {final_data.get('id', 'unknown')}: {e}") self._last_load_time = time.time() def _save_config(self) -> bool: """Save current configuration to YAML file.""" try: for key_group in self._raw_config.get('keys', []): config_id = key_group.get('config_id') if config_id and config_id in self._config_enabled_map: key_group['enabled'] = self._config_enabled_map[config_id] # Save endpoint enable status for (cfg_id, idx), enabled in self._endpoint_enabled_map.items(): for cfg in self._raw_config.get('model_configs', []): if cfg.get('id') == cfg_id: endpoints = cfg.get('api_endpoints', []) if idx < len(endpoints): if isinstance(endpoints[idx], dict): endpoints[idx]['enabled'] = enabled else: # Convert string to dict with url and enabled endpoints[idx] = {'url': endpoints[idx], 'enabled': enabled} with open(self.config_path, 'w', encoding='utf-8') as f: yaml.dump(self._raw_config, f, default_flow_style=False, allow_unicode=True, sort_keys=False) return True except Exception as e: print(f"Error saving config: {e}") return False def reload_keys(self) -> bool: """Reload keys from config file. Returns True if successful.""" try: self.load_keys() return True except Exception as e: print(f"Error reloading keys: {e}") return False def set_config_enabled(self, config_id: str, enabled: bool) -> bool: """Enable or disable all keys for a config_id and save to file.""" if config_id not in self._config_ids: return False self._config_enabled_map[config_id] = enabled updated = False for key in self._keys.values(): if hasattr(key, 'config_id') and key.config_id == config_id: key.enabled = enabled if enabled: key.status = KeyStatus.ACTIVE else: key.status = KeyStatus.DISABLED updated = True if updated: self._save_config() return updated def set_key_enabled(self, key_id: str, enabled: bool) -> bool: """Enable or disable a specific key.""" if key_id not in self._keys: return False key = self._keys[key_id] key.enabled = enabled if enabled: key.status = KeyStatus.ACTIVE else: key.status = KeyStatus.DISABLED return True def get_config_ids(self) -> Set[str]: """Get all config IDs.""" return self._config_ids.copy() def get_keys_by_config(self, config_id: str) -> List[APIKey]: """Get all keys for a specific config_id.""" return [k for k in self._keys.values() if hasattr(k, 'config_id') and k.config_id == config_id] def get_key(self, key_id: str) -> Optional[APIKey]: return self._keys.get(key_id) def get_candidate_keys(self, model_name: Optional[str] = None) -> List[APIKey]: """Returns all keys for a given model, regardless of status. If model_name is None, returns all keys.""" if not model_name: return self.get_all_keys() return self._keys_by_model.get(model_name, []) def get_all_keys(self) -> List[APIKey]: return list(self._keys.values()) def get_last_load_time(self) -> float: return self._last_load_time def get_config_enabled(self, config_id: str) -> bool: """Get enabled status for a config_id.""" return self._config_enabled_map.get(config_id, True) def get_endpoint_info(self, config_id: str) -> List[Dict[str, Any]]: """Get all endpoint info for a config_id.""" endpoints = [] for (cfg_id, idx), url in self._endpoint_urls.items(): if cfg_id == config_id: enabled = self._endpoint_enabled_map.get((cfg_id, idx), True) endpoint_hash = hash(url) & 0xFFFF key_id = f"{config_id}:endpoint:{idx}:{endpoint_hash}" endpoints.append({ 'idx': idx, 'url': url, 'enabled': enabled, 'key_id': key_id }) return sorted(endpoints, key=lambda x: x['idx']) def set_endpoint_enabled(self, config_id: str, endpoint_idx: int, enabled: bool) -> bool: """Enable or disable a specific endpoint and save to file.""" if (config_id, endpoint_idx) not in self._endpoint_enabled_map: return False self._endpoint_enabled_map[(config_id, endpoint_idx)] = enabled # Update the corresponding key's status url = self._endpoint_urls.get((config_id, endpoint_idx), '') endpoint_hash = hash(url) & 0xFFFF key_id = f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}" if key_id in self._keys: key = self._keys[key_id] key.enabled = enabled if enabled: key.status = KeyStatus.ACTIVE else: key.status = KeyStatus.DISABLED self._save_config() return True def get_endpoint_key_id(self, config_id: str, endpoint_idx: int) -> Optional[str]: """Get the key_id for a specific endpoint.""" url = self._endpoint_urls.get((config_id, endpoint_idx)) if url: endpoint_hash = hash(url) & 0xFFFF return f"{config_id}:endpoint:{endpoint_idx}:{endpoint_hash}" return None