diff --git a/optillm/__init__.py b/optillm/__init__.py index f495d37..af21e78 100644 --- a/optillm/__init__.py +++ b/optillm/__init__.py @@ -2,7 +2,7 @@ import os # Version information -__version__ = "0.1.15" +__version__ = "0.1.16" # Get the path to the root optillm.py spec = util.spec_from_file_location( diff --git a/optillm/inference.py b/optillm/inference.py index 624b5a5..07640a4 100644 --- a/optillm/inference.py +++ b/optillm/inference.py @@ -16,6 +16,8 @@ import time import threading import traceback +import platform +import sys from optillm.cot_decoding import cot_decode from optillm.entropy_decoding import entropy_decode @@ -26,6 +28,17 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +# MLX Support for Apple Silicon +try: + import mlx.core as mx + from mlx_lm import load as mlx_load, generate as mlx_generate + from mlx_lm.tokenizer_utils import TokenizerWrapper + MLX_AVAILABLE = True + logger.info("MLX framework available") +except ImportError: + MLX_AVAILABLE = False + logger.debug("MLX framework not available - falling back to PyTorch") + @dataclass class ModelConfig: base_model_id: str @@ -162,6 +175,302 @@ def calculate_logprobs( bytes_per_token=all_bytes ) +# MLX Support Functions and Classes + +def is_apple_silicon() -> bool: + """Check if running on Apple Silicon""" + return platform.system() == "Darwin" and platform.machine() == "arm64" + +def should_use_mlx(model_id: str) -> bool: + """Determine if a model should use MLX instead of PyTorch""" + if not MLX_AVAILABLE or not is_apple_silicon(): + return False + + # Models that should use MLX + mlx_patterns = [ + "mlx-community/", + "mlx-" + ] + + # Known problematic models that should prefer MLX on Apple Silicon + problematic_models = [ + "Qwen/Qwen3-", + "google/gemma-3-", + "google/gemma3-" + ] + + model_lower = model_id.lower() + + # Direct MLX model detection + for pattern in mlx_patterns: + if pattern.lower() in model_lower: + return True + + # Problematic model detection + for pattern in problematic_models: + if pattern.lower() in model_lower: + logger.warning(f"Model {model_id} detected as potentially problematic with MPS backend") + suggested_mlx = suggest_mlx_alternative(model_id) + logger.warning(f"Consider using MLX model: {suggested_mlx}") + # Don't auto-switch, but recommend + return False + + return False + +def suggest_mlx_alternative(model_id: str) -> str: + """Suggest MLX alternative for a given model""" + mlx_alternatives = { + # Qwen3 models + "Qwen/Qwen3-0.6B": "mlx-community/Qwen3-0.6B-4bit", + "Qwen/Qwen3-1.7B": "mlx-community/Qwen3-1.7B-4bit", + "Qwen/Qwen3-4B": "mlx-community/Qwen3-4B-4bit", + "Qwen/Qwen3-8B": "mlx-community/Qwen3-8B-4bit", + "Qwen/Qwen3-14B": "mlx-community/Qwen3-14B-4bit", + "Qwen/Qwen3-32B": "mlx-community/Qwen3-32B-4bit", + + # Gemma 3 models + "google/gemma-3-1b-it": "mlx-community/gemma-3-1b-it-4bit", + "google/gemma-3-4b-it": "mlx-community/gemma-3-4b-it-4bit", + "google/gemma-3-12b-it": "mlx-community/gemma-3-12b-it-4bit", + "google/gemma-3-27b-it": "mlx-community/gemma-3-27b-it-4bit", + } + + return mlx_alternatives.get(model_id, f"mlx-community/{model_id.split('/')[-1]}-4bit") + +@dataclass +class MLXModelConfig: + """Configuration for MLX models""" + model_id: str + max_new_tokens: int = 4096 + temperature: float = 0.7 + top_p: float = 0.9 + repetition_penalty: float = 1.0 + enable_prompt_caching: bool = True + +class MLXInferencePipeline: + """MLX-based inference pipeline that mirrors PyTorch pipeline interface""" + + def __init__(self, model_config: MLXModelConfig, cache_manager): + self.model_config = model_config + self.cache_manager = cache_manager + self.last_used = time.time() + + if not MLX_AVAILABLE: + raise RuntimeError("MLX framework not available. Install with: pip install mlx-lm") + + if not is_apple_silicon(): + raise RuntimeError("MLX framework is only supported on Apple Silicon") + + try: + logger.info(f"Loading MLX model: {model_config.model_id}") + self.model, self.tokenizer = self._load_mlx_model(model_config.model_id) + logger.info("MLX model loaded successfully") + except Exception as e: + logger.error(f"Failed to load MLX model: {str(e)}") + raise + + def _load_mlx_model(self, model_id: str): + """Load MLX model and tokenizer with caching""" + def _load_model(): + start_time = time.time() + logger.info(f"Loading MLX model: {model_id}") + + try: + model, tokenizer = mlx_load(model_id) + load_time = time.time() - start_time + logger.info(f"MLX model loaded in {load_time:.2f}s") + return model, tokenizer + except Exception as e: + logger.error(f"Error loading MLX model {model_id}: {str(e)}") + raise + + return self.cache_manager.get_or_load_model(f"mlx_{model_id}", _load_model) + + def generate( + self, + prompt: str, + generation_params: Optional[Dict[str, Any]] = None + ) -> Tuple[List[str], List[int], List[Optional[Dict]]]: + """Generate text using MLX""" + start_time = time.time() + + if generation_params is None: + generation_params = {} + + # Extract parameters with defaults + max_tokens = generation_params.get("max_new_tokens", self.model_config.max_new_tokens) + temperature = generation_params.get("temperature", self.model_config.temperature) + top_p = generation_params.get("top_p", self.model_config.top_p) + repetition_penalty = generation_params.get("repetition_penalty", self.model_config.repetition_penalty) + num_return_sequences = generation_params.get("num_return_sequences", 1) + + # Handle seed + if generation_params.get("seed") is not None: + mx.random.seed(generation_params["seed"]) + + responses = [] + token_counts = [] + logprobs_results = [] + + # Generate multiple sequences if requested + for _ in range(num_return_sequences): + try: + logger.debug(f"Generating with MLX: max_tokens={max_tokens}, temp={temperature}") + + # Use robust MLX generation with multiple fallback approaches + response = self._robust_mlx_generate( + prompt, max_tokens, temperature, top_p, repetition_penalty + ) + + responses.append(response) + + # Count tokens (approximate) - check if response is string + if isinstance(response, str): + token_count = len(self.tokenizer.encode(response)) + else: + # Sometimes MLX returns just the new tokens, get the actual text + token_count = len(response) if hasattr(response, '__len__') else 0 + token_counts.append(token_count) + + # MLX doesn't provide logprobs by default + logprobs_results.append(None) + + except Exception as e: + logger.error(f"Error during MLX generation: {str(e)}") + logger.error(f"MLX generation parameters: max_tokens={max_tokens}, temp={temperature}, top_p={top_p}") + responses.append("") + token_counts.append(0) + logprobs_results.append(None) + + generation_time = time.time() - start_time + logger.info(f"MLX generation completed in {generation_time:.2f}s") + + return responses, token_counts, logprobs_results + + def _robust_mlx_generate(self, prompt: str, max_tokens: int, temperature: float, top_p: float, repetition_penalty: float) -> str: + """Robust MLX generation with multiple parameter combinations""" + + # Try different parameter combinations based on MLX-LM version + parameter_combinations = [ + # Version 1: Current style with positional args and temp + { + "style": "positional_temp", + "args": (self.model, self.tokenizer, prompt), + "kwargs": { + "max_tokens": max_tokens, + "temp": temperature, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "verbose": False + } + }, + # Version 2: All keyword arguments with temp + { + "style": "keyword_temp", + "args": (), + "kwargs": { + "model": self.model, + "tokenizer": self.tokenizer, + "prompt": prompt, + "max_tokens": max_tokens, + "temp": temperature, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "verbose": False + } + }, + # Version 3: Using temperature instead of temp + { + "style": "positional_temperature", + "args": (self.model, self.tokenizer, prompt), + "kwargs": { + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "verbose": False + } + }, + # Version 4: Minimal parameters only + { + "style": "minimal", + "args": (self.model, self.tokenizer, prompt), + "kwargs": { + "max_tokens": max_tokens, + "temp": temperature, + "verbose": False + } + }, + # Version 5: Just essential parameters + { + "style": "essential", + "args": (self.model, self.tokenizer, prompt), + "kwargs": { + "max_tokens": max_tokens + } + } + ] + + last_error = None + + for combo in parameter_combinations: + try: + logger.debug(f"Trying MLX generation with style: {combo['style']}") + response = mlx_generate(*combo["args"], **combo["kwargs"]) + logger.debug(f"Successfully generated with style: {combo['style']}") + return response + + except Exception as e: + last_error = e + logger.debug(f"Failed with style {combo['style']}: {str(e)}") + continue + + # If all combinations failed, raise the last error + raise RuntimeError(f"All MLX generation methods failed. Last error: {str(last_error)}") + + def format_chat_prompt(self, system_prompt: str, user_prompt: str) -> str: + """Format the prompt according to model's chat template""" + if hasattr(self.tokenizer, 'apply_chat_template'): + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ] + try: + return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + except Exception as e: + logger.warning(f"Failed to apply chat template: {e}, using fallback") + return f"System: {system_prompt}\n\nUser: {user_prompt}\n\nAssistant:" + else: + return f"System: {system_prompt}\n\nUser: {user_prompt}\n\nAssistant:" + +class MLXManager: + """Manager for MLX models and operations""" + + def __init__(self, cache_manager): + self.cache_manager = cache_manager + self.available = MLX_AVAILABLE and is_apple_silicon() + + if self.available: + logger.info("MLX manager initialized - Apple Silicon detected") + else: + logger.debug("MLX manager not available - requires Apple Silicon and mlx-lm") + + def create_pipeline(self, model_id: str, **kwargs) -> MLXInferencePipeline: + """Create an MLX inference pipeline""" + if not self.available: + raise RuntimeError("MLX not available on this platform") + + config = MLXModelConfig( + model_id=model_id, + **kwargs + ) + + return MLXInferencePipeline(config, self.cache_manager) + + def is_mlx_model(self, model_id: str) -> bool: + """Check if model should use MLX""" + return should_use_mlx(model_id) + class MemoryEfficientAttention(nn.Module): """ Memory-efficient attention using linear attention mechanism. @@ -1286,18 +1595,27 @@ def __init__(self): self.device_manager = DeviceManager() self.model_manager = ModelManager(self.cache_manager, self.device_manager) self.lora_manager = LoRAManager(self.cache_manager) + self.mlx_manager = MLXManager(self.cache_manager) self.chat = self.Chat(self) self.models = self.Models() - def get_pipeline(self, model: str) -> 'InferencePipeline': - model_config = parse_model_string(model) - return InferencePipeline( - model_config, - self.cache_manager, - self.device_manager, - self.model_manager, - self.lora_manager - ) + def get_pipeline(self, model: str): + """Get inference pipeline - automatically chooses MLX or PyTorch based on model""" + # Check if should use MLX + if self.mlx_manager.available and should_use_mlx(model): + logger.info(f"Using MLX pipeline for model: {model}") + return self.mlx_manager.create_pipeline(model) + else: + # Use existing PyTorch pipeline + logger.info(f"Using PyTorch pipeline for model: {model}") + model_config = parse_model_string(model) + return InferencePipeline( + model_config, + self.cache_manager, + self.device_manager, + self.model_manager, + self.lora_manager + ) class Chat: """OpenAI-compatible chat interface""" diff --git a/requirements.txt b/requirements.txt index 1996b88..aead2cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,6 @@ cerebras_cloud_sdk outlines[transformers] sentencepiece adaptive-classifier -mcp \ No newline at end of file +mcp +# MLX support for Apple Silicon optimization +mlx-lm>=0.24.0; platform_machine=="arm64" and sys_platform=="darwin" \ No newline at end of file diff --git a/setup.py b/setup.py index bb5b152..95d18f9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="optillm", - version="0.1.15", + version="0.1.16", packages=find_packages(include=['optillm', 'optillm.*']), # This ensures all subpackages are included py_modules=['optillm'], package_data={ @@ -46,6 +46,8 @@ "sentencepiece", "mcp", "adaptive-classifier", + # MLX support for Apple Silicon optimization + 'mlx-lm>=0.24.0; platform_machine=="arm64" and sys_platform=="darwin"', ], entry_points={ 'console_scripts': [