Skip to content

Commit 2e4c0da

Browse files
authored
Merge pull request #203 from codelion/fix-bug-mps
Fix bug mps
2 parents 333e752 + c27a095 commit 2e4c0da

File tree

4 files changed

+334
-12
lines changed

4 files changed

+334
-12
lines changed

optillm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33

44
# Version information
5-
__version__ = "0.1.15"
5+
__version__ = "0.1.16"
66

77
# Get the path to the root optillm.py
88
spec = util.spec_from_file_location(

optillm/inference.py

Lines changed: 327 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import time
1717
import threading
1818
import traceback
19+
import platform
20+
import sys
1921

2022
from optillm.cot_decoding import cot_decode
2123
from optillm.entropy_decoding import entropy_decode
@@ -26,6 +28,17 @@
2628
logging.basicConfig(level=logging.INFO)
2729
logger = logging.getLogger(__name__)
2830

31+
# MLX Support for Apple Silicon
32+
try:
33+
import mlx.core as mx
34+
from mlx_lm import load as mlx_load, generate as mlx_generate
35+
from mlx_lm.tokenizer_utils import TokenizerWrapper
36+
MLX_AVAILABLE = True
37+
logger.info("MLX framework available")
38+
except ImportError:
39+
MLX_AVAILABLE = False
40+
logger.debug("MLX framework not available - falling back to PyTorch")
41+
2942
@dataclass
3043
class ModelConfig:
3144
base_model_id: str
@@ -162,6 +175,302 @@ def calculate_logprobs(
162175
bytes_per_token=all_bytes
163176
)
164177

178+
# MLX Support Functions and Classes
179+
180+
def is_apple_silicon() -> bool:
181+
"""Check if running on Apple Silicon"""
182+
return platform.system() == "Darwin" and platform.machine() == "arm64"
183+
184+
def should_use_mlx(model_id: str) -> bool:
185+
"""Determine if a model should use MLX instead of PyTorch"""
186+
if not MLX_AVAILABLE or not is_apple_silicon():
187+
return False
188+
189+
# Models that should use MLX
190+
mlx_patterns = [
191+
"mlx-community/",
192+
"mlx-"
193+
]
194+
195+
# Known problematic models that should prefer MLX on Apple Silicon
196+
problematic_models = [
197+
"Qwen/Qwen3-",
198+
"google/gemma-3-",
199+
"google/gemma3-"
200+
]
201+
202+
model_lower = model_id.lower()
203+
204+
# Direct MLX model detection
205+
for pattern in mlx_patterns:
206+
if pattern.lower() in model_lower:
207+
return True
208+
209+
# Problematic model detection
210+
for pattern in problematic_models:
211+
if pattern.lower() in model_lower:
212+
logger.warning(f"Model {model_id} detected as potentially problematic with MPS backend")
213+
suggested_mlx = suggest_mlx_alternative(model_id)
214+
logger.warning(f"Consider using MLX model: {suggested_mlx}")
215+
# Don't auto-switch, but recommend
216+
return False
217+
218+
return False
219+
220+
def suggest_mlx_alternative(model_id: str) -> str:
221+
"""Suggest MLX alternative for a given model"""
222+
mlx_alternatives = {
223+
# Qwen3 models
224+
"Qwen/Qwen3-0.6B": "mlx-community/Qwen3-0.6B-4bit",
225+
"Qwen/Qwen3-1.7B": "mlx-community/Qwen3-1.7B-4bit",
226+
"Qwen/Qwen3-4B": "mlx-community/Qwen3-4B-4bit",
227+
"Qwen/Qwen3-8B": "mlx-community/Qwen3-8B-4bit",
228+
"Qwen/Qwen3-14B": "mlx-community/Qwen3-14B-4bit",
229+
"Qwen/Qwen3-32B": "mlx-community/Qwen3-32B-4bit",
230+
231+
# Gemma 3 models
232+
"google/gemma-3-1b-it": "mlx-community/gemma-3-1b-it-4bit",
233+
"google/gemma-3-4b-it": "mlx-community/gemma-3-4b-it-4bit",
234+
"google/gemma-3-12b-it": "mlx-community/gemma-3-12b-it-4bit",
235+
"google/gemma-3-27b-it": "mlx-community/gemma-3-27b-it-4bit",
236+
}
237+
238+
return mlx_alternatives.get(model_id, f"mlx-community/{model_id.split('/')[-1]}-4bit")
239+
240+
@dataclass
241+
class MLXModelConfig:
242+
"""Configuration for MLX models"""
243+
model_id: str
244+
max_new_tokens: int = 4096
245+
temperature: float = 0.7
246+
top_p: float = 0.9
247+
repetition_penalty: float = 1.0
248+
enable_prompt_caching: bool = True
249+
250+
class MLXInferencePipeline:
251+
"""MLX-based inference pipeline that mirrors PyTorch pipeline interface"""
252+
253+
def __init__(self, model_config: MLXModelConfig, cache_manager):
254+
self.model_config = model_config
255+
self.cache_manager = cache_manager
256+
self.last_used = time.time()
257+
258+
if not MLX_AVAILABLE:
259+
raise RuntimeError("MLX framework not available. Install with: pip install mlx-lm")
260+
261+
if not is_apple_silicon():
262+
raise RuntimeError("MLX framework is only supported on Apple Silicon")
263+
264+
try:
265+
logger.info(f"Loading MLX model: {model_config.model_id}")
266+
self.model, self.tokenizer = self._load_mlx_model(model_config.model_id)
267+
logger.info("MLX model loaded successfully")
268+
except Exception as e:
269+
logger.error(f"Failed to load MLX model: {str(e)}")
270+
raise
271+
272+
def _load_mlx_model(self, model_id: str):
273+
"""Load MLX model and tokenizer with caching"""
274+
def _load_model():
275+
start_time = time.time()
276+
logger.info(f"Loading MLX model: {model_id}")
277+
278+
try:
279+
model, tokenizer = mlx_load(model_id)
280+
load_time = time.time() - start_time
281+
logger.info(f"MLX model loaded in {load_time:.2f}s")
282+
return model, tokenizer
283+
except Exception as e:
284+
logger.error(f"Error loading MLX model {model_id}: {str(e)}")
285+
raise
286+
287+
return self.cache_manager.get_or_load_model(f"mlx_{model_id}", _load_model)
288+
289+
def generate(
290+
self,
291+
prompt: str,
292+
generation_params: Optional[Dict[str, Any]] = None
293+
) -> Tuple[List[str], List[int], List[Optional[Dict]]]:
294+
"""Generate text using MLX"""
295+
start_time = time.time()
296+
297+
if generation_params is None:
298+
generation_params = {}
299+
300+
# Extract parameters with defaults
301+
max_tokens = generation_params.get("max_new_tokens", self.model_config.max_new_tokens)
302+
temperature = generation_params.get("temperature", self.model_config.temperature)
303+
top_p = generation_params.get("top_p", self.model_config.top_p)
304+
repetition_penalty = generation_params.get("repetition_penalty", self.model_config.repetition_penalty)
305+
num_return_sequences = generation_params.get("num_return_sequences", 1)
306+
307+
# Handle seed
308+
if generation_params.get("seed") is not None:
309+
mx.random.seed(generation_params["seed"])
310+
311+
responses = []
312+
token_counts = []
313+
logprobs_results = []
314+
315+
# Generate multiple sequences if requested
316+
for _ in range(num_return_sequences):
317+
try:
318+
logger.debug(f"Generating with MLX: max_tokens={max_tokens}, temp={temperature}")
319+
320+
# Use robust MLX generation with multiple fallback approaches
321+
response = self._robust_mlx_generate(
322+
prompt, max_tokens, temperature, top_p, repetition_penalty
323+
)
324+
325+
responses.append(response)
326+
327+
# Count tokens (approximate) - check if response is string
328+
if isinstance(response, str):
329+
token_count = len(self.tokenizer.encode(response))
330+
else:
331+
# Sometimes MLX returns just the new tokens, get the actual text
332+
token_count = len(response) if hasattr(response, '__len__') else 0
333+
token_counts.append(token_count)
334+
335+
# MLX doesn't provide logprobs by default
336+
logprobs_results.append(None)
337+
338+
except Exception as e:
339+
logger.error(f"Error during MLX generation: {str(e)}")
340+
logger.error(f"MLX generation parameters: max_tokens={max_tokens}, temp={temperature}, top_p={top_p}")
341+
responses.append("")
342+
token_counts.append(0)
343+
logprobs_results.append(None)
344+
345+
generation_time = time.time() - start_time
346+
logger.info(f"MLX generation completed in {generation_time:.2f}s")
347+
348+
return responses, token_counts, logprobs_results
349+
350+
def _robust_mlx_generate(self, prompt: str, max_tokens: int, temperature: float, top_p: float, repetition_penalty: float) -> str:
351+
"""Robust MLX generation with multiple parameter combinations"""
352+
353+
# Try different parameter combinations based on MLX-LM version
354+
parameter_combinations = [
355+
# Version 1: Current style with positional args and temp
356+
{
357+
"style": "positional_temp",
358+
"args": (self.model, self.tokenizer, prompt),
359+
"kwargs": {
360+
"max_tokens": max_tokens,
361+
"temp": temperature,
362+
"top_p": top_p,
363+
"repetition_penalty": repetition_penalty,
364+
"verbose": False
365+
}
366+
},
367+
# Version 2: All keyword arguments with temp
368+
{
369+
"style": "keyword_temp",
370+
"args": (),
371+
"kwargs": {
372+
"model": self.model,
373+
"tokenizer": self.tokenizer,
374+
"prompt": prompt,
375+
"max_tokens": max_tokens,
376+
"temp": temperature,
377+
"top_p": top_p,
378+
"repetition_penalty": repetition_penalty,
379+
"verbose": False
380+
}
381+
},
382+
# Version 3: Using temperature instead of temp
383+
{
384+
"style": "positional_temperature",
385+
"args": (self.model, self.tokenizer, prompt),
386+
"kwargs": {
387+
"max_tokens": max_tokens,
388+
"temperature": temperature,
389+
"top_p": top_p,
390+
"repetition_penalty": repetition_penalty,
391+
"verbose": False
392+
}
393+
},
394+
# Version 4: Minimal parameters only
395+
{
396+
"style": "minimal",
397+
"args": (self.model, self.tokenizer, prompt),
398+
"kwargs": {
399+
"max_tokens": max_tokens,
400+
"temp": temperature,
401+
"verbose": False
402+
}
403+
},
404+
# Version 5: Just essential parameters
405+
{
406+
"style": "essential",
407+
"args": (self.model, self.tokenizer, prompt),
408+
"kwargs": {
409+
"max_tokens": max_tokens
410+
}
411+
}
412+
]
413+
414+
last_error = None
415+
416+
for combo in parameter_combinations:
417+
try:
418+
logger.debug(f"Trying MLX generation with style: {combo['style']}")
419+
response = mlx_generate(*combo["args"], **combo["kwargs"])
420+
logger.debug(f"Successfully generated with style: {combo['style']}")
421+
return response
422+
423+
except Exception as e:
424+
last_error = e
425+
logger.debug(f"Failed with style {combo['style']}: {str(e)}")
426+
continue
427+
428+
# If all combinations failed, raise the last error
429+
raise RuntimeError(f"All MLX generation methods failed. Last error: {str(last_error)}")
430+
431+
def format_chat_prompt(self, system_prompt: str, user_prompt: str) -> str:
432+
"""Format the prompt according to model's chat template"""
433+
if hasattr(self.tokenizer, 'apply_chat_template'):
434+
messages = [
435+
{"role": "system", "content": system_prompt},
436+
{"role": "user", "content": user_prompt}
437+
]
438+
try:
439+
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
440+
except Exception as e:
441+
logger.warning(f"Failed to apply chat template: {e}, using fallback")
442+
return f"System: {system_prompt}\n\nUser: {user_prompt}\n\nAssistant:"
443+
else:
444+
return f"System: {system_prompt}\n\nUser: {user_prompt}\n\nAssistant:"
445+
446+
class MLXManager:
447+
"""Manager for MLX models and operations"""
448+
449+
def __init__(self, cache_manager):
450+
self.cache_manager = cache_manager
451+
self.available = MLX_AVAILABLE and is_apple_silicon()
452+
453+
if self.available:
454+
logger.info("MLX manager initialized - Apple Silicon detected")
455+
else:
456+
logger.debug("MLX manager not available - requires Apple Silicon and mlx-lm")
457+
458+
def create_pipeline(self, model_id: str, **kwargs) -> MLXInferencePipeline:
459+
"""Create an MLX inference pipeline"""
460+
if not self.available:
461+
raise RuntimeError("MLX not available on this platform")
462+
463+
config = MLXModelConfig(
464+
model_id=model_id,
465+
**kwargs
466+
)
467+
468+
return MLXInferencePipeline(config, self.cache_manager)
469+
470+
def is_mlx_model(self, model_id: str) -> bool:
471+
"""Check if model should use MLX"""
472+
return should_use_mlx(model_id)
473+
165474
class MemoryEfficientAttention(nn.Module):
166475
"""
167476
Memory-efficient attention using linear attention mechanism.
@@ -1286,18 +1595,27 @@ def __init__(self):
12861595
self.device_manager = DeviceManager()
12871596
self.model_manager = ModelManager(self.cache_manager, self.device_manager)
12881597
self.lora_manager = LoRAManager(self.cache_manager)
1598+
self.mlx_manager = MLXManager(self.cache_manager)
12891599
self.chat = self.Chat(self)
12901600
self.models = self.Models()
12911601

1292-
def get_pipeline(self, model: str) -> 'InferencePipeline':
1293-
model_config = parse_model_string(model)
1294-
return InferencePipeline(
1295-
model_config,
1296-
self.cache_manager,
1297-
self.device_manager,
1298-
self.model_manager,
1299-
self.lora_manager
1300-
)
1602+
def get_pipeline(self, model: str):
1603+
"""Get inference pipeline - automatically chooses MLX or PyTorch based on model"""
1604+
# Check if should use MLX
1605+
if self.mlx_manager.available and should_use_mlx(model):
1606+
logger.info(f"Using MLX pipeline for model: {model}")
1607+
return self.mlx_manager.create_pipeline(model)
1608+
else:
1609+
# Use existing PyTorch pipeline
1610+
logger.info(f"Using PyTorch pipeline for model: {model}")
1611+
model_config = parse_model_string(model)
1612+
return InferencePipeline(
1613+
model_config,
1614+
self.cache_manager,
1615+
self.device_manager,
1616+
self.model_manager,
1617+
self.lora_manager
1618+
)
13011619

13021620
class Chat:
13031621
"""OpenAI-compatible chat interface"""

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,6 @@ cerebras_cloud_sdk
2828
outlines[transformers]
2929
sentencepiece
3030
adaptive-classifier
31-
mcp
31+
mcp
32+
# MLX support for Apple Silicon optimization
33+
mlx-lm>=0.24.0; platform_machine=="arm64" and sys_platform=="darwin"

0 commit comments

Comments
 (0)