1
+ """Test the JSON plugin for compatibility with outlines>=1.1.0"""
2
+
3
+ import unittest
4
+ from unittest .mock import Mock , patch , MagicMock
5
+ import json
6
+ from typing import Dict , Any
7
+
8
+ # Mock the dependencies before importing the plugin
9
+ import sys
10
+ sys .modules ['torch' ] = MagicMock ()
11
+ sys .modules ['transformers' ] = MagicMock ()
12
+ sys .modules ['outlines' ] = MagicMock ()
13
+ sys .modules ['pydantic' ] = MagicMock ()
14
+
15
+ # Import after mocking
16
+ from optillm .plugins .json_plugin import JSONGenerator , extract_schema_from_response_format , run
17
+
18
+
19
+ class TestJSONPlugin (unittest .TestCase ):
20
+ """Test cases for the JSON plugin with new outlines API."""
21
+
22
+ def setUp (self ):
23
+ """Set up test fixtures."""
24
+ # Sample JSON schemas for testing
25
+ self .simple_schema = json .dumps ({
26
+ "type" : "object" ,
27
+ "properties" : {
28
+ "name" : {"type" : "string" },
29
+ "age" : {"type" : "integer" },
30
+ "active" : {"type" : "boolean" }
31
+ },
32
+ "required" : ["name" , "age" ]
33
+ })
34
+
35
+ self .complex_schema = json .dumps ({
36
+ "type" : "object" ,
37
+ "properties" : {
38
+ "id" : {"type" : "integer" },
39
+ "email" : {"type" : "string" },
40
+ "score" : {"type" : "number" },
41
+ "tags" : {"type" : "array" },
42
+ "metadata" : {"type" : "object" }
43
+ },
44
+ "required" : ["id" , "email" ]
45
+ })
46
+
47
+ @patch ('optillm.plugins.json_plugin.outlines.from_transformers' )
48
+ @patch ('optillm.plugins.json_plugin.AutoTokenizer.from_pretrained' )
49
+ def test_json_generator_init (self , mock_tokenizer , mock_from_transformers ):
50
+ """Test JSONGenerator initialization with new API."""
51
+ # Mock the model and tokenizer
52
+ mock_model = Mock ()
53
+ mock_from_transformers .return_value = mock_model
54
+ mock_tokenizer .return_value = Mock ()
55
+
56
+ # Initialize JSONGenerator
57
+ generator = JSONGenerator ()
58
+
59
+ # Verify initialization
60
+ mock_from_transformers .assert_called_once ()
61
+ mock_tokenizer .assert_called_once ()
62
+ self .assertIsNotNone (generator .model )
63
+ self .assertIsNotNone (generator .tokenizer )
64
+
65
+ @patch ('optillm.plugins.json_plugin.create_model' )
66
+ def test_parse_json_schema_to_pydantic (self , mock_create_model ):
67
+ """Test JSON schema to Pydantic model conversion."""
68
+ # Mock Pydantic model creation
69
+ mock_model_class = Mock ()
70
+ mock_create_model .return_value = mock_model_class
71
+
72
+ # Create generator with mocked dependencies
73
+ generator = JSONGenerator .__new__ (JSONGenerator )
74
+
75
+ # Test simple schema parsing
76
+ result = generator .parse_json_schema_to_pydantic (self .simple_schema )
77
+
78
+ # Verify create_model was called with correct fields
79
+ mock_create_model .assert_called_once ()
80
+ call_args = mock_create_model .call_args
81
+ self .assertEqual (call_args [0 ][0 ], 'DynamicModel' )
82
+
83
+ # Check fields
84
+ fields = call_args [1 ]
85
+ self .assertIn ('name' , fields )
86
+ self .assertIn ('age' , fields )
87
+ self .assertIn ('active' , fields )
88
+
89
+ @patch ('optillm.plugins.json_plugin.outlines.from_transformers' )
90
+ @patch ('optillm.plugins.json_plugin.AutoTokenizer.from_pretrained' )
91
+ def test_generate_json_new_api (self , mock_tokenizer , mock_from_transformers ):
92
+ """Test JSON generation with new outlines API."""
93
+ # Create mock Pydantic instance with model_dump method
94
+ mock_result = Mock ()
95
+ mock_result .model_dump .return_value = {"name" : "Test" , "age" : 25 }
96
+
97
+ # Mock the model to return our result
98
+ mock_model = Mock ()
99
+ mock_model .return_value = mock_result
100
+ mock_from_transformers .return_value = mock_model
101
+
102
+ # Initialize generator
103
+ generator = JSONGenerator ()
104
+
105
+ # Test generation
106
+ prompt = "Create a person named Test who is 25 years old"
107
+ result = generator .generate_json (prompt , self .simple_schema )
108
+
109
+ # Verify the result
110
+ self .assertEqual (result , {"name" : "Test" , "age" : 25 })
111
+ mock_model .assert_called_once ()
112
+
113
+ def test_extract_schema_from_response_format (self ):
114
+ """Test schema extraction from OpenAI response format."""
115
+ # Test with OpenAI format
116
+ response_format = {
117
+ "type" : "json_schema" ,
118
+ "json_schema" : {
119
+ "name" : "test_schema" ,
120
+ "schema" : {
121
+ "type" : "object" ,
122
+ "properties" : {
123
+ "test" : {"type" : "string" }
124
+ }
125
+ }
126
+ }
127
+ }
128
+
129
+ result = extract_schema_from_response_format (response_format )
130
+ self .assertIsNotNone (result )
131
+
132
+ # Verify it's valid JSON
133
+ schema = json .loads (result )
134
+ self .assertEqual (schema ["type" ], "object" )
135
+ self .assertIn ("test" , schema ["properties" ])
136
+
137
+ @patch ('optillm.plugins.json_plugin.JSONGenerator' )
138
+ def test_run_function_with_schema (self , mock_json_generator_class ):
139
+ """Test the main run function with a valid schema."""
140
+ # Mock JSONGenerator instance
141
+ mock_generator = Mock ()
142
+ mock_generator .generate_json .return_value = {"result" : "test" }
143
+ mock_generator .count_tokens .return_value = 10
144
+ mock_json_generator_class .return_value = mock_generator
145
+
146
+ # Mock client
147
+ mock_client = Mock ()
148
+
149
+ # Test configuration
150
+ request_config = {
151
+ "response_format" : {
152
+ "type" : "json_schema" ,
153
+ "json_schema" : {
154
+ "schema" : {
155
+ "type" : "object" ,
156
+ "properties" : {
157
+ "result" : {"type" : "string" }
158
+ }
159
+ }
160
+ }
161
+ }
162
+ }
163
+
164
+ # Run the plugin
165
+ result , tokens = run (
166
+ "System prompt" ,
167
+ "Generate a test result" ,
168
+ mock_client ,
169
+ "test-model" ,
170
+ request_config
171
+ )
172
+
173
+ # Verify results
174
+ self .assertIn ("result" , result )
175
+ self .assertEqual (tokens , 10 )
176
+ mock_generator .generate_json .assert_called_once ()
177
+
178
+ def test_run_function_without_schema (self ):
179
+ """Test the main run function without a schema (fallback)."""
180
+ # Mock client and response
181
+ mock_response = Mock ()
182
+ mock_response .choices = [Mock (message = Mock (content = "Regular response" ))]
183
+ mock_response .usage .completion_tokens = 5
184
+
185
+ mock_client = Mock ()
186
+ mock_client .chat .completions .create .return_value = mock_response
187
+
188
+ # Run without schema
189
+ result , tokens = run (
190
+ "System prompt" ,
191
+ "Test query" ,
192
+ mock_client ,
193
+ "test-model" ,
194
+ {}
195
+ )
196
+
197
+ # Verify fallback behavior
198
+ self .assertEqual (result , "Regular response" )
199
+ self .assertEqual (tokens , 5 )
200
+ mock_client .chat .completions .create .assert_called_once ()
201
+
202
+ @patch ('optillm.plugins.json_plugin.JSONGenerator' )
203
+ def test_error_handling (self , mock_json_generator_class ):
204
+ """Test error handling and fallback."""
205
+ # Mock generator that raises an error
206
+ mock_generator = Mock ()
207
+ mock_generator .generate_json .side_effect = Exception ("Test error" )
208
+ mock_json_generator_class .return_value = mock_generator
209
+
210
+ # Mock client for fallback
211
+ mock_response = Mock ()
212
+ mock_response .choices = [Mock (message = Mock (content = "Fallback response" ))]
213
+ mock_response .usage .completion_tokens = 8
214
+
215
+ mock_client = Mock ()
216
+ mock_client .chat .completions .create .return_value = mock_response
217
+
218
+ # Test configuration with schema
219
+ request_config = {
220
+ "response_format" : {
221
+ "type" : "json_schema" ,
222
+ "json_schema" : {
223
+ "schema" : {"type" : "object" }
224
+ }
225
+ }
226
+ }
227
+
228
+ # Run and expect fallback
229
+ result , tokens = run (
230
+ "System prompt" ,
231
+ "Test query" ,
232
+ mock_client ,
233
+ "test-model" ,
234
+ request_config
235
+ )
236
+
237
+ # Verify fallback was used
238
+ self .assertEqual (result , "Fallback response" )
239
+ self .assertEqual (tokens , 8 )
240
+ mock_client .chat .completions .create .assert_called_once ()
241
+
242
+
243
+ if __name__ == '__main__' :
244
+ unittest .main ()
0 commit comments