From c5f956849e3f95d2ab53fde473b79587738b1303 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Sun, 27 Jul 2025 15:55:23 +0800 Subject: [PATCH 1/2] Refactor JSON plugin for outlines>=1.1.0 and add tests Updated the JSON plugin to use the new outlines API, including dynamic Pydantic model creation from JSON schema and updated model initialization. Added a comprehensive test suite for the plugin covering initialization, schema parsing, JSON generation, schema extraction, main run logic, and error handling. --- optillm/plugins/json_plugin.py | 71 ++++++++-- tests/test_json_plugin.py | 244 +++++++++++++++++++++++++++++++++ 2 files changed, 302 insertions(+), 13 deletions(-) create mode 100644 tests/test_json_plugin.py diff --git a/optillm/plugins/json_plugin.py b/optillm/plugins/json_plugin.py index 5a17763b..0263de1a 100644 --- a/optillm/plugins/json_plugin.py +++ b/optillm/plugins/json_plugin.py @@ -1,9 +1,10 @@ from typing import Tuple, Dict, Any, Optional import logging -from outlines import models, generate +import outlines import json import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from pydantic import BaseModel, create_model +from transformers import AutoTokenizer # Plugin identifier SLUG = "json" @@ -26,11 +27,9 @@ def __init__(self, model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" self.device = self.get_device() logger.info(f"Using device: {self.device}") try: - llm = AutoModelForCausalLM.from_pretrained(model_name) - llm.to(self.device) - tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = models.Transformers(llm, tokenizer) - self.tokenizer = tokenizer + # Initialize the model using the new outlines API + self.model = outlines.from_transformers(model_name, device=str(self.device)) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info(f"Successfully loaded model: {model_name}") except Exception as e: logger.error(f"Error loading model: {str(e)}") @@ -45,17 +44,63 @@ def count_tokens(self, text: str) -> int: logger.error(f"Error counting tokens: {str(e)}") return 0 + def parse_json_schema_to_pydantic(self, schema_str: str) -> type[BaseModel]: + """Convert JSON schema string to Pydantic model.""" + try: + schema_dict = json.loads(schema_str) + + # Extract properties and required fields + properties = schema_dict.get('properties', {}) + required = schema_dict.get('required', []) + + # Build field definitions for Pydantic + fields = {} + for field_name, field_def in properties.items(): + field_type = str # Default to string + + # Map JSON schema types to Python types + if field_def.get('type') == 'integer': + field_type = int + elif field_def.get('type') == 'number': + field_type = float + elif field_def.get('type') == 'boolean': + field_type = bool + elif field_def.get('type') == 'array': + field_type = list + elif field_def.get('type') == 'object': + field_type = dict + + # Check if field is required + if field_name in required: + fields[field_name] = (field_type, ...) + else: + fields[field_name] = (Optional[field_type], None) + + # Create dynamic Pydantic model + return create_model('DynamicModel', **fields) + + except Exception as e: + logger.error(f"Error parsing JSON schema: {str(e)}") + raise + def generate_json(self, prompt: str, schema: str) -> Dict[str, Any]: """Generate JSON based on the provided schema and prompt.""" try: - # Create JSON generator with the schema - generator = generate.json(self.model, schema) - logger.info("Created JSON generator with schema") + # Parse JSON schema to Pydantic model + pydantic_model = self.parse_json_schema_to_pydantic(schema) + logger.info("Parsed JSON schema to Pydantic model") - # Generate JSON response - result = generator(prompt) + # Generate JSON response using the new API + result = self.model(prompt, pydantic_model) logger.info("Successfully generated JSON response") - return result + + # Convert Pydantic model instance to dict + if hasattr(result, 'model_dump'): + return result.model_dump() + elif hasattr(result, 'dict'): + return result.dict() + else: + return dict(result) except Exception as e: logger.error(f"Error generating JSON: {str(e)}") diff --git a/tests/test_json_plugin.py b/tests/test_json_plugin.py new file mode 100644 index 00000000..c746869c --- /dev/null +++ b/tests/test_json_plugin.py @@ -0,0 +1,244 @@ +"""Test the JSON plugin for compatibility with outlines>=1.1.0""" + +import unittest +from unittest.mock import Mock, patch, MagicMock +import json +from typing import Dict, Any + +# Mock the dependencies before importing the plugin +import sys +sys.modules['torch'] = MagicMock() +sys.modules['transformers'] = MagicMock() +sys.modules['outlines'] = MagicMock() +sys.modules['pydantic'] = MagicMock() + +# Import after mocking +from optillm.plugins.json_plugin import JSONGenerator, extract_schema_from_response_format, run + + +class TestJSONPlugin(unittest.TestCase): + """Test cases for the JSON plugin with new outlines API.""" + + def setUp(self): + """Set up test fixtures.""" + # Sample JSON schemas for testing + self.simple_schema = json.dumps({ + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "active": {"type": "boolean"} + }, + "required": ["name", "age"] + }) + + self.complex_schema = json.dumps({ + "type": "object", + "properties": { + "id": {"type": "integer"}, + "email": {"type": "string"}, + "score": {"type": "number"}, + "tags": {"type": "array"}, + "metadata": {"type": "object"} + }, + "required": ["id", "email"] + }) + + @patch('optillm.plugins.json_plugin.outlines.from_transformers') + @patch('optillm.plugins.json_plugin.AutoTokenizer.from_pretrained') + def test_json_generator_init(self, mock_tokenizer, mock_from_transformers): + """Test JSONGenerator initialization with new API.""" + # Mock the model and tokenizer + mock_model = Mock() + mock_from_transformers.return_value = mock_model + mock_tokenizer.return_value = Mock() + + # Initialize JSONGenerator + generator = JSONGenerator() + + # Verify initialization + mock_from_transformers.assert_called_once() + mock_tokenizer.assert_called_once() + self.assertIsNotNone(generator.model) + self.assertIsNotNone(generator.tokenizer) + + @patch('optillm.plugins.json_plugin.create_model') + def test_parse_json_schema_to_pydantic(self, mock_create_model): + """Test JSON schema to Pydantic model conversion.""" + # Mock Pydantic model creation + mock_model_class = Mock() + mock_create_model.return_value = mock_model_class + + # Create generator with mocked dependencies + generator = JSONGenerator.__new__(JSONGenerator) + + # Test simple schema parsing + result = generator.parse_json_schema_to_pydantic(self.simple_schema) + + # Verify create_model was called with correct fields + mock_create_model.assert_called_once() + call_args = mock_create_model.call_args + self.assertEqual(call_args[0][0], 'DynamicModel') + + # Check fields + fields = call_args[1] + self.assertIn('name', fields) + self.assertIn('age', fields) + self.assertIn('active', fields) + + @patch('optillm.plugins.json_plugin.outlines.from_transformers') + @patch('optillm.plugins.json_plugin.AutoTokenizer.from_pretrained') + def test_generate_json_new_api(self, mock_tokenizer, mock_from_transformers): + """Test JSON generation with new outlines API.""" + # Create mock Pydantic instance with model_dump method + mock_result = Mock() + mock_result.model_dump.return_value = {"name": "Test", "age": 25} + + # Mock the model to return our result + mock_model = Mock() + mock_model.return_value = mock_result + mock_from_transformers.return_value = mock_model + + # Initialize generator + generator = JSONGenerator() + + # Test generation + prompt = "Create a person named Test who is 25 years old" + result = generator.generate_json(prompt, self.simple_schema) + + # Verify the result + self.assertEqual(result, {"name": "Test", "age": 25}) + mock_model.assert_called_once() + + def test_extract_schema_from_response_format(self): + """Test schema extraction from OpenAI response format.""" + # Test with OpenAI format + response_format = { + "type": "json_schema", + "json_schema": { + "name": "test_schema", + "schema": { + "type": "object", + "properties": { + "test": {"type": "string"} + } + } + } + } + + result = extract_schema_from_response_format(response_format) + self.assertIsNotNone(result) + + # Verify it's valid JSON + schema = json.loads(result) + self.assertEqual(schema["type"], "object") + self.assertIn("test", schema["properties"]) + + @patch('optillm.plugins.json_plugin.JSONGenerator') + def test_run_function_with_schema(self, mock_json_generator_class): + """Test the main run function with a valid schema.""" + # Mock JSONGenerator instance + mock_generator = Mock() + mock_generator.generate_json.return_value = {"result": "test"} + mock_generator.count_tokens.return_value = 10 + mock_json_generator_class.return_value = mock_generator + + # Mock client + mock_client = Mock() + + # Test configuration + request_config = { + "response_format": { + "type": "json_schema", + "json_schema": { + "schema": { + "type": "object", + "properties": { + "result": {"type": "string"} + } + } + } + } + } + + # Run the plugin + result, tokens = run( + "System prompt", + "Generate a test result", + mock_client, + "test-model", + request_config + ) + + # Verify results + self.assertIn("result", result) + self.assertEqual(tokens, 10) + mock_generator.generate_json.assert_called_once() + + def test_run_function_without_schema(self): + """Test the main run function without a schema (fallback).""" + # Mock client and response + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="Regular response"))] + mock_response.usage.completion_tokens = 5 + + mock_client = Mock() + mock_client.chat.completions.create.return_value = mock_response + + # Run without schema + result, tokens = run( + "System prompt", + "Test query", + mock_client, + "test-model", + {} + ) + + # Verify fallback behavior + self.assertEqual(result, "Regular response") + self.assertEqual(tokens, 5) + mock_client.chat.completions.create.assert_called_once() + + @patch('optillm.plugins.json_plugin.JSONGenerator') + def test_error_handling(self, mock_json_generator_class): + """Test error handling and fallback.""" + # Mock generator that raises an error + mock_generator = Mock() + mock_generator.generate_json.side_effect = Exception("Test error") + mock_json_generator_class.return_value = mock_generator + + # Mock client for fallback + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="Fallback response"))] + mock_response.usage.completion_tokens = 8 + + mock_client = Mock() + mock_client.chat.completions.create.return_value = mock_response + + # Test configuration with schema + request_config = { + "response_format": { + "type": "json_schema", + "json_schema": { + "schema": {"type": "object"} + } + } + } + + # Run and expect fallback + result, tokens = run( + "System prompt", + "Test query", + mock_client, + "test-model", + request_config + ) + + # Verify fallback was used + self.assertEqual(result, "Fallback response") + self.assertEqual(tokens, 8) + mock_client.chat.completions.create.assert_called_once() + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 5fe43ea894d046852656b48b62ea861390beed75 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Sun, 27 Jul 2025 15:56:11 +0800 Subject: [PATCH 2/2] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 30cb4879..6b1d7edc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "optillm" -version = "0.1.24" +version = "0.1.25" description = "An optimizing inference proxy for LLMs." readme = "README.md" license = "Apache-2.0"