|
16 | 16 | import time
|
17 | 17 | import threading
|
18 | 18 | import traceback
|
| 19 | +import platform |
| 20 | +import sys |
19 | 21 |
|
20 | 22 | from optillm.cot_decoding import cot_decode
|
21 | 23 | from optillm.entropy_decoding import entropy_decode
|
|
26 | 28 | logging.basicConfig(level=logging.INFO)
|
27 | 29 | logger = logging.getLogger(__name__)
|
28 | 30 |
|
| 31 | +# MLX Support for Apple Silicon |
| 32 | +try: |
| 33 | + import mlx.core as mx |
| 34 | + from mlx_lm import load as mlx_load, generate as mlx_generate |
| 35 | + from mlx_lm.tokenizer_utils import TokenizerWrapper |
| 36 | + MLX_AVAILABLE = True |
| 37 | + logger.info("MLX framework available") |
| 38 | +except ImportError: |
| 39 | + MLX_AVAILABLE = False |
| 40 | + logger.debug("MLX framework not available - falling back to PyTorch") |
| 41 | + |
29 | 42 | @dataclass
|
30 | 43 | class ModelConfig:
|
31 | 44 | base_model_id: str
|
@@ -162,6 +175,302 @@ def calculate_logprobs(
|
162 | 175 | bytes_per_token=all_bytes
|
163 | 176 | )
|
164 | 177 |
|
| 178 | +# MLX Support Functions and Classes |
| 179 | + |
| 180 | +def is_apple_silicon() -> bool: |
| 181 | + """Check if running on Apple Silicon""" |
| 182 | + return platform.system() == "Darwin" and platform.machine() == "arm64" |
| 183 | + |
| 184 | +def should_use_mlx(model_id: str) -> bool: |
| 185 | + """Determine if a model should use MLX instead of PyTorch""" |
| 186 | + if not MLX_AVAILABLE or not is_apple_silicon(): |
| 187 | + return False |
| 188 | + |
| 189 | + # Models that should use MLX |
| 190 | + mlx_patterns = [ |
| 191 | + "mlx-community/", |
| 192 | + "mlx-" |
| 193 | + ] |
| 194 | + |
| 195 | + # Known problematic models that should prefer MLX on Apple Silicon |
| 196 | + problematic_models = [ |
| 197 | + "Qwen/Qwen3-", |
| 198 | + "google/gemma-3-", |
| 199 | + "google/gemma3-" |
| 200 | + ] |
| 201 | + |
| 202 | + model_lower = model_id.lower() |
| 203 | + |
| 204 | + # Direct MLX model detection |
| 205 | + for pattern in mlx_patterns: |
| 206 | + if pattern.lower() in model_lower: |
| 207 | + return True |
| 208 | + |
| 209 | + # Problematic model detection |
| 210 | + for pattern in problematic_models: |
| 211 | + if pattern.lower() in model_lower: |
| 212 | + logger.warning(f"Model {model_id} detected as potentially problematic with MPS backend") |
| 213 | + suggested_mlx = suggest_mlx_alternative(model_id) |
| 214 | + logger.warning(f"Consider using MLX model: {suggested_mlx}") |
| 215 | + # Don't auto-switch, but recommend |
| 216 | + return False |
| 217 | + |
| 218 | + return False |
| 219 | + |
| 220 | +def suggest_mlx_alternative(model_id: str) -> str: |
| 221 | + """Suggest MLX alternative for a given model""" |
| 222 | + mlx_alternatives = { |
| 223 | + # Qwen3 models |
| 224 | + "Qwen/Qwen3-0.6B": "mlx-community/Qwen3-0.6B-4bit", |
| 225 | + "Qwen/Qwen3-1.7B": "mlx-community/Qwen3-1.7B-4bit", |
| 226 | + "Qwen/Qwen3-4B": "mlx-community/Qwen3-4B-4bit", |
| 227 | + "Qwen/Qwen3-8B": "mlx-community/Qwen3-8B-4bit", |
| 228 | + "Qwen/Qwen3-14B": "mlx-community/Qwen3-14B-4bit", |
| 229 | + "Qwen/Qwen3-32B": "mlx-community/Qwen3-32B-4bit", |
| 230 | + |
| 231 | + # Gemma 3 models |
| 232 | + "google/gemma-3-1b-it": "mlx-community/gemma-3-1b-it-4bit", |
| 233 | + "google/gemma-3-4b-it": "mlx-community/gemma-3-4b-it-4bit", |
| 234 | + "google/gemma-3-12b-it": "mlx-community/gemma-3-12b-it-4bit", |
| 235 | + "google/gemma-3-27b-it": "mlx-community/gemma-3-27b-it-4bit", |
| 236 | + } |
| 237 | + |
| 238 | + return mlx_alternatives.get(model_id, f"mlx-community/{model_id.split('/')[-1]}-4bit") |
| 239 | + |
| 240 | +@dataclass |
| 241 | +class MLXModelConfig: |
| 242 | + """Configuration for MLX models""" |
| 243 | + model_id: str |
| 244 | + max_new_tokens: int = 4096 |
| 245 | + temperature: float = 0.7 |
| 246 | + top_p: float = 0.9 |
| 247 | + repetition_penalty: float = 1.0 |
| 248 | + enable_prompt_caching: bool = True |
| 249 | + |
| 250 | +class MLXInferencePipeline: |
| 251 | + """MLX-based inference pipeline that mirrors PyTorch pipeline interface""" |
| 252 | + |
| 253 | + def __init__(self, model_config: MLXModelConfig, cache_manager): |
| 254 | + self.model_config = model_config |
| 255 | + self.cache_manager = cache_manager |
| 256 | + self.last_used = time.time() |
| 257 | + |
| 258 | + if not MLX_AVAILABLE: |
| 259 | + raise RuntimeError("MLX framework not available. Install with: pip install mlx-lm") |
| 260 | + |
| 261 | + if not is_apple_silicon(): |
| 262 | + raise RuntimeError("MLX framework is only supported on Apple Silicon") |
| 263 | + |
| 264 | + try: |
| 265 | + logger.info(f"Loading MLX model: {model_config.model_id}") |
| 266 | + self.model, self.tokenizer = self._load_mlx_model(model_config.model_id) |
| 267 | + logger.info("MLX model loaded successfully") |
| 268 | + except Exception as e: |
| 269 | + logger.error(f"Failed to load MLX model: {str(e)}") |
| 270 | + raise |
| 271 | + |
| 272 | + def _load_mlx_model(self, model_id: str): |
| 273 | + """Load MLX model and tokenizer with caching""" |
| 274 | + def _load_model(): |
| 275 | + start_time = time.time() |
| 276 | + logger.info(f"Loading MLX model: {model_id}") |
| 277 | + |
| 278 | + try: |
| 279 | + model, tokenizer = mlx_load(model_id) |
| 280 | + load_time = time.time() - start_time |
| 281 | + logger.info(f"MLX model loaded in {load_time:.2f}s") |
| 282 | + return model, tokenizer |
| 283 | + except Exception as e: |
| 284 | + logger.error(f"Error loading MLX model {model_id}: {str(e)}") |
| 285 | + raise |
| 286 | + |
| 287 | + return self.cache_manager.get_or_load_model(f"mlx_{model_id}", _load_model) |
| 288 | + |
| 289 | + def generate( |
| 290 | + self, |
| 291 | + prompt: str, |
| 292 | + generation_params: Optional[Dict[str, Any]] = None |
| 293 | + ) -> Tuple[List[str], List[int], List[Optional[Dict]]]: |
| 294 | + """Generate text using MLX""" |
| 295 | + start_time = time.time() |
| 296 | + |
| 297 | + if generation_params is None: |
| 298 | + generation_params = {} |
| 299 | + |
| 300 | + # Extract parameters with defaults |
| 301 | + max_tokens = generation_params.get("max_new_tokens", self.model_config.max_new_tokens) |
| 302 | + temperature = generation_params.get("temperature", self.model_config.temperature) |
| 303 | + top_p = generation_params.get("top_p", self.model_config.top_p) |
| 304 | + repetition_penalty = generation_params.get("repetition_penalty", self.model_config.repetition_penalty) |
| 305 | + num_return_sequences = generation_params.get("num_return_sequences", 1) |
| 306 | + |
| 307 | + # Handle seed |
| 308 | + if generation_params.get("seed") is not None: |
| 309 | + mx.random.seed(generation_params["seed"]) |
| 310 | + |
| 311 | + responses = [] |
| 312 | + token_counts = [] |
| 313 | + logprobs_results = [] |
| 314 | + |
| 315 | + # Generate multiple sequences if requested |
| 316 | + for _ in range(num_return_sequences): |
| 317 | + try: |
| 318 | + logger.debug(f"Generating with MLX: max_tokens={max_tokens}, temp={temperature}") |
| 319 | + |
| 320 | + # Use robust MLX generation with multiple fallback approaches |
| 321 | + response = self._robust_mlx_generate( |
| 322 | + prompt, max_tokens, temperature, top_p, repetition_penalty |
| 323 | + ) |
| 324 | + |
| 325 | + responses.append(response) |
| 326 | + |
| 327 | + # Count tokens (approximate) - check if response is string |
| 328 | + if isinstance(response, str): |
| 329 | + token_count = len(self.tokenizer.encode(response)) |
| 330 | + else: |
| 331 | + # Sometimes MLX returns just the new tokens, get the actual text |
| 332 | + token_count = len(response) if hasattr(response, '__len__') else 0 |
| 333 | + token_counts.append(token_count) |
| 334 | + |
| 335 | + # MLX doesn't provide logprobs by default |
| 336 | + logprobs_results.append(None) |
| 337 | + |
| 338 | + except Exception as e: |
| 339 | + logger.error(f"Error during MLX generation: {str(e)}") |
| 340 | + logger.error(f"MLX generation parameters: max_tokens={max_tokens}, temp={temperature}, top_p={top_p}") |
| 341 | + responses.append("") |
| 342 | + token_counts.append(0) |
| 343 | + logprobs_results.append(None) |
| 344 | + |
| 345 | + generation_time = time.time() - start_time |
| 346 | + logger.info(f"MLX generation completed in {generation_time:.2f}s") |
| 347 | + |
| 348 | + return responses, token_counts, logprobs_results |
| 349 | + |
| 350 | + def _robust_mlx_generate(self, prompt: str, max_tokens: int, temperature: float, top_p: float, repetition_penalty: float) -> str: |
| 351 | + """Robust MLX generation with multiple parameter combinations""" |
| 352 | + |
| 353 | + # Try different parameter combinations based on MLX-LM version |
| 354 | + parameter_combinations = [ |
| 355 | + # Version 1: Current style with positional args and temp |
| 356 | + { |
| 357 | + "style": "positional_temp", |
| 358 | + "args": (self.model, self.tokenizer, prompt), |
| 359 | + "kwargs": { |
| 360 | + "max_tokens": max_tokens, |
| 361 | + "temp": temperature, |
| 362 | + "top_p": top_p, |
| 363 | + "repetition_penalty": repetition_penalty, |
| 364 | + "verbose": False |
| 365 | + } |
| 366 | + }, |
| 367 | + # Version 2: All keyword arguments with temp |
| 368 | + { |
| 369 | + "style": "keyword_temp", |
| 370 | + "args": (), |
| 371 | + "kwargs": { |
| 372 | + "model": self.model, |
| 373 | + "tokenizer": self.tokenizer, |
| 374 | + "prompt": prompt, |
| 375 | + "max_tokens": max_tokens, |
| 376 | + "temp": temperature, |
| 377 | + "top_p": top_p, |
| 378 | + "repetition_penalty": repetition_penalty, |
| 379 | + "verbose": False |
| 380 | + } |
| 381 | + }, |
| 382 | + # Version 3: Using temperature instead of temp |
| 383 | + { |
| 384 | + "style": "positional_temperature", |
| 385 | + "args": (self.model, self.tokenizer, prompt), |
| 386 | + "kwargs": { |
| 387 | + "max_tokens": max_tokens, |
| 388 | + "temperature": temperature, |
| 389 | + "top_p": top_p, |
| 390 | + "repetition_penalty": repetition_penalty, |
| 391 | + "verbose": False |
| 392 | + } |
| 393 | + }, |
| 394 | + # Version 4: Minimal parameters only |
| 395 | + { |
| 396 | + "style": "minimal", |
| 397 | + "args": (self.model, self.tokenizer, prompt), |
| 398 | + "kwargs": { |
| 399 | + "max_tokens": max_tokens, |
| 400 | + "temp": temperature, |
| 401 | + "verbose": False |
| 402 | + } |
| 403 | + }, |
| 404 | + # Version 5: Just essential parameters |
| 405 | + { |
| 406 | + "style": "essential", |
| 407 | + "args": (self.model, self.tokenizer, prompt), |
| 408 | + "kwargs": { |
| 409 | + "max_tokens": max_tokens |
| 410 | + } |
| 411 | + } |
| 412 | + ] |
| 413 | + |
| 414 | + last_error = None |
| 415 | + |
| 416 | + for combo in parameter_combinations: |
| 417 | + try: |
| 418 | + logger.debug(f"Trying MLX generation with style: {combo['style']}") |
| 419 | + response = mlx_generate(*combo["args"], **combo["kwargs"]) |
| 420 | + logger.debug(f"Successfully generated with style: {combo['style']}") |
| 421 | + return response |
| 422 | + |
| 423 | + except Exception as e: |
| 424 | + last_error = e |
| 425 | + logger.debug(f"Failed with style {combo['style']}: {str(e)}") |
| 426 | + continue |
| 427 | + |
| 428 | + # If all combinations failed, raise the last error |
| 429 | + raise RuntimeError(f"All MLX generation methods failed. Last error: {str(last_error)}") |
| 430 | + |
| 431 | + def format_chat_prompt(self, system_prompt: str, user_prompt: str) -> str: |
| 432 | + """Format the prompt according to model's chat template""" |
| 433 | + if hasattr(self.tokenizer, 'apply_chat_template'): |
| 434 | + messages = [ |
| 435 | + {"role": "system", "content": system_prompt}, |
| 436 | + {"role": "user", "content": user_prompt} |
| 437 | + ] |
| 438 | + try: |
| 439 | + return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| 440 | + except Exception as e: |
| 441 | + logger.warning(f"Failed to apply chat template: {e}, using fallback") |
| 442 | + return f"System: {system_prompt}\n\nUser: {user_prompt}\n\nAssistant:" |
| 443 | + else: |
| 444 | + return f"System: {system_prompt}\n\nUser: {user_prompt}\n\nAssistant:" |
| 445 | + |
| 446 | +class MLXManager: |
| 447 | + """Manager for MLX models and operations""" |
| 448 | + |
| 449 | + def __init__(self, cache_manager): |
| 450 | + self.cache_manager = cache_manager |
| 451 | + self.available = MLX_AVAILABLE and is_apple_silicon() |
| 452 | + |
| 453 | + if self.available: |
| 454 | + logger.info("MLX manager initialized - Apple Silicon detected") |
| 455 | + else: |
| 456 | + logger.debug("MLX manager not available - requires Apple Silicon and mlx-lm") |
| 457 | + |
| 458 | + def create_pipeline(self, model_id: str, **kwargs) -> MLXInferencePipeline: |
| 459 | + """Create an MLX inference pipeline""" |
| 460 | + if not self.available: |
| 461 | + raise RuntimeError("MLX not available on this platform") |
| 462 | + |
| 463 | + config = MLXModelConfig( |
| 464 | + model_id=model_id, |
| 465 | + **kwargs |
| 466 | + ) |
| 467 | + |
| 468 | + return MLXInferencePipeline(config, self.cache_manager) |
| 469 | + |
| 470 | + def is_mlx_model(self, model_id: str) -> bool: |
| 471 | + """Check if model should use MLX""" |
| 472 | + return should_use_mlx(model_id) |
| 473 | + |
165 | 474 | class MemoryEfficientAttention(nn.Module):
|
166 | 475 | """
|
167 | 476 | Memory-efficient attention using linear attention mechanism.
|
@@ -1286,18 +1595,27 @@ def __init__(self):
|
1286 | 1595 | self.device_manager = DeviceManager()
|
1287 | 1596 | self.model_manager = ModelManager(self.cache_manager, self.device_manager)
|
1288 | 1597 | self.lora_manager = LoRAManager(self.cache_manager)
|
| 1598 | + self.mlx_manager = MLXManager(self.cache_manager) |
1289 | 1599 | self.chat = self.Chat(self)
|
1290 | 1600 | self.models = self.Models()
|
1291 | 1601 |
|
1292 |
| - def get_pipeline(self, model: str) -> 'InferencePipeline': |
1293 |
| - model_config = parse_model_string(model) |
1294 |
| - return InferencePipeline( |
1295 |
| - model_config, |
1296 |
| - self.cache_manager, |
1297 |
| - self.device_manager, |
1298 |
| - self.model_manager, |
1299 |
| - self.lora_manager |
1300 |
| - ) |
| 1602 | + def get_pipeline(self, model: str): |
| 1603 | + """Get inference pipeline - automatically chooses MLX or PyTorch based on model""" |
| 1604 | + # Check if should use MLX |
| 1605 | + if self.mlx_manager.available and should_use_mlx(model): |
| 1606 | + logger.info(f"Using MLX pipeline for model: {model}") |
| 1607 | + return self.mlx_manager.create_pipeline(model) |
| 1608 | + else: |
| 1609 | + # Use existing PyTorch pipeline |
| 1610 | + logger.info(f"Using PyTorch pipeline for model: {model}") |
| 1611 | + model_config = parse_model_string(model) |
| 1612 | + return InferencePipeline( |
| 1613 | + model_config, |
| 1614 | + self.cache_manager, |
| 1615 | + self.device_manager, |
| 1616 | + self.model_manager, |
| 1617 | + self.lora_manager |
| 1618 | + ) |
1301 | 1619 |
|
1302 | 1620 | class Chat:
|
1303 | 1621 | """OpenAI-compatible chat interface"""
|
|
0 commit comments