Skip to content

Fix bug mps #203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optillm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

# Version information
__version__ = "0.1.15"
__version__ = "0.1.16"

# Get the path to the root optillm.py
spec = util.spec_from_file_location(
Expand Down
336 changes: 327 additions & 9 deletions optillm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import time
import threading
import traceback
import platform
import sys

from optillm.cot_decoding import cot_decode
from optillm.entropy_decoding import entropy_decode
Expand All @@ -26,6 +28,17 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# MLX Support for Apple Silicon
try:
import mlx.core as mx
from mlx_lm import load as mlx_load, generate as mlx_generate
from mlx_lm.tokenizer_utils import TokenizerWrapper
MLX_AVAILABLE = True
logger.info("MLX framework available")
except ImportError:
MLX_AVAILABLE = False
logger.debug("MLX framework not available - falling back to PyTorch")

@dataclass
class ModelConfig:
base_model_id: str
Expand Down Expand Up @@ -162,6 +175,302 @@ def calculate_logprobs(
bytes_per_token=all_bytes
)

# MLX Support Functions and Classes

def is_apple_silicon() -> bool:
"""Check if running on Apple Silicon"""
return platform.system() == "Darwin" and platform.machine() == "arm64"

def should_use_mlx(model_id: str) -> bool:
"""Determine if a model should use MLX instead of PyTorch"""
if not MLX_AVAILABLE or not is_apple_silicon():
return False

# Models that should use MLX
mlx_patterns = [
"mlx-community/",
"mlx-"
]

# Known problematic models that should prefer MLX on Apple Silicon
problematic_models = [
"Qwen/Qwen3-",
"google/gemma-3-",
"google/gemma3-"
]

model_lower = model_id.lower()

# Direct MLX model detection
for pattern in mlx_patterns:
if pattern.lower() in model_lower:
return True

# Problematic model detection
for pattern in problematic_models:
if pattern.lower() in model_lower:
logger.warning(f"Model {model_id} detected as potentially problematic with MPS backend")
suggested_mlx = suggest_mlx_alternative(model_id)
logger.warning(f"Consider using MLX model: {suggested_mlx}")
# Don't auto-switch, but recommend
return False

return False

def suggest_mlx_alternative(model_id: str) -> str:
"""Suggest MLX alternative for a given model"""
mlx_alternatives = {
# Qwen3 models
"Qwen/Qwen3-0.6B": "mlx-community/Qwen3-0.6B-4bit",
"Qwen/Qwen3-1.7B": "mlx-community/Qwen3-1.7B-4bit",
"Qwen/Qwen3-4B": "mlx-community/Qwen3-4B-4bit",
"Qwen/Qwen3-8B": "mlx-community/Qwen3-8B-4bit",
"Qwen/Qwen3-14B": "mlx-community/Qwen3-14B-4bit",
"Qwen/Qwen3-32B": "mlx-community/Qwen3-32B-4bit",

# Gemma 3 models
"google/gemma-3-1b-it": "mlx-community/gemma-3-1b-it-4bit",
"google/gemma-3-4b-it": "mlx-community/gemma-3-4b-it-4bit",
"google/gemma-3-12b-it": "mlx-community/gemma-3-12b-it-4bit",
"google/gemma-3-27b-it": "mlx-community/gemma-3-27b-it-4bit",
}

return mlx_alternatives.get(model_id, f"mlx-community/{model_id.split('/')[-1]}-4bit")

@dataclass
class MLXModelConfig:
"""Configuration for MLX models"""
model_id: str
max_new_tokens: int = 4096
temperature: float = 0.7
top_p: float = 0.9
repetition_penalty: float = 1.0
enable_prompt_caching: bool = True

class MLXInferencePipeline:
"""MLX-based inference pipeline that mirrors PyTorch pipeline interface"""

def __init__(self, model_config: MLXModelConfig, cache_manager):
self.model_config = model_config
self.cache_manager = cache_manager
self.last_used = time.time()

if not MLX_AVAILABLE:
raise RuntimeError("MLX framework not available. Install with: pip install mlx-lm")

if not is_apple_silicon():
raise RuntimeError("MLX framework is only supported on Apple Silicon")

try:
logger.info(f"Loading MLX model: {model_config.model_id}")
self.model, self.tokenizer = self._load_mlx_model(model_config.model_id)
logger.info("MLX model loaded successfully")
except Exception as e:
logger.error(f"Failed to load MLX model: {str(e)}")
raise

def _load_mlx_model(self, model_id: str):
"""Load MLX model and tokenizer with caching"""
def _load_model():
start_time = time.time()
logger.info(f"Loading MLX model: {model_id}")

try:
model, tokenizer = mlx_load(model_id)
load_time = time.time() - start_time
logger.info(f"MLX model loaded in {load_time:.2f}s")
return model, tokenizer
except Exception as e:
logger.error(f"Error loading MLX model {model_id}: {str(e)}")
raise

return self.cache_manager.get_or_load_model(f"mlx_{model_id}", _load_model)

def generate(
self,
prompt: str,
generation_params: Optional[Dict[str, Any]] = None
) -> Tuple[List[str], List[int], List[Optional[Dict]]]:
"""Generate text using MLX"""
start_time = time.time()

if generation_params is None:
generation_params = {}

# Extract parameters with defaults
max_tokens = generation_params.get("max_new_tokens", self.model_config.max_new_tokens)
temperature = generation_params.get("temperature", self.model_config.temperature)
top_p = generation_params.get("top_p", self.model_config.top_p)
repetition_penalty = generation_params.get("repetition_penalty", self.model_config.repetition_penalty)
num_return_sequences = generation_params.get("num_return_sequences", 1)

# Handle seed
if generation_params.get("seed") is not None:
mx.random.seed(generation_params["seed"])

responses = []
token_counts = []
logprobs_results = []

# Generate multiple sequences if requested
for _ in range(num_return_sequences):
try:
logger.debug(f"Generating with MLX: max_tokens={max_tokens}, temp={temperature}")

# Use robust MLX generation with multiple fallback approaches
response = self._robust_mlx_generate(
prompt, max_tokens, temperature, top_p, repetition_penalty
)

responses.append(response)

# Count tokens (approximate) - check if response is string
if isinstance(response, str):
token_count = len(self.tokenizer.encode(response))
else:
# Sometimes MLX returns just the new tokens, get the actual text
token_count = len(response) if hasattr(response, '__len__') else 0
token_counts.append(token_count)

# MLX doesn't provide logprobs by default
logprobs_results.append(None)

except Exception as e:
logger.error(f"Error during MLX generation: {str(e)}")
logger.error(f"MLX generation parameters: max_tokens={max_tokens}, temp={temperature}, top_p={top_p}")
responses.append("")
token_counts.append(0)
logprobs_results.append(None)

generation_time = time.time() - start_time
logger.info(f"MLX generation completed in {generation_time:.2f}s")

return responses, token_counts, logprobs_results

def _robust_mlx_generate(self, prompt: str, max_tokens: int, temperature: float, top_p: float, repetition_penalty: float) -> str:
"""Robust MLX generation with multiple parameter combinations"""

# Try different parameter combinations based on MLX-LM version
parameter_combinations = [
# Version 1: Current style with positional args and temp
{
"style": "positional_temp",
"args": (self.model, self.tokenizer, prompt),
"kwargs": {
"max_tokens": max_tokens,
"temp": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"verbose": False
}
},
# Version 2: All keyword arguments with temp
{
"style": "keyword_temp",
"args": (),
"kwargs": {
"model": self.model,
"tokenizer": self.tokenizer,
"prompt": prompt,
"max_tokens": max_tokens,
"temp": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"verbose": False
}
},
# Version 3: Using temperature instead of temp
{
"style": "positional_temperature",
"args": (self.model, self.tokenizer, prompt),
"kwargs": {
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"verbose": False
}
},
# Version 4: Minimal parameters only
{
"style": "minimal",
"args": (self.model, self.tokenizer, prompt),
"kwargs": {
"max_tokens": max_tokens,
"temp": temperature,
"verbose": False
}
},
# Version 5: Just essential parameters
{
"style": "essential",
"args": (self.model, self.tokenizer, prompt),
"kwargs": {
"max_tokens": max_tokens
}
}
]

last_error = None

for combo in parameter_combinations:
try:
logger.debug(f"Trying MLX generation with style: {combo['style']}")
response = mlx_generate(*combo["args"], **combo["kwargs"])
logger.debug(f"Successfully generated with style: {combo['style']}")
return response

except Exception as e:
last_error = e
logger.debug(f"Failed with style {combo['style']}: {str(e)}")
continue

# If all combinations failed, raise the last error
raise RuntimeError(f"All MLX generation methods failed. Last error: {str(last_error)}")

def format_chat_prompt(self, system_prompt: str, user_prompt: str) -> str:
"""Format the prompt according to model's chat template"""
if hasattr(self.tokenizer, 'apply_chat_template'):
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
try:
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except Exception as e:
logger.warning(f"Failed to apply chat template: {e}, using fallback")
return f"System: {system_prompt}\n\nUser: {user_prompt}\n\nAssistant:"
else:
return f"System: {system_prompt}\n\nUser: {user_prompt}\n\nAssistant:"

class MLXManager:
"""Manager for MLX models and operations"""

def __init__(self, cache_manager):
self.cache_manager = cache_manager
self.available = MLX_AVAILABLE and is_apple_silicon()

if self.available:
logger.info("MLX manager initialized - Apple Silicon detected")
else:
logger.debug("MLX manager not available - requires Apple Silicon and mlx-lm")

def create_pipeline(self, model_id: str, **kwargs) -> MLXInferencePipeline:
"""Create an MLX inference pipeline"""
if not self.available:
raise RuntimeError("MLX not available on this platform")

config = MLXModelConfig(
model_id=model_id,
**kwargs
)

return MLXInferencePipeline(config, self.cache_manager)

def is_mlx_model(self, model_id: str) -> bool:
"""Check if model should use MLX"""
return should_use_mlx(model_id)

class MemoryEfficientAttention(nn.Module):
"""
Memory-efficient attention using linear attention mechanism.
Expand Down Expand Up @@ -1286,18 +1595,27 @@ def __init__(self):
self.device_manager = DeviceManager()
self.model_manager = ModelManager(self.cache_manager, self.device_manager)
self.lora_manager = LoRAManager(self.cache_manager)
self.mlx_manager = MLXManager(self.cache_manager)
self.chat = self.Chat(self)
self.models = self.Models()

def get_pipeline(self, model: str) -> 'InferencePipeline':
model_config = parse_model_string(model)
return InferencePipeline(
model_config,
self.cache_manager,
self.device_manager,
self.model_manager,
self.lora_manager
)
def get_pipeline(self, model: str):
"""Get inference pipeline - automatically chooses MLX or PyTorch based on model"""
# Check if should use MLX
if self.mlx_manager.available and should_use_mlx(model):
logger.info(f"Using MLX pipeline for model: {model}")
return self.mlx_manager.create_pipeline(model)
else:
# Use existing PyTorch pipeline
logger.info(f"Using PyTorch pipeline for model: {model}")
model_config = parse_model_string(model)
return InferencePipeline(
model_config,
self.cache_manager,
self.device_manager,
self.model_manager,
self.lora_manager
)

class Chat:
"""OpenAI-compatible chat interface"""
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ cerebras_cloud_sdk
outlines[transformers]
sentencepiece
adaptive-classifier
mcp
mcp
# MLX support for Apple Silicon optimization
mlx-lm>=0.24.0; platform_machine=="arm64" and sys_platform=="darwin"
Loading