Skip to content

Commit b64d891

Browse files
authored
Merge pull request #218 from codelion/feat-upgrade-outlines
Feat upgrade outlines
2 parents 0297900 + 5fe43ea commit b64d891

File tree

3 files changed

+303
-14
lines changed

3 files changed

+303
-14
lines changed

optillm/plugins/json_plugin.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import Tuple, Dict, Any, Optional
22
import logging
3-
from outlines import models, generate
3+
import outlines
44
import json
55
import torch
6-
from transformers import AutoModelForCausalLM, AutoTokenizer
6+
from pydantic import BaseModel, create_model
7+
from transformers import AutoTokenizer
78

89
# Plugin identifier
910
SLUG = "json"
@@ -26,11 +27,9 @@ def __init__(self, model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
2627
self.device = self.get_device()
2728
logger.info(f"Using device: {self.device}")
2829
try:
29-
llm = AutoModelForCausalLM.from_pretrained(model_name)
30-
llm.to(self.device)
31-
tokenizer = AutoTokenizer.from_pretrained(model_name)
32-
self.model = models.Transformers(llm, tokenizer)
33-
self.tokenizer = tokenizer
30+
# Initialize the model using the new outlines API
31+
self.model = outlines.from_transformers(model_name, device=str(self.device))
32+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
3433
logger.info(f"Successfully loaded model: {model_name}")
3534
except Exception as e:
3635
logger.error(f"Error loading model: {str(e)}")
@@ -45,17 +44,63 @@ def count_tokens(self, text: str) -> int:
4544
logger.error(f"Error counting tokens: {str(e)}")
4645
return 0
4746

47+
def parse_json_schema_to_pydantic(self, schema_str: str) -> type[BaseModel]:
48+
"""Convert JSON schema string to Pydantic model."""
49+
try:
50+
schema_dict = json.loads(schema_str)
51+
52+
# Extract properties and required fields
53+
properties = schema_dict.get('properties', {})
54+
required = schema_dict.get('required', [])
55+
56+
# Build field definitions for Pydantic
57+
fields = {}
58+
for field_name, field_def in properties.items():
59+
field_type = str # Default to string
60+
61+
# Map JSON schema types to Python types
62+
if field_def.get('type') == 'integer':
63+
field_type = int
64+
elif field_def.get('type') == 'number':
65+
field_type = float
66+
elif field_def.get('type') == 'boolean':
67+
field_type = bool
68+
elif field_def.get('type') == 'array':
69+
field_type = list
70+
elif field_def.get('type') == 'object':
71+
field_type = dict
72+
73+
# Check if field is required
74+
if field_name in required:
75+
fields[field_name] = (field_type, ...)
76+
else:
77+
fields[field_name] = (Optional[field_type], None)
78+
79+
# Create dynamic Pydantic model
80+
return create_model('DynamicModel', **fields)
81+
82+
except Exception as e:
83+
logger.error(f"Error parsing JSON schema: {str(e)}")
84+
raise
85+
4886
def generate_json(self, prompt: str, schema: str) -> Dict[str, Any]:
4987
"""Generate JSON based on the provided schema and prompt."""
5088
try:
51-
# Create JSON generator with the schema
52-
generator = generate.json(self.model, schema)
53-
logger.info("Created JSON generator with schema")
89+
# Parse JSON schema to Pydantic model
90+
pydantic_model = self.parse_json_schema_to_pydantic(schema)
91+
logger.info("Parsed JSON schema to Pydantic model")
5492

55-
# Generate JSON response
56-
result = generator(prompt)
93+
# Generate JSON response using the new API
94+
result = self.model(prompt, pydantic_model)
5795
logger.info("Successfully generated JSON response")
58-
return result
96+
97+
# Convert Pydantic model instance to dict
98+
if hasattr(result, 'model_dump'):
99+
return result.model_dump()
100+
elif hasattr(result, 'dict'):
101+
return result.dict()
102+
else:
103+
return dict(result)
59104

60105
except Exception as e:
61106
logger.error(f"Error generating JSON: {str(e)}")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "optillm"
7-
version = "0.1.24"
7+
version = "0.1.25"
88
description = "An optimizing inference proxy for LLMs."
99
readme = "README.md"
1010
license = "Apache-2.0"

tests/test_json_plugin.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
"""Test the JSON plugin for compatibility with outlines>=1.1.0"""
2+
3+
import unittest
4+
from unittest.mock import Mock, patch, MagicMock
5+
import json
6+
from typing import Dict, Any
7+
8+
# Mock the dependencies before importing the plugin
9+
import sys
10+
sys.modules['torch'] = MagicMock()
11+
sys.modules['transformers'] = MagicMock()
12+
sys.modules['outlines'] = MagicMock()
13+
sys.modules['pydantic'] = MagicMock()
14+
15+
# Import after mocking
16+
from optillm.plugins.json_plugin import JSONGenerator, extract_schema_from_response_format, run
17+
18+
19+
class TestJSONPlugin(unittest.TestCase):
20+
"""Test cases for the JSON plugin with new outlines API."""
21+
22+
def setUp(self):
23+
"""Set up test fixtures."""
24+
# Sample JSON schemas for testing
25+
self.simple_schema = json.dumps({
26+
"type": "object",
27+
"properties": {
28+
"name": {"type": "string"},
29+
"age": {"type": "integer"},
30+
"active": {"type": "boolean"}
31+
},
32+
"required": ["name", "age"]
33+
})
34+
35+
self.complex_schema = json.dumps({
36+
"type": "object",
37+
"properties": {
38+
"id": {"type": "integer"},
39+
"email": {"type": "string"},
40+
"score": {"type": "number"},
41+
"tags": {"type": "array"},
42+
"metadata": {"type": "object"}
43+
},
44+
"required": ["id", "email"]
45+
})
46+
47+
@patch('optillm.plugins.json_plugin.outlines.from_transformers')
48+
@patch('optillm.plugins.json_plugin.AutoTokenizer.from_pretrained')
49+
def test_json_generator_init(self, mock_tokenizer, mock_from_transformers):
50+
"""Test JSONGenerator initialization with new API."""
51+
# Mock the model and tokenizer
52+
mock_model = Mock()
53+
mock_from_transformers.return_value = mock_model
54+
mock_tokenizer.return_value = Mock()
55+
56+
# Initialize JSONGenerator
57+
generator = JSONGenerator()
58+
59+
# Verify initialization
60+
mock_from_transformers.assert_called_once()
61+
mock_tokenizer.assert_called_once()
62+
self.assertIsNotNone(generator.model)
63+
self.assertIsNotNone(generator.tokenizer)
64+
65+
@patch('optillm.plugins.json_plugin.create_model')
66+
def test_parse_json_schema_to_pydantic(self, mock_create_model):
67+
"""Test JSON schema to Pydantic model conversion."""
68+
# Mock Pydantic model creation
69+
mock_model_class = Mock()
70+
mock_create_model.return_value = mock_model_class
71+
72+
# Create generator with mocked dependencies
73+
generator = JSONGenerator.__new__(JSONGenerator)
74+
75+
# Test simple schema parsing
76+
result = generator.parse_json_schema_to_pydantic(self.simple_schema)
77+
78+
# Verify create_model was called with correct fields
79+
mock_create_model.assert_called_once()
80+
call_args = mock_create_model.call_args
81+
self.assertEqual(call_args[0][0], 'DynamicModel')
82+
83+
# Check fields
84+
fields = call_args[1]
85+
self.assertIn('name', fields)
86+
self.assertIn('age', fields)
87+
self.assertIn('active', fields)
88+
89+
@patch('optillm.plugins.json_plugin.outlines.from_transformers')
90+
@patch('optillm.plugins.json_plugin.AutoTokenizer.from_pretrained')
91+
def test_generate_json_new_api(self, mock_tokenizer, mock_from_transformers):
92+
"""Test JSON generation with new outlines API."""
93+
# Create mock Pydantic instance with model_dump method
94+
mock_result = Mock()
95+
mock_result.model_dump.return_value = {"name": "Test", "age": 25}
96+
97+
# Mock the model to return our result
98+
mock_model = Mock()
99+
mock_model.return_value = mock_result
100+
mock_from_transformers.return_value = mock_model
101+
102+
# Initialize generator
103+
generator = JSONGenerator()
104+
105+
# Test generation
106+
prompt = "Create a person named Test who is 25 years old"
107+
result = generator.generate_json(prompt, self.simple_schema)
108+
109+
# Verify the result
110+
self.assertEqual(result, {"name": "Test", "age": 25})
111+
mock_model.assert_called_once()
112+
113+
def test_extract_schema_from_response_format(self):
114+
"""Test schema extraction from OpenAI response format."""
115+
# Test with OpenAI format
116+
response_format = {
117+
"type": "json_schema",
118+
"json_schema": {
119+
"name": "test_schema",
120+
"schema": {
121+
"type": "object",
122+
"properties": {
123+
"test": {"type": "string"}
124+
}
125+
}
126+
}
127+
}
128+
129+
result = extract_schema_from_response_format(response_format)
130+
self.assertIsNotNone(result)
131+
132+
# Verify it's valid JSON
133+
schema = json.loads(result)
134+
self.assertEqual(schema["type"], "object")
135+
self.assertIn("test", schema["properties"])
136+
137+
@patch('optillm.plugins.json_plugin.JSONGenerator')
138+
def test_run_function_with_schema(self, mock_json_generator_class):
139+
"""Test the main run function with a valid schema."""
140+
# Mock JSONGenerator instance
141+
mock_generator = Mock()
142+
mock_generator.generate_json.return_value = {"result": "test"}
143+
mock_generator.count_tokens.return_value = 10
144+
mock_json_generator_class.return_value = mock_generator
145+
146+
# Mock client
147+
mock_client = Mock()
148+
149+
# Test configuration
150+
request_config = {
151+
"response_format": {
152+
"type": "json_schema",
153+
"json_schema": {
154+
"schema": {
155+
"type": "object",
156+
"properties": {
157+
"result": {"type": "string"}
158+
}
159+
}
160+
}
161+
}
162+
}
163+
164+
# Run the plugin
165+
result, tokens = run(
166+
"System prompt",
167+
"Generate a test result",
168+
mock_client,
169+
"test-model",
170+
request_config
171+
)
172+
173+
# Verify results
174+
self.assertIn("result", result)
175+
self.assertEqual(tokens, 10)
176+
mock_generator.generate_json.assert_called_once()
177+
178+
def test_run_function_without_schema(self):
179+
"""Test the main run function without a schema (fallback)."""
180+
# Mock client and response
181+
mock_response = Mock()
182+
mock_response.choices = [Mock(message=Mock(content="Regular response"))]
183+
mock_response.usage.completion_tokens = 5
184+
185+
mock_client = Mock()
186+
mock_client.chat.completions.create.return_value = mock_response
187+
188+
# Run without schema
189+
result, tokens = run(
190+
"System prompt",
191+
"Test query",
192+
mock_client,
193+
"test-model",
194+
{}
195+
)
196+
197+
# Verify fallback behavior
198+
self.assertEqual(result, "Regular response")
199+
self.assertEqual(tokens, 5)
200+
mock_client.chat.completions.create.assert_called_once()
201+
202+
@patch('optillm.plugins.json_plugin.JSONGenerator')
203+
def test_error_handling(self, mock_json_generator_class):
204+
"""Test error handling and fallback."""
205+
# Mock generator that raises an error
206+
mock_generator = Mock()
207+
mock_generator.generate_json.side_effect = Exception("Test error")
208+
mock_json_generator_class.return_value = mock_generator
209+
210+
# Mock client for fallback
211+
mock_response = Mock()
212+
mock_response.choices = [Mock(message=Mock(content="Fallback response"))]
213+
mock_response.usage.completion_tokens = 8
214+
215+
mock_client = Mock()
216+
mock_client.chat.completions.create.return_value = mock_response
217+
218+
# Test configuration with schema
219+
request_config = {
220+
"response_format": {
221+
"type": "json_schema",
222+
"json_schema": {
223+
"schema": {"type": "object"}
224+
}
225+
}
226+
}
227+
228+
# Run and expect fallback
229+
result, tokens = run(
230+
"System prompt",
231+
"Test query",
232+
mock_client,
233+
"test-model",
234+
request_config
235+
)
236+
237+
# Verify fallback was used
238+
self.assertEqual(result, "Fallback response")
239+
self.assertEqual(tokens, 8)
240+
mock_client.chat.completions.create.assert_called_once()
241+
242+
243+
if __name__ == '__main__':
244+
unittest.main()

0 commit comments

Comments
 (0)