diff --git a/optillm/__init__.py b/optillm/__init__.py index c7e63bc..24870a6 100644 --- a/optillm/__init__.py +++ b/optillm/__init__.py @@ -2,7 +2,7 @@ import os # Version information -__version__ = "0.1.19" +__version__ = "0.1.20" # Get the path to the root optillm.py spec = util.spec_from_file_location( diff --git a/optillm/bon.py b/optillm/bon.py index 8ee752a..3da7d14 100644 --- a/optillm/bon.py +++ b/optillm/bon.py @@ -10,16 +10,45 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st completions = [] - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=4096, - n=n, - temperature=1 - ) - completions = [choice.message.content for choice in response.choices] - logger.info(f"Generated {len(completions)} initial completions. Tokens used: {response.usage.completion_tokens}") - bon_completion_tokens += response.usage.completion_tokens + try: + # Try to generate n completions in a single API call using n parameter + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=4096, + n=n, + temperature=1 + ) + completions = [choice.message.content for choice in response.choices] + logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}") + bon_completion_tokens += response.usage.completion_tokens + + except Exception as e: + logger.warning(f"n parameter not supported by provider: {str(e)}") + logger.info(f"Falling back to generating {n} completions one by one") + + # Fallback: Generate completions one by one in a loop + for i in range(n): + try: + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=4096, + temperature=1 + ) + completions.append(response.choices[0].message.content) + bon_completion_tokens += response.usage.completion_tokens + logger.debug(f"Generated completion {i+1}/{n}") + + except Exception as fallback_error: + logger.error(f"Error generating completion {i+1}: {str(fallback_error)}") + continue + + if not completions: + logger.error("Failed to generate any completions") + return "Error: Could not generate any completions", 0 + + logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {bon_completion_tokens}") # Rate the completions rating_messages = messages.copy() diff --git a/optillm/inference.py b/optillm/inference.py index 206b357..94a003d 100644 --- a/optillm/inference.py +++ b/optillm/inference.py @@ -22,6 +22,7 @@ from optillm.cot_decoding import cot_decode from optillm.entropy_decoding import entropy_decode from optillm.thinkdeeper import thinkdeeper_decode +from optillm.thinkdeeper_mlx import thinkdeeper_decode_mlx from optillm.autothink import autothink_decode # Configure logging @@ -33,6 +34,7 @@ import mlx.core as mx from mlx_lm import load as mlx_load, generate as mlx_generate from mlx_lm.tokenizer_utils import TokenizerWrapper + from mlx_lm.sample_utils import make_sampler MLX_AVAILABLE = True logger.info("MLX framework available") except ImportError: @@ -349,85 +351,46 @@ def generate( return responses, token_counts, logprobs_results def _robust_mlx_generate(self, prompt: str, max_tokens: int, temperature: float, top_p: float, repetition_penalty: float) -> str: - """Robust MLX generation with multiple parameter combinations""" - - # Try different parameter combinations based on MLX-LM version - parameter_combinations = [ - # Version 1: Current style with positional args and temp - { - "style": "positional_temp", - "args": (self.model, self.tokenizer, prompt), - "kwargs": { - "max_tokens": max_tokens, - "temp": temperature, - "top_p": top_p, - "repetition_penalty": repetition_penalty, - "verbose": False - } - }, - # Version 2: All keyword arguments with temp - { - "style": "keyword_temp", - "args": (), - "kwargs": { - "model": self.model, - "tokenizer": self.tokenizer, - "prompt": prompt, - "max_tokens": max_tokens, - "temp": temperature, - "top_p": top_p, - "repetition_penalty": repetition_penalty, - "verbose": False - } - }, - # Version 3: Using temperature instead of temp - { - "style": "positional_temperature", - "args": (self.model, self.tokenizer, prompt), - "kwargs": { - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "repetition_penalty": repetition_penalty, - "verbose": False - } - }, - # Version 4: Minimal parameters only - { - "style": "minimal", - "args": (self.model, self.tokenizer, prompt), - "kwargs": { - "max_tokens": max_tokens, - "temp": temperature, - "verbose": False - } - }, - # Version 5: Just essential parameters - { - "style": "essential", - "args": (self.model, self.tokenizer, prompt), - "kwargs": { - "max_tokens": max_tokens - } - } - ] + """Robust MLX generation using sampler approach""" - last_error = None - - for combo in parameter_combinations: + try: + # Create sampler with generation parameters + sampler = make_sampler( + temp=temperature, + top_p=top_p, + min_p=0.0, # Default min_p + min_tokens_to_keep=1 # Default min_tokens_to_keep + ) + + # Generate using the sampler + response = mlx_generate( + self.model, + self.tokenizer, + prompt, + max_tokens=max_tokens, + sampler=sampler, + verbose=False + ) + + return response + + except Exception as e: + logger.error(f"MLX generation with sampler failed: {str(e)}") + + # Fallback: Try minimal parameters without sampler try: - logger.debug(f"Trying MLX generation with style: {combo['style']}") - response = mlx_generate(*combo["args"], **combo["kwargs"]) - logger.debug(f"Successfully generated with style: {combo['style']}") + logger.debug("Attempting MLX generation without sampler") + response = mlx_generate( + self.model, + self.tokenizer, + prompt, + max_tokens=max_tokens, + verbose=False + ) return response - - except Exception as e: - last_error = e - logger.debug(f"Failed with style {combo['style']}: {str(e)}") - continue - - # If all combinations failed, raise the last error - raise RuntimeError(f"All MLX generation methods failed. Last error: {str(last_error)}") + except Exception as fallback_e: + logger.error(f"MLX fallback generation also failed: {str(fallback_e)}") + raise def format_chat_prompt(self, system_prompt: str, user_prompt: str) -> str: """Format the prompt according to model's chat template""" @@ -1691,37 +1654,47 @@ def create( if decoding: logger.info(f"Using specialized decoding approach: {decoding}") - # Ensure model is in eval mode and on correct device - pipeline.current_model.eval() - device = pipeline.current_model.device + # Check if this decoding approach is supported for MLX + mlx_unsupported_decodings = ["cot_decoding", "entropy_decoding", "autothink"] + if isinstance(pipeline, MLXInferencePipeline) and decoding in mlx_unsupported_decodings: + logger.warning(f"{decoding} is not supported for MLX models. Falling back to standard generation.") + decoding = None + + if decoding: + # For PyTorch pipelines, ensure model is in eval mode and get device + # MLX pipelines handle this differently + if not isinstance(pipeline, MLXInferencePipeline): + pipeline.current_model.eval() + device = pipeline.current_model.device + else: + device = None # MLX doesn't use torch devices if decoding == "cot_decoding": # Use directly available parameters for CoT - cot_params = { - "k": k, - "num_beams": num_beams, - "max_new_tokens": max_tokens if max_tokens is not None else 512, - "temperature": temperature, - "top_p": top_p, - "repetition_penalty": 1.0, - "length_penalty": length_penalty, - "no_repeat_ngram_size": no_repeat_ngram_size, - "early_stopping": early_stopping, - "aggregate_paths": aggregate_paths, - } - - result, confidence = cot_decode( - pipeline.current_model, - pipeline.tokenizer, - messages, - **cot_params - ) - responses = [result] - logprobs_results = [{"confidence_score": confidence} if confidence is not None else None] - completion_tokens = len(pipeline.tokenizer.encode(result)) + cot_params = { + "k": k, + "num_beams": num_beams, + "max_new_tokens": max_tokens if max_tokens is not None else 512, + "temperature": temperature, + "top_p": top_p, + "repetition_penalty": 1.0, + "length_penalty": length_penalty, + "no_repeat_ngram_size": no_repeat_ngram_size, + "early_stopping": early_stopping, + "aggregate_paths": aggregate_paths, + } + + result, confidence = cot_decode( + pipeline.current_model, + pipeline.tokenizer, + messages, + **cot_params + ) + responses = [result] + logprobs_results = [{"confidence_score": confidence} if confidence is not None else None] + completion_tokens = len(pipeline.tokenizer.encode(result)) elif decoding == "entropy_decoding": - # Ensure model is using full precision original_dtype = pipeline.current_model.dtype pipeline.current_model = pipeline.current_model.to(torch.float32) @@ -1778,43 +1751,66 @@ def create( } thinkdeeper_config.update(custom_config) - result = thinkdeeper_decode( - pipeline.current_model, - pipeline.tokenizer, - messages, - thinkdeeper_config + # Check if we're using MLX pipeline + if isinstance(pipeline, MLXInferencePipeline): + logger.info("Using MLX ThinkDeeper implementation") + + # Ensure we have enough tokens for thinking + response + user_max_tokens = max_tokens if max_tokens is not None else 512 + total_tokens_needed = max_thinking_tokens + 512 # thinking + response buffer + adjusted_max_tokens = max(user_max_tokens, total_tokens_needed) + + # Add max_tokens to thinkdeeper config + thinkdeeper_config_with_tokens = thinkdeeper_config.copy() + thinkdeeper_config_with_tokens["max_tokens"] = adjusted_max_tokens + + logger.debug(f"ThinkDeeper tokens: user={user_max_tokens}, thinking={max_thinking_tokens}, adjusted={adjusted_max_tokens}") + + result = thinkdeeper_decode_mlx( + pipeline.model, + pipeline.tokenizer, + messages, + thinkdeeper_config_with_tokens + ) + else: + logger.info("Using PyTorch ThinkDeeper implementation") + result = thinkdeeper_decode( + pipeline.current_model, + pipeline.tokenizer, + messages, + thinkdeeper_config ) responses = [result] logprobs_results = [None] completion_tokens = len(pipeline.tokenizer.encode(result)) elif decoding == "autothink": # Get steering dataset configuration - steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors") - target_layer = kwargs.get("target_layer", 19) - - # Prepare AutoThink configuration - autothink_config = { - "steering_dataset": steering_dataset, - "target_layer": target_layer, - "pattern_strengths": kwargs.get("pattern_strengths", { - "depth_and_thoroughness": 2.5, - "numerical_accuracy": 2.0, - "self_correction": 3.0, - "exploration": 2.0, - "organization": 1.5 - }) - } - - # Process with AutoThink - result = autothink_decode( - pipeline.current_model, - pipeline.tokenizer, - messages, - autothink_config - ) - responses = [result] - logprobs_results = [None] - completion_tokens = len(pipeline.tokenizer.encode(result)) + steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors") + target_layer = kwargs.get("target_layer", 19) + + # Prepare AutoThink configuration + autothink_config = { + "steering_dataset": steering_dataset, + "target_layer": target_layer, + "pattern_strengths": kwargs.get("pattern_strengths", { + "depth_and_thoroughness": 2.5, + "numerical_accuracy": 2.0, + "self_correction": 3.0, + "exploration": 2.0, + "organization": 1.5 + }) + } + + # Process with AutoThink + result = autothink_decode( + pipeline.current_model, + pipeline.tokenizer, + messages, + autothink_config + ) + responses = [result] + logprobs_results = [None] + completion_tokens = len(pipeline.tokenizer.encode(result)) else: raise ValueError(f"Unknown specialized decoding approach: {decoding}") diff --git a/optillm/moa.py b/optillm/moa.py index 5306e25..21d5e10 100644 --- a/optillm/moa.py +++ b/optillm/moa.py @@ -8,19 +8,61 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str completions = [] logger.debug(f"Generating initial completions for query: {initial_query}") - response = client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": initial_query} - ], - max_tokens=4096, - n=3, - temperature=1 - ) - completions = [choice.message.content for choice in response.choices] - moa_completion_tokens += response.usage.completion_tokens - logger.info(f"Generated {len(completions)} initial completions. Tokens used: {response.usage.completion_tokens}") + + try: + # Try to generate 3 completions in a single API call using n parameter + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ], + max_tokens=4096, + n=3, + temperature=1 + ) + completions = [choice.message.content for choice in response.choices] + moa_completion_tokens += response.usage.completion_tokens + logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}") + + except Exception as e: + logger.warning(f"n parameter not supported by provider: {str(e)}") + logger.info("Falling back to generating 3 completions one by one") + + # Fallback: Generate 3 completions one by one in a loop + completions = [] + for i in range(3): + try: + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ], + max_tokens=4096, + temperature=1 + ) + completions.append(response.choices[0].message.content) + moa_completion_tokens += response.usage.completion_tokens + logger.debug(f"Generated completion {i+1}/3") + + except Exception as fallback_error: + logger.error(f"Error generating completion {i+1}: {str(fallback_error)}") + continue + + if not completions: + logger.error("Failed to generate any completions") + return "Error: Could not generate any completions", 0 + + logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {moa_completion_tokens}") + + # Handle case where fewer than 3 completions were generated + if len(completions) < 3: + original_count = len(completions) + # Pad with the first completion to ensure we have 3 + while len(completions) < 3: + completions.append(completions[0]) + logger.warning(f"Only generated {original_count} unique completions, padded to 3 for critique") logger.debug("Preparing critique prompt") critique_prompt = f""" diff --git a/optillm/plugins/majority_voting_plugin.py b/optillm/plugins/majority_voting_plugin.py new file mode 100644 index 0000000..311072b --- /dev/null +++ b/optillm/plugins/majority_voting_plugin.py @@ -0,0 +1,293 @@ +""" +Majority Voting Plugin for OptILLM + +This plugin implements a majority voting approach where k candidate solutions +are generated and the most frequent answer is selected. This is particularly +effective for problems with discrete answers (math, coding, multiple choice). + +The plugin uses the OpenAI API's n parameter to generate multiple responses +efficiently in a single API call. +""" + +import re +import logging +from typing import Tuple, Dict, Any, List, Optional +from collections import Counter +import json + +logger = logging.getLogger(__name__) + +# Plugin identifier +SLUG = "majority_voting" + +# Default number of candidates to generate +DEFAULT_K = 6 + +# Default temperature for candidate generation +DEFAULT_TEMPERATURE = 0.6 + +def extract_answer(text: str) -> Optional[str]: + """ + Extract the answer from a response text. + + This function looks for common answer patterns in the response: + 1. Text after "Answer:" or "Final Answer:" + 2. Text within \\boxed{} (LaTeX format) + 3. Numbers at the end of the response + 4. The last line if it's short (likely the answer) + + Args: + text: The response text to extract answer from + + Returns: + The extracted answer or None if no clear answer found + """ + # Remove any trailing whitespace + text = text.strip() + + # Pattern 1: Look for LaTeX boxed format first (handle both \boxed and \\boxed) + boxed_match = re.search(r'\\{1,2}boxed\{([^}]+)\}', text) + if boxed_match: + answer = boxed_match.group(1).strip() + logger.debug(f"Extracted boxed answer: {answer}") + return answer + + # Pattern 2: Look for "Answer:" or "Final Answer:" patterns + answer_patterns = [ + r'(?:final\s+)?answer\s*[:=]\s*(.+?)(?:\n|$)', + r'(?:the\s+)?(?:final\s+)?answer\s+is\s*[:=]?\s*(.+?)(?:\n|$)', + r'(?:therefore|thus|so)\s*,?\s*(.+?)(?:\n|$)' + ] + + for pattern in answer_patterns: + match = re.search(pattern, text, re.IGNORECASE) + if match: + answer = match.group(1).strip() + # Clean up the answer + answer = answer.rstrip('.,;') + if answer: + logger.debug(f"Extracted answer using pattern: {answer}") + return answer + + # Pattern 3: Look for standalone numbers (useful for math problems) + # Check the last few lines for a number + lines = text.split('\n') + for line in reversed(lines[-3:]): # Check last 3 lines + line = line.strip() + # Match numbers (including decimals, fractions, negative numbers) + number_match = re.match(r'^-?\d+\.?\d*$|^-?\d+/\d+$', line) + if number_match: + logger.debug(f"Extracted number answer: {line}") + return line + + # Pattern 4: For multiple choice, look for single letter answers + # Check this before the generic last line check + mc_patterns = [ + r'(?:the\s+)?(?:correct\s+)?(?:answer|option)\s+is\s+([A-E])(?:\b|$)', + r'(?:choose|select|pick)\s+(?:option\s+)?([A-E])(?:\b|$)', + r'\b([A-E])\s*\)\s*[A-Za-z]+.*is\s+(?:the\s+)?(?:correct|right)', + r'^([A-E])$', # Just a letter on its own line + ] + + for pattern in mc_patterns: + mc_match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE) + if mc_match: + answer = mc_match.group(1).upper() + logger.debug(f"Extracted multiple choice answer: {answer}") + return answer + + # Pattern 5: If the last line is short (< 50 chars), it might be the answer + if lines: + last_line = lines[-1].strip() + if last_line and len(last_line) < 50 and not last_line.endswith(':'): + logger.debug(f"Using last line as answer: {last_line}") + return last_line + + logger.warning("Could not extract a clear answer from the response") + return None + +def normalize_answer(answer: str) -> str: + """ + Normalize an answer for comparison. + + This helps ensure that equivalent answers are treated as the same: + - Converts to lowercase + - Removes extra whitespace + - Removes quotes + - Normalizes number formats + + Args: + answer: The answer to normalize + + Returns: + The normalized answer + """ + # Convert to lowercase + answer = answer.lower().strip() + + # Remove quotes + answer = answer.strip('"\'') + + # Normalize whitespace + answer = ' '.join(answer.split()) + + # Try to normalize numbers + try: + # Check if it's a float + if '.' in answer: + num = float(answer) + # Format to remove trailing zeros + answer = f"{num:g}" + else: + # Try integer + num = int(answer) + answer = str(num) + except ValueError: + # Not a number, keep as is + pass + + # Handle yes/no variations + if answer in ['yes', 'yeah', 'yep', 'true', 'correct']: + answer = 'yes' + elif answer in ['no', 'nope', 'false', 'incorrect']: + answer = 'no' + + return answer + +def run( + system_prompt: str, + initial_query: str, + client, + model: str, + request_config: Dict[str, Any] = None +) -> Tuple[str, int]: + """ + Main entry point for the majority voting plugin. + + Generates k candidate solutions and returns the most frequent answer. + + Args: + system_prompt: System prompt for the model + initial_query: User's query + client: OpenAI-compatible client instance + model: Model identifier + request_config: Additional configuration parameters + + Returns: + Tuple of (response_text, completion_tokens_used) + """ + logger.info("Starting majority voting process") + + # Extract parameters from request_config + k = DEFAULT_K + temperature = DEFAULT_TEMPERATURE + + if request_config: + k = request_config.get('k', DEFAULT_K) + # Allow overriding temperature if needed + temperature = request_config.get('temperature', DEFAULT_TEMPERATURE) + # Respect max_tokens if provided + max_tokens = request_config.get('max_tokens', 4096) + else: + max_tokens = 4096 + + logger.info(f"Generating {k} candidates with temperature={temperature}") + + # Prepare messages + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ] + + try: + # Generate k candidates in a single API call using n parameter + response = client.chat.completions.create( + model=model, + messages=messages, + n=k, + temperature=temperature, + max_tokens=max_tokens + ) + + # Extract all candidate responses + candidates = [choice.message.content for choice in response.choices] + total_tokens = response.usage.completion_tokens + + logger.info(f"Generated {len(candidates)} candidates using n parameter. Tokens used: {total_tokens}") + + except Exception as e: + logger.warning(f"n parameter not supported by provider: {str(e)}") + logger.info(f"Falling back to generating {k} candidates one by one") + + # Fallback: Generate candidates one by one in a loop + candidates = [] + total_tokens = 0 + + for i in range(k): + try: + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens + ) + candidates.append(response.choices[0].message.content) + total_tokens += response.usage.completion_tokens + logger.debug(f"Generated candidate {i+1}/{k}") + + except Exception as fallback_error: + logger.error(f"Error generating candidate {i+1}: {str(fallback_error)}") + continue + + if not candidates: + logger.error("Failed to generate any candidates") + return "Error: Could not generate any candidates", 0 + + logger.info(f"Generated {len(candidates)} candidates using fallback method. Total tokens used: {total_tokens}") + + # Extract answers from each candidate + answers = [] + answer_to_response = {} # Map normalized answers to full responses + + for i, candidate in enumerate(candidates): + answer = extract_answer(candidate) + if answer: + normalized = normalize_answer(answer) + answers.append(normalized) + # Keep the first full response for each unique answer + if normalized not in answer_to_response: + answer_to_response[normalized] = candidate + logger.debug(f"Candidate {i+1} answer: {answer} (normalized: {normalized})") + else: + logger.warning(f"Could not extract answer from candidate {i+1}") + + if not answers: + logger.warning("No answers could be extracted from any candidate") + # Return the first candidate as fallback + return candidates[0] if candidates else "Error: No candidates generated", total_tokens + + # Count answer frequencies + answer_counts = Counter(answers) + logger.info(f"Answer distribution: {dict(answer_counts)}") + + # Get the most common answer + most_common_answer, count = answer_counts.most_common(1)[0] + confidence = count / len(answers) + + logger.info(f"Most common answer: '{most_common_answer}' with {count}/{len(answers)} votes ({confidence:.1%} confidence)") + + # Get the full response corresponding to the most common answer + winning_response = answer_to_response.get(most_common_answer, candidates[0]) + + # Log voting summary to console instead of adding to response + logger.info("Majority Voting Summary:") + logger.info(f" - Generated {len(candidates)} candidates") + logger.info(f" - Most common answer: {most_common_answer}") + logger.info(f" - Votes: {count}/{len(answers)} ({confidence:.1%} confidence)") + + if len(answer_counts) > 1: + other_answers = [f"{ans} ({cnt} votes)" for ans, cnt in answer_counts.items() if ans != most_common_answer] + logger.info(f" - Other answers: {', '.join(other_answers)}") + + # Return only the full response from the winning answer + return winning_response, total_tokens \ No newline at end of file diff --git a/optillm/thinkdeeper_mlx.py b/optillm/thinkdeeper_mlx.py new file mode 100644 index 0000000..043e287 --- /dev/null +++ b/optillm/thinkdeeper_mlx.py @@ -0,0 +1,327 @@ +""" +MLX-compatible implementation of ThinkDeeper +Provides the same functionality as the PyTorch version but adapted for MLX framework +""" + +import random +from typing import Tuple, Dict, Any, List +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +try: + import mlx.core as mx + from mlx_lm import generate as mlx_generate + from mlx_lm.sample_utils import make_sampler + MLX_AVAILABLE = True +except ImportError: + MLX_AVAILABLE = False + +DEFAULT_CONFIG = { + "min_thinking_tokens": 1024, + "max_thinking_tokens": 4196, + "max_thoughts": 64, + "prefill": "", + "start_think_token": "", + "end_think_token": "", + "thought_switch_tokens": [ + "Wait,", + "Alternatively,", + ], +} + +class MLXThinkDeeperProcessor: + def __init__(self, config: Dict[str, Any], tokenizer, model): + self.config = {**DEFAULT_CONFIG, **config} + self.tokenizer = tokenizer + self.model = model + + # Get token IDs for think markers + start_tokens = self.tokenizer.encode(self.config['start_think_token']) + end_tokens = self.tokenizer.encode(self.config['end_think_token']) + self._start_think_token = start_tokens[0] if len(start_tokens) == 1 else start_tokens[1] + self.end_think_token = end_tokens[0] if len(end_tokens) == 1 else end_tokens[1] + + # Store thought switch markers as token sequences + self.thought_switch_sequences = [] + for phrase in self.config["thought_switch_tokens"]: + # Encode without adding special tokens to get exact sequence + token_ids = self.tokenizer.encode(phrase, add_special_tokens=False) + self.thought_switch_sequences.append(token_ids) + + # Track thought switches + self.thought_count = 0 + self.current_sequence = [] # Track recent tokens for sequence matching + self.max_sequence_length = max(len(seq) for seq in self.thought_switch_sequences) if self.thought_switch_sequences else 5 + + # Track total tokens for budget management + self.total_tokens_generated = 0 + self.max_total_tokens = config.get('max_tokens', 8192) # Default to 8192 if not specified + + def is_thought_switch(self, token: int) -> bool: + """Check if adding this token creates a thought switch sequence.""" + # Add new token to current sequence + self.current_sequence.append(token) + + # Keep only the most recent tokens that could match our sequences + if len(self.current_sequence) > self.max_sequence_length: + self.current_sequence = self.current_sequence[-self.max_sequence_length:] + + # Check if current sequence ends with any thought switch sequence + for sequence in self.thought_switch_sequences: + if len(sequence) <= len(self.current_sequence) and \ + self.current_sequence[-len(sequence):] == sequence: + return True + + return False + + def reasoning_effort(self, messages) -> str: + """Generate response with ThinkDeeper's controlled thinking process using MLX""" + + # Prepare the messages with thinking token + thinking_messages = messages.copy() + thinking_messages.append({ + "role": "assistant", + "content": f"{self.config['start_think_token']}\n{self.config['prefill']}" + }) + + # Convert messages to prompt using tokenizer + if hasattr(self.tokenizer, 'apply_chat_template'): + prompt = self.tokenizer.apply_chat_template( + thinking_messages, + continue_final_message=False, # This was causing MLX failures! + tokenize=False, + add_generation_prompt=True # Standard generation prompt + ) + else: + # Fallback: simple concatenation + prompt = "" + for msg in thinking_messages: + prompt += f"{msg['role']}: {msg['content']}\n" + + + # Initialize tracking variables + n_thinking_tokens = 0 + seen_end_think = False + response_chunks = [] + + # Use MLX generation with custom token-by-token control + # Since MLX doesn't support token-by-token generation like PyTorch, + # we'll use a different approach: generate in chunks and check for markers + + current_prompt = prompt + max_chunk_size = 150 # Increase chunk size - MLX may work better with larger chunks + consecutive_empty_chunks = 0 + max_empty_chunks = 3 # Allow up to 3 consecutive empty chunks before stopping + + while (n_thinking_tokens < self.config["max_thinking_tokens"] and + self.thought_count < self.config["max_thoughts"] and + self.total_tokens_generated < self.max_total_tokens - 512): # Reserve 512 tokens for final response + try: + # Generate a small chunk of tokens + chunk_response = self._generate_chunk( + current_prompt, + max_tokens=min(max_chunk_size, self.config["max_thinking_tokens"] - n_thinking_tokens), + temperature=0.6 + ) + + if not chunk_response or chunk_response.strip() == "": + consecutive_empty_chunks += 1 + + if consecutive_empty_chunks >= max_empty_chunks: + break + + # Try with different parameters for next attempt + max_chunk_size = min(max_chunk_size + 50, 300) # Increase chunk size more aggressively + continue + else: + # Reset empty chunk counter on successful generation + consecutive_empty_chunks = 0 + max_chunk_size = 150 # Reset chunk size + + # Update token counts + chunk_tokens = len(self.tokenizer.encode(chunk_response)) + self.total_tokens_generated += chunk_tokens + + # Check for end think token in the chunk + if self.config['end_think_token'] in chunk_response: + # Split at the end think token + parts = chunk_response.split(self.config['end_think_token'], 1) + before_end = parts[0] + after_end = parts[1] if len(parts) > 1 else "" + + response_chunks.append(before_end) + n_thinking_tokens += len(self.tokenizer.encode(before_end)) + + # Check if we've reached minimum thinking tokens + if n_thinking_tokens < self.config["min_thinking_tokens"]: + # Insert thought transition instead of ending + transition = random.choice(self.config["thought_switch_tokens"]) + response_chunks.append(transition) + current_prompt += before_end + transition + n_thinking_tokens += len(self.tokenizer.encode(transition)) + self.thought_count += 1 + continue + else: + # Natural end - add the end token and continue for conclusion + response_chunks.append(self.config['end_think_token']) + current_prompt += before_end + self.config['end_think_token'] + seen_end_think = True + + # Generate conclusion after thinking + if after_end.strip(): + response_chunks.append(after_end) + else: + conclusion = self._generate_chunk(current_prompt, max_tokens=200, temperature=0.3) + if conclusion: + response_chunks.append(conclusion) + break + else: + # No end think token found, add the chunk and continue + response_chunks.append(chunk_response) + current_prompt += chunk_response + n_thinking_tokens += len(self.tokenizer.encode(chunk_response)) + + # Check for thought switch patterns in the chunk + for phrase in self.config["thought_switch_tokens"]: + if phrase in chunk_response: + self.thought_count += 1 + break + + # Safety check to avoid infinite loops + if len(response_chunks) > 100: + logger.warning("Too many chunks generated, stopping to avoid infinite loop") + break + + except Exception as e: + logger.error(f"Error during MLX chunk generation: {str(e)}") + break + + # Enforce minimum thinking tokens if not reached + if not seen_end_think and n_thinking_tokens < self.config["min_thinking_tokens"]: + while n_thinking_tokens < self.config["min_thinking_tokens"] and self.thought_count < self.config["max_thoughts"]: + # Add transition and continue thinking + transition = random.choice(self.config["thought_switch_tokens"]) + response_chunks.append(f" {transition} ") + current_prompt += f" {transition} " + + # Generate more thinking content + additional_thinking = self._generate_chunk( + current_prompt, + max_tokens=min(200, self.config["min_thinking_tokens"] - n_thinking_tokens + 100), + temperature=0.6 + ) + + if additional_thinking and additional_thinking.strip(): + response_chunks.append(additional_thinking) + current_prompt += additional_thinking + additional_tokens = len(self.tokenizer.encode(additional_thinking)) + n_thinking_tokens += additional_tokens + self.thought_count += 1 + else: + # If generation fails, break to avoid infinite loop + break + + # If we haven't seen end think token, force it + if not seen_end_think: + response_chunks.append(self.config['end_think_token']) + + # Add a brief conclusion + try: + conclusion = self._generate_chunk( + current_prompt + self.config['end_think_token'], + max_tokens=100, + temperature=0.3 + ) + if conclusion: + response_chunks.append(conclusion) + except Exception as e: + logger.error(f"Error generating conclusion: {str(e)}") + + # Join all chunks and create final response + response_content = "".join(response_chunks) + full_response = f"{self.config['start_think_token']}\n{self.config['prefill']}{response_content}" + + return full_response + + def _generate_chunk(self, prompt: str, max_tokens: int, temperature: float) -> str: + """Generate a small chunk of text using MLX with proper sampler""" + try: + # Let MLX fail naturally to identify the real issue + + # Create sampler with specified thinkdeeper parameters + sampler = make_sampler( + temp=temperature, + top_p=0.95, + top_k=20, + min_p=0.0, + min_tokens_to_keep=3 + ) + + # Use mlx_generate with the sampler + # Ensure we have minimum tokens to generate - larger minimum for better MLX performance + actual_max_tokens = max(max_tokens, 30) # At least 30 tokens for better generation + + response = mlx_generate( + self.model, + self.tokenizer, + prompt, + max_tokens=actual_max_tokens, + sampler=sampler, + verbose=False + ) + + # MLX generate might return just the generated tokens or the full text + # Check if response starts with the prompt + if response: + if response.startswith(prompt): + # Response includes the prompt, extract new content + new_content = response[len(prompt):] + else: + # Response is just the generated tokens + new_content = response + + if new_content.strip(): # Only return non-empty content + return new_content + + return "" + + except Exception as e: + logger.error(f"Error in MLX chunk generation: {str(e)}") + return "" + +def thinkdeeper_decode_mlx( + model, + tokenizer, + messages: List[Dict[str, str]], + request_config: Dict[str, Any] = None +) -> str: + """MLX-compatible ThinkDeeper processing function""" + logger.info("Starting MLX ThinkDeeper processing") + + if not MLX_AVAILABLE: + raise RuntimeError("MLX framework not available for ThinkDeeper processing") + + # Extract config from request_config if provided + config = DEFAULT_CONFIG.copy() + if request_config: + # Update only valid keys from DEFAULT_CONFIG + for key in DEFAULT_CONFIG: + if key in request_config: + config[key] = request_config[key] + + # Also handle max_tokens which is not in DEFAULT_CONFIG + if 'max_tokens' in request_config: + config['max_tokens'] = request_config['max_tokens'] + + logger.info(f"MLX ThinkDeeper using config: {config}") + + try: + processor = MLXThinkDeeperProcessor(config, tokenizer, model) + response = processor.reasoning_effort(messages) + return response + + except Exception as e: + logger.error(f"Error in MLX ThinkDeeper processing: {str(e)}") + raise \ No newline at end of file diff --git a/scripts/eval_aime_benchmark.py b/scripts/eval_aime_benchmark.py index a854f83..ac61c35 100644 --- a/scripts/eval_aime_benchmark.py +++ b/scripts/eval_aime_benchmark.py @@ -256,7 +256,7 @@ def analyze_logits_probs(logprobs_data: List[Dict]) -> Dict: "token_count": len(token_entropies) } -def get_llm_response(problem: str, model: str, analyze_logits: bool = False) -> Union[str, List[Dict]]: +def get_llm_response(problem: str, model: str, analyze_logits: bool = False, extra_body: dict = None) -> Union[str, List[Dict]]: """ Get response from the LLM for a given problem. If multiple choices are returned, formats them as attempt dictionaries. @@ -276,18 +276,16 @@ def get_llm_response(problem: str, model: str, analyze_logits: bool = False) -> kwargs["logprobs"] = True kwargs["top_logprobs"] = 3 + # Add extra_body if provided + if extra_body: + kwargs["extra_body"] = extra_body + response = client.with_options(timeout=1000.0).chat.completions.create( model=model, messages=[ {"role": "user", "content": SYSTEM_PROMPT + problem} ], max_tokens=8192, - # extra_body={ - # "decoding": "thinkdeeper", - # "min_thinking_tokens" : 0, - # "max_thinking_tokens" : 8000, - # "max_thoughts": 100, - # }, **kwargs ) @@ -333,7 +331,7 @@ def get_llm_response(problem: str, model: str, analyze_logits: bool = False) -> logger.error(f"Error getting LLM response: {e}") return "" -def make_n_attempts(problem: str, model: str, n: int, analyze_thoughts: bool = False, analyze_logits: bool = False) -> List[Dict]: +def make_n_attempts(problem: str, model: str, n: int, analyze_thoughts: bool = False, analyze_logits: bool = False, extra_body: dict = None) -> List[Dict]: """ Make n attempts to solve a problem and return all responses and predictions. @@ -351,7 +349,7 @@ def make_n_attempts(problem: str, model: str, n: int, analyze_thoughts: bool = F remaining_attempts = n while remaining_attempts > 0: - response = get_llm_response(problem, model, analyze_logits) + response = get_llm_response(problem, model, analyze_logits, extra_body) # If response is already formatted as attempts if isinstance(response, list): @@ -774,7 +772,7 @@ def save_raw_response(filename: str, problem_id: int, response_data: Dict): return response_id -def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_logits: bool = False): +def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_logits: bool = False, test_time_compute: bool = False, approach_name: str = None, extra_body: dict = None): """Main evaluation function that handles gaps in processed indexes.""" os.makedirs("results", exist_ok=True) @@ -784,6 +782,8 @@ def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_lo suffix_parts.append("thought_analysis") if analyze_logits: suffix_parts.append("logit_analysis") + if approach_name: + suffix_parts.append(approach_name) suffix = "_" + "_".join(suffix_parts) if suffix_parts else "" results_file = f"results/evaluation_results_{model.replace('/', '_')}_pass_at_{n_attempts}{suffix}.json" @@ -804,7 +804,7 @@ def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_lo correct_answer = int(item['answer']) # Make n attempts for each problem - attempts = make_n_attempts(problem_text, model, n_attempts, analyze_thoughts, analyze_logits) + attempts = make_n_attempts(problem_text, model, n_attempts, analyze_thoughts, analyze_logits, extra_body) is_correct, first_correct = evaluate_pass_at_n(attempts, correct_answer) result = { @@ -826,6 +826,51 @@ def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_lo parser.add_argument("--n", type=int, default=1, help="Number of attempts per problem (for pass@n evaluation)") parser.add_argument("--analyze-thoughts", action="store_true", help="Analyze thinking patterns in responses") parser.add_argument("--analyze-logits", action="store_true", help="Analyze token probability distributions") + parser.add_argument("--test-time-compute", action="store_true", help="Evaluate test-time compute scaling approaches") args = parser.parse_args() - main(args.model, args.n, args.analyze_thoughts, args.analyze_logits) \ No newline at end of file + if args.test_time_compute: + # Define test-time compute approaches with same config as eval_optillmbench.py + TEST_TIME_COMPUTE_APPROACHES = [ + # Baseline + ("none", "Baseline without any optimization", {}), + + # Sequential test-time compute using thinkdeeper with controlled thinking budgets + ("thinkdeeper_2k", "ThinkDeeper with 2K thinking tokens", { + "decoding": "thinkdeeper", + "min_thinking_tokens": 2048, + "max_thinking_tokens": 2560, # min + 512 for flexibility + "max_tokens": 3072 # Total budget: max_thinking_tokens + 512 + }), + ("thinkdeeper_4k", "ThinkDeeper with 4K thinking tokens", { + "decoding": "thinkdeeper", + "min_thinking_tokens": 4096, + "max_thinking_tokens": 4608, # min + 512 for flexibility + "max_tokens": 5120 # Total budget: max_thinking_tokens + 512 + }), + ("thinkdeeper_8k", "ThinkDeeper with 8K thinking tokens", { + "decoding": "thinkdeeper", + "min_thinking_tokens": 8192, + "max_thinking_tokens": 8704, # min + 512 for flexibility + "max_tokens": 9216 # Total budget: max_thinking_tokens + 512 + }), + + # Parallel test-time compute using majority voting with different k values + ("majority_voting_3", "Majority Voting with k=3", {"k": 3}), + ("majority_voting_6", "Majority Voting with k=6", {"k": 6}), + ("majority_voting_9", "Majority Voting with k=9", {"k": 9}), + ] + + # Run evaluation for each approach + for approach_slug, approach_name, extra_body in TEST_TIME_COMPUTE_APPROACHES: + print(f"\n{'=' * 80}") + print(f"Evaluating: {approach_name}") + print(f"Model: {args.model}") + print(f"Approach: {approach_slug}") + print(f"Extra body: {extra_body}") + print(f"{'=' * 80}\n") + + main(args.model, args.n, args.analyze_thoughts, args.analyze_logits, + test_time_compute=True, approach_name=approach_slug, extra_body=extra_body) + else: + main(args.model, args.n, args.analyze_thoughts, args.analyze_logits) \ No newline at end of file diff --git a/scripts/eval_optillmbench.py b/scripts/eval_optillmbench.py index ed59d35..58eac41 100644 --- a/scripts/eval_optillmbench.py +++ b/scripts/eval_optillmbench.py @@ -21,19 +21,50 @@ logger = logging.getLogger(__name__) # Define the approaches to test -# Each approach is (name, description) +# Each approach is (name, description, extra_body_params) APPROACHES = [ - ("none", "Baseline without any optimization"), - ("leap", "LEAP Approach"), - ("rto", "Round Trip Optimization"), - ("cot_reflection", "Chain of Thought with Reflection"), - ("self_consistency", "Self Consistency Check"), - ("plansearch", "Planning with Search"), - ("re2", "ReRead Approach"), - ("z3", "Z3 Solver for Mathematical Problems"), - ("coc", "Chain of Code"), - ("executecode" , "Execute Code"), - ("spl", "System Prompt Learning") + ("none", "Baseline without any optimization", {}), + ("leap", "LEAP Approach", {}), + ("rto", "Round Trip Optimization", {}), + ("cot_reflection", "Chain of Thought with Reflection", {}), + ("self_consistency", "Self Consistency Check", {}), + ("plansearch", "Planning with Search", {}), + ("re2", "ReRead Approach", {}), + ("z3", "Z3 Solver for Mathematical Problems", {}), + ("coc", "Chain of Code", {}), + ("executecode" , "Execute Code", {}), + ("spl", "System Prompt Learning", {}) +] + +# Define test-time compute approaches for sequential and parallel scaling +TEST_TIME_COMPUTE_APPROACHES = [ + # Baseline + ("none", "Baseline without any optimization", {}), + + # Sequential test-time compute using thinkdeeper with controlled thinking budgets + ("thinkdeeper_2k", "ThinkDeeper with 2K thinking tokens", { + "decoding": "thinkdeeper", + "min_thinking_tokens": 2048, + "max_thinking_tokens": 2560, # min + 512 for flexibility + "max_tokens": 3072 # Total budget: max_thinking_tokens + 512 + }), + ("thinkdeeper_4k", "ThinkDeeper with 4K thinking tokens", { + "decoding": "thinkdeeper", + "min_thinking_tokens": 4096, + "max_thinking_tokens": 4608, # min + 512 for flexibility + "max_tokens": 5120 # Total budget: max_thinking_tokens + 512 + }), + ("thinkdeeper_8k", "ThinkDeeper with 8K thinking tokens", { + "decoding": "thinkdeeper", + "min_thinking_tokens": 8192, + "max_thinking_tokens": 8704, # min + 512 for flexibility + "max_tokens": 9216 # Total budget: max_thinking_tokens + 512 + }), + + # Parallel test-time compute using majority voting with different k values + ("majority_voting_3", "Majority Voting with k=3", {"k": 3}), + ("majority_voting_6", "Majority Voting with k=6", {"k": 6}), + ("majority_voting_9", "Majority Voting with k=9", {"k": 9}), ] def load_optillm_bench() -> datasets.Dataset: @@ -265,6 +296,7 @@ def evaluate_model( model: str, dataset: datasets.Dataset, approach: str, + approach_extra_body: Dict[str, Any] = None, max_samples: int = None ) -> Tuple[Dict[str, float], List[Dict[str, Any]]]: """ @@ -286,8 +318,18 @@ def evaluate_model( # Prepare the dataset examples = dataset if max_samples is None else dataset.select(range(max_samples)) - # Create model name with approach - full_model_name = f"{approach}-{model}" if approach != "none" else model + # Create model name with approach - handle special cases + if approach == "none": + full_model_name = model + elif approach.startswith("thinkdeeper_"): + # For thinkdeeper, use base model name (decoding is passed in extra_body) + full_model_name = model + elif approach.startswith("majority_voting_"): + # For majority voting, use majority_voting prefix + full_model_name = f"majority_voting-{model}" + else: + # Standard approach prefix + full_model_name = f"{approach}-{model}" for example in tqdm(examples, desc=f"Evaluating {approach}"): try: @@ -297,6 +339,11 @@ def evaluate_model( # Record start time start_time = time.time() + # Prepare extra_body parameters + extra_body = {"spl_learning": False} + if approach_extra_body: + extra_body.update(approach_extra_body) + # Make API call response = client.chat.completions.create( model=full_model_name, @@ -306,7 +353,7 @@ def evaluate_model( ], temperature=0.2, max_tokens=4096, - extra_body= {"spl_learning": False}, + extra_body=extra_body, ) # Calculate time taken @@ -407,14 +454,20 @@ def save_results(metrics: Dict[str, float], detailed_results: List[Dict[str, Any logger.info(f"Results saved to {base_filename}_*") -def generate_report(all_metrics: Dict[str, Dict[str, float]], output_dir: str): +def generate_report(all_metrics: Dict[str, Dict[str, float]], output_dir: str, is_test_time_compute: bool = False): """Generate a comprehensive report comparing all approaches.""" report = [] # Header - report.append("# OptiLLM Bench Evaluation Report") + report_title = "OptiLLM Bench Test-Time Compute Evaluation Report" if is_test_time_compute else "OptiLLM Bench Evaluation Report" + report.append(f"# {report_title}") report.append(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + if is_test_time_compute: + report.append("This report evaluates test-time compute scaling approaches:") + report.append("- **Sequential scaling**: ThinkDeeper with varying thinking token budgets") + report.append("- **Parallel scaling**: Majority voting with varying k values\n") + # Overall Results Table report.append("## Overall Results") headers = ["Approach", "Accuracy", "Avg Time (s)", "Total Time (s)"] @@ -469,6 +522,8 @@ def main(): help="Directory to save results") parser.add_argument("--approaches", nargs="+", help="Specific approaches to evaluate (default: all)") + parser.add_argument("--test-time-compute", action="store_true", + help="Evaluate test-time compute approaches (sequential and parallel scaling)") parser.add_argument("--debug", action="store_true", help="Enable debug logging") args = parser.parse_args() @@ -494,44 +549,54 @@ def main(): dataset = load_optillm_bench() # Determine which approaches to evaluate - approaches_to_test = ( - [a[0] for a in APPROACHES if a[0] in args.approaches] - if args.approaches - else [a[0] for a in APPROACHES] - ) + if args.test_time_compute: + # Use test-time compute approaches + approaches_config = TEST_TIME_COMPUTE_APPROACHES + if args.approaches: + # Filter test-time compute approaches if specific ones are requested + approaches_config = [a for a in TEST_TIME_COMPUTE_APPROACHES if a[0] in args.approaches] + else: + # Use standard approaches + if args.approaches: + approaches_config = [a for a in APPROACHES if a[0] in args.approaches] + else: + approaches_config = APPROACHES # Store all metrics for final report all_metrics = {} # Evaluate each approach - for approach in approaches_to_test: - logger.info(f"Evaluating approach: {approach}") + for approach_name, description, extra_body_params in approaches_config: + logger.info(f"Evaluating approach: {approach_name} - {description}") + if extra_body_params: + logger.info(f"Extra parameters: {extra_body_params}") try: metrics, detailed_results = evaluate_model( client, args.model, dataset, - approach, + approach_name, + extra_body_params, args.max_samples ) - all_metrics[approach] = metrics + all_metrics[approach_name] = metrics # Save results for this approach - save_results(metrics, detailed_results, args.model, approach, + save_results(metrics, detailed_results, args.model, approach_name, args.output_dir) - logger.info(f"Completed evaluation for {approach}") + logger.info(f"Completed evaluation for {approach_name}") logger.info(f"Accuracy: {metrics['accuracy']*100:.2f}%") logger.info(f"Average time per sample: {metrics['average_time']:.2f}s") except Exception as e: - logger.error(f"Error evaluating approach {approach}: {e}") + logger.error(f"Error evaluating approach {approach_name}: {e}") continue # Generate final report - generate_report(all_metrics, args.output_dir) + generate_report(all_metrics, args.output_dir, args.test_time_compute) if __name__ == "__main__": main() \ No newline at end of file diff --git a/setup.py b/setup.py index d164df5..73a48e3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="optillm", - version="0.1.19", + version="0.1.20", packages=find_packages(include=['optillm', 'optillm.*']), # This ensures all subpackages are included py_modules=['optillm'], package_data={