Skip to content

Feat upgrade outlines #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 58 additions & 13 deletions optillm/plugins/json_plugin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Tuple, Dict, Any, Optional
import logging
from outlines import models, generate
import outlines
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pydantic import BaseModel, create_model
from transformers import AutoTokenizer

# Plugin identifier
SLUG = "json"
Expand All @@ -26,11 +27,9 @@ def __init__(self, model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
self.device = self.get_device()
logger.info(f"Using device: {self.device}")
try:
llm = AutoModelForCausalLM.from_pretrained(model_name)
llm.to(self.device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = models.Transformers(llm, tokenizer)
self.tokenizer = tokenizer
# Initialize the model using the new outlines API
self.model = outlines.from_transformers(model_name, device=str(self.device))
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Successfully loaded model: {model_name}")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
Expand All @@ -45,17 +44,63 @@ def count_tokens(self, text: str) -> int:
logger.error(f"Error counting tokens: {str(e)}")
return 0

def parse_json_schema_to_pydantic(self, schema_str: str) -> type[BaseModel]:
"""Convert JSON schema string to Pydantic model."""
try:
schema_dict = json.loads(schema_str)

# Extract properties and required fields
properties = schema_dict.get('properties', {})
required = schema_dict.get('required', [])

# Build field definitions for Pydantic
fields = {}
for field_name, field_def in properties.items():
field_type = str # Default to string

# Map JSON schema types to Python types
if field_def.get('type') == 'integer':
field_type = int
elif field_def.get('type') == 'number':
field_type = float
elif field_def.get('type') == 'boolean':
field_type = bool
elif field_def.get('type') == 'array':
field_type = list
elif field_def.get('type') == 'object':
field_type = dict

# Check if field is required
if field_name in required:
fields[field_name] = (field_type, ...)
else:
fields[field_name] = (Optional[field_type], None)

# Create dynamic Pydantic model
return create_model('DynamicModel', **fields)

except Exception as e:
logger.error(f"Error parsing JSON schema: {str(e)}")
raise

def generate_json(self, prompt: str, schema: str) -> Dict[str, Any]:
"""Generate JSON based on the provided schema and prompt."""
try:
# Create JSON generator with the schema
generator = generate.json(self.model, schema)
logger.info("Created JSON generator with schema")
# Parse JSON schema to Pydantic model
pydantic_model = self.parse_json_schema_to_pydantic(schema)
logger.info("Parsed JSON schema to Pydantic model")

# Generate JSON response
result = generator(prompt)
# Generate JSON response using the new API
result = self.model(prompt, pydantic_model)
logger.info("Successfully generated JSON response")
return result

# Convert Pydantic model instance to dict
if hasattr(result, 'model_dump'):
return result.model_dump()
elif hasattr(result, 'dict'):
return result.dict()
else:
return dict(result)

except Exception as e:
logger.error(f"Error generating JSON: {str(e)}")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "optillm"
version = "0.1.24"
version = "0.1.25"
description = "An optimizing inference proxy for LLMs."
readme = "README.md"
license = "Apache-2.0"
Expand Down
244 changes: 244 additions & 0 deletions tests/test_json_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
"""Test the JSON plugin for compatibility with outlines>=1.1.0"""

import unittest
from unittest.mock import Mock, patch, MagicMock
import json
from typing import Dict, Any

# Mock the dependencies before importing the plugin
import sys
sys.modules['torch'] = MagicMock()
sys.modules['transformers'] = MagicMock()
sys.modules['outlines'] = MagicMock()
sys.modules['pydantic'] = MagicMock()

# Import after mocking
from optillm.plugins.json_plugin import JSONGenerator, extract_schema_from_response_format, run


class TestJSONPlugin(unittest.TestCase):
"""Test cases for the JSON plugin with new outlines API."""

def setUp(self):
"""Set up test fixtures."""
# Sample JSON schemas for testing
self.simple_schema = json.dumps({
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"active": {"type": "boolean"}
},
"required": ["name", "age"]
})

self.complex_schema = json.dumps({
"type": "object",
"properties": {
"id": {"type": "integer"},
"email": {"type": "string"},
"score": {"type": "number"},
"tags": {"type": "array"},
"metadata": {"type": "object"}
},
"required": ["id", "email"]
})

@patch('optillm.plugins.json_plugin.outlines.from_transformers')
@patch('optillm.plugins.json_plugin.AutoTokenizer.from_pretrained')
def test_json_generator_init(self, mock_tokenizer, mock_from_transformers):
"""Test JSONGenerator initialization with new API."""
# Mock the model and tokenizer
mock_model = Mock()
mock_from_transformers.return_value = mock_model
mock_tokenizer.return_value = Mock()

# Initialize JSONGenerator
generator = JSONGenerator()

# Verify initialization
mock_from_transformers.assert_called_once()
mock_tokenizer.assert_called_once()
self.assertIsNotNone(generator.model)
self.assertIsNotNone(generator.tokenizer)

@patch('optillm.plugins.json_plugin.create_model')
def test_parse_json_schema_to_pydantic(self, mock_create_model):
"""Test JSON schema to Pydantic model conversion."""
# Mock Pydantic model creation
mock_model_class = Mock()
mock_create_model.return_value = mock_model_class

# Create generator with mocked dependencies
generator = JSONGenerator.__new__(JSONGenerator)

# Test simple schema parsing
result = generator.parse_json_schema_to_pydantic(self.simple_schema)

# Verify create_model was called with correct fields
mock_create_model.assert_called_once()
call_args = mock_create_model.call_args
self.assertEqual(call_args[0][0], 'DynamicModel')

# Check fields
fields = call_args[1]
self.assertIn('name', fields)
self.assertIn('age', fields)
self.assertIn('active', fields)

@patch('optillm.plugins.json_plugin.outlines.from_transformers')
@patch('optillm.plugins.json_plugin.AutoTokenizer.from_pretrained')
def test_generate_json_new_api(self, mock_tokenizer, mock_from_transformers):
"""Test JSON generation with new outlines API."""
# Create mock Pydantic instance with model_dump method
mock_result = Mock()
mock_result.model_dump.return_value = {"name": "Test", "age": 25}

# Mock the model to return our result
mock_model = Mock()
mock_model.return_value = mock_result
mock_from_transformers.return_value = mock_model

# Initialize generator
generator = JSONGenerator()

# Test generation
prompt = "Create a person named Test who is 25 years old"
result = generator.generate_json(prompt, self.simple_schema)

# Verify the result
self.assertEqual(result, {"name": "Test", "age": 25})
mock_model.assert_called_once()

def test_extract_schema_from_response_format(self):
"""Test schema extraction from OpenAI response format."""
# Test with OpenAI format
response_format = {
"type": "json_schema",
"json_schema": {
"name": "test_schema",
"schema": {
"type": "object",
"properties": {
"test": {"type": "string"}
}
}
}
}

result = extract_schema_from_response_format(response_format)
self.assertIsNotNone(result)

# Verify it's valid JSON
schema = json.loads(result)
self.assertEqual(schema["type"], "object")
self.assertIn("test", schema["properties"])

@patch('optillm.plugins.json_plugin.JSONGenerator')
def test_run_function_with_schema(self, mock_json_generator_class):
"""Test the main run function with a valid schema."""
# Mock JSONGenerator instance
mock_generator = Mock()
mock_generator.generate_json.return_value = {"result": "test"}
mock_generator.count_tokens.return_value = 10
mock_json_generator_class.return_value = mock_generator

# Mock client
mock_client = Mock()

# Test configuration
request_config = {
"response_format": {
"type": "json_schema",
"json_schema": {
"schema": {
"type": "object",
"properties": {
"result": {"type": "string"}
}
}
}
}
}

# Run the plugin
result, tokens = run(
"System prompt",
"Generate a test result",
mock_client,
"test-model",
request_config
)

# Verify results
self.assertIn("result", result)
self.assertEqual(tokens, 10)
mock_generator.generate_json.assert_called_once()

def test_run_function_without_schema(self):
"""Test the main run function without a schema (fallback)."""
# Mock client and response
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="Regular response"))]
mock_response.usage.completion_tokens = 5

mock_client = Mock()
mock_client.chat.completions.create.return_value = mock_response

# Run without schema
result, tokens = run(
"System prompt",
"Test query",
mock_client,
"test-model",
{}
)

# Verify fallback behavior
self.assertEqual(result, "Regular response")
self.assertEqual(tokens, 5)
mock_client.chat.completions.create.assert_called_once()

@patch('optillm.plugins.json_plugin.JSONGenerator')
def test_error_handling(self, mock_json_generator_class):
"""Test error handling and fallback."""
# Mock generator that raises an error
mock_generator = Mock()
mock_generator.generate_json.side_effect = Exception("Test error")
mock_json_generator_class.return_value = mock_generator

# Mock client for fallback
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="Fallback response"))]
mock_response.usage.completion_tokens = 8

mock_client = Mock()
mock_client.chat.completions.create.return_value = mock_response

# Test configuration with schema
request_config = {
"response_format": {
"type": "json_schema",
"json_schema": {
"schema": {"type": "object"}
}
}
}

# Run and expect fallback
result, tokens = run(
"System prompt",
"Test query",
mock_client,
"test-model",
request_config
)

# Verify fallback was used
self.assertEqual(result, "Fallback response")
self.assertEqual(tokens, 8)
mock_client.chat.completions.create.assert_called_once()


if __name__ == '__main__':
unittest.main()