| 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | # SPDX-License-Identifier: Apache-2.0 |
| 3 | # /// script |
| 4 | # requires-python = ">=3.10" |
| 5 | # dependencies = [ |
| 6 | # "data-designer", |
| 7 | # ] |
| 8 | # /// |
| 9 | """Text-to-SQL Code Generation Recipe |
| 10 | |
| 11 | Generate synthetic instruction-SQL pairs for database tasks across different |
| 12 | industries, complexity levels, and SQL concepts. Each record includes an |
| 13 | instruction, database context (CREATE TABLE + sample data), generated SQL query, |
| 14 | code validation, and judge evaluation. |
| 15 | |
| 16 | Prerequisites: |
| 17 | - OPENAI_API_KEY environment variable for OpenAI provider model aliases (default model alias is "openai-text"). |
| 18 | - NVIDIA_API_KEY environment variable for NVIDIA provider model aliases. |
| 19 | |
| 20 | Run: |
| 21 | # Basic usage (generates 5 records by default) |
| 22 | uv run text_to_sql.py |
| 23 | |
| 24 | # For help message and available options |
| 25 | uv run text_to_sql.py --help |
| 26 | """ |
| 27 | |
| 28 | from pathlib import Path |
| 29 | |
| 30 | import data_designer.config as dd |
| 31 | from data_designer.interface import DataDesigner, DatasetCreationResults |
| 32 | |
| 33 | |
| 34 | def build_config(model_alias: str) -> dd.DataDesignerConfigBuilder: |
| 35 | config_builder = dd.DataDesignerConfigBuilder() |
| 36 | |
| 37 | config_builder.add_column( |
| 38 | dd.SamplerColumnConfig( |
| 39 | name="industry_sector", |
| 40 | sampler_type=dd.SamplerType.CATEGORY, |
| 41 | params=dd.CategorySamplerParams( |
| 42 | values=["Healthcare", "Finance", "Technology"], |
| 43 | ), |
| 44 | ) |
| 45 | ) |
| 46 | |
| 47 | config_builder.add_column( |
| 48 | dd.SamplerColumnConfig( |
| 49 | name="topic", |
| 50 | sampler_type=dd.SamplerType.SUBCATEGORY, |
| 51 | params=dd.SubcategorySamplerParams( |
| 52 | category="industry_sector", |
| 53 | values={ |
| 54 | "Healthcare": [ |
| 55 | "Electronic Health Records (EHR) Systems", |
| 56 | "Telemedicine Platforms", |
| 57 | "AI-Powered Diagnostic Tools", |
| 58 | ], |
| 59 | "Finance": [ |
| 60 | "Fraud Detection Software", |
| 61 | "Automated Trading Systems", |
| 62 | "Personal Finance Apps", |
| 63 | ], |
| 64 | "Technology": [ |
| 65 | "Cloud Computing Platforms", |
| 66 | "Artificial Intelligence and Machine Learning Platforms", |
| 67 | "DevOps and CI/CD Tools", |
| 68 | ], |
| 69 | }, |
| 70 | ), |
| 71 | ) |
| 72 | ) |
| 73 | |
| 74 | config_builder.add_column( |
| 75 | dd.SamplerColumnConfig( |
| 76 | name="sql_complexity", |
| 77 | sampler_type=dd.SamplerType.CATEGORY, |
| 78 | params=dd.CategorySamplerParams( |
| 79 | values=["Beginner", "Intermediate", "Advanced"], |
| 80 | ), |
| 81 | ) |
| 82 | ) |
| 83 | |
| 84 | config_builder.add_column( |
| 85 | dd.SamplerColumnConfig( |
| 86 | name="sql_concept", |
| 87 | sampler_type=dd.SamplerType.SUBCATEGORY, |
| 88 | params=dd.SubcategorySamplerParams( |
| 89 | category="sql_complexity", |
| 90 | values={ |
| 91 | "Beginner": [ |
| 92 | "Basic SELECT Statements", |
| 93 | "WHERE Clauses", |
| 94 | "Basic JOINs", |
| 95 | "INSERT, UPDATE, DELETE", |
| 96 | ], |
| 97 | "Intermediate": [ |
| 98 | "Aggregation Functions", |
| 99 | "Multiple JOINs", |
| 100 | "Subqueries", |
| 101 | "Views", |
| 102 | ], |
| 103 | "Advanced": [ |
| 104 | "Window Functions", |
| 105 | "Common Table Expressions (CTEs)", |
| 106 | "Stored Procedures", |
| 107 | "Query Optimization", |
| 108 | ], |
| 109 | }, |
| 110 | ), |
| 111 | ) |
| 112 | ) |
| 113 | |
| 114 | config_builder.add_column( |
| 115 | dd.SamplerColumnConfig( |
| 116 | name="sql_task_type", |
| 117 | sampler_type=dd.SamplerType.CATEGORY, |
| 118 | params=dd.CategorySamplerParams( |
| 119 | values=[ |
| 120 | "Data Retrieval", |
| 121 | "Data Manipulation", |
| 122 | "Analytics and Reporting", |
| 123 | "Data Transformation", |
| 124 | ], |
| 125 | ), |
| 126 | ) |
| 127 | ) |
| 128 | |
| 129 | config_builder.add_column( |
| 130 | dd.SamplerColumnConfig( |
| 131 | name="instruction_phrase", |
| 132 | sampler_type=dd.SamplerType.CATEGORY, |
| 133 | params=dd.CategorySamplerParams( |
| 134 | values=[ |
| 135 | "Write an SQL query that", |
| 136 | "Create an SQL statement to", |
| 137 | "Develop an SQL query to", |
| 138 | "Can you write SQL that", |
| 139 | "Formulate an SQL query that", |
| 140 | ], |
| 141 | ), |
| 142 | ) |
| 143 | ) |
| 144 | |
| 145 | config_builder.add_column( |
| 146 | dd.LLMTextColumnConfig( |
| 147 | name="sql_prompt", |
| 148 | model_alias=model_alias, |
| 149 | system_prompt="You are an expert at generating clear and specific SQL tasks.", |
| 150 | prompt=SQL_PROMPT_TEXT, |
| 151 | ) |
| 152 | ) |
| 153 | |
| 154 | config_builder.add_column( |
| 155 | dd.LLMCodeColumnConfig( |
| 156 | name="sql_context", |
| 157 | model_alias=model_alias, |
| 158 | code_lang=dd.CodeLang.SQL_ANSI, |
| 159 | system_prompt=( |
| 160 | "You are an expert SQL database designer who creates clean, efficient, and " |
| 161 | "well-structured database schemas." |
| 162 | ), |
| 163 | prompt=SQL_CONTEXT_TEXT, |
| 164 | ) |
| 165 | ) |
| 166 | |
| 167 | config_builder.add_column( |
| 168 | dd.LLMCodeColumnConfig( |
| 169 | name="sql", |
| 170 | model_alias=model_alias, |
| 171 | code_lang=dd.CodeLang.SQL_ANSI, |
| 172 | system_prompt="You are an expert SQL programmer who writes clean, efficient, and well-structured queries.", |
| 173 | prompt=SQL_CODE_TEXT, |
| 174 | ) |
| 175 | ) |
| 176 | |
| 177 | config_builder.add_column( |
| 178 | dd.ValidationColumnConfig( |
| 179 | name="code_validity_result", |
| 180 | validator_type=dd.ValidatorType.CODE, |
| 181 | target_columns=["sql"], |
| 182 | validator_params=dd.CodeValidatorParams( |
| 183 | code_lang=dd.CodeLang.SQL_ANSI, |
| 184 | ), |
| 185 | batch_size=100, |
| 186 | ) |
| 187 | ) |
| 188 | |
| 189 | config_builder.add_column( |
| 190 | dd.LLMJudgeColumnConfig( |
| 191 | name="code_judge_result", |
| 192 | model_alias=model_alias, |
| 193 | prompt=TEXT_TO_SQL_JUDGE_TEMPLATE, |
| 194 | scores=sql_scoring, |
| 195 | ) |
| 196 | ) |
| 197 | |
| 198 | return config_builder |
| 199 | |
| 200 | |
| 201 | def create_dataset( |
| 202 | config_builder: dd.DataDesignerConfigBuilder, |
| 203 | num_records: int, |
| 204 | artifact_path: Path | str | None = None, |
| 205 | ) -> DatasetCreationResults: |
| 206 | data_designer = DataDesigner(artifact_path=artifact_path) |
| 207 | results = data_designer.create(config_builder, num_records=num_records) |
| 208 | return results |
| 209 | |
| 210 | |
| 211 | SQL_PROMPT_TEXT = ( |
| 212 | "Generate an instruction to create SQL code that solves a specific problem.\n" |
| 213 | "Each instruction should begin with one of the following phrases: {{instruction_phrase}}.\n\n" |
| 214 | "Important Guidelines:\n" |
| 215 | "* Industry Relevance: Ensure the instruction pertains to the {{industry_sector}} sector and {{topic}} topic.\n" |
| 216 | "* SQL Complexity: Tailor the instruction to the {{sql_complexity}} level. Utilize relevant {{sql_concept}} " |
| 217 | "where appropriate to match the complexity level.\n" |
| 218 | "* Task Type: The instruction should involve a {{sql_task_type}} task.\n" |
| 219 | "* Clarity and Specificity: Make the problem statement clear and unambiguous. Provide sufficient context to " |
| 220 | "understand the requirements without being overly verbose.\n" |
| 221 | "* Response Formatting: Do not include any markers such as ### Response ### in the instruction.\n" |
| 222 | ) |
| 223 | |
| 224 | SQL_CONTEXT_TEXT = ( |
| 225 | "Generate the SQL for creating database tables that would be relevant for the following instruction:\n" |
| 226 | "Instruction: {{sql_prompt}}\n\n" |
| 227 | "Important Guidelines:\n" |
| 228 | "* Relevance: Ensure all tables are directly related to the {{industry_sector}} sector and {{topic}} topic.\n" |
| 229 | "* Completeness: Include all essential columns with appropriate data types, primary/foreign keys, and necessary constraints.\n" |
| 230 | "* Realism: Use realistic table structures typical for the specified industry.\n" |
| 231 | "* Executable SQL: Provide complete CREATE TABLE statements that can be run without modification.\n" |
| 232 | "* Consistency: Use consistent naming conventions (e.g., snake_case for table and column names).\n" |
| 233 | "* Sample Data: Include INSERT statements with sample data that makes sense for the tables (at least 5-10 rows per table)." |
| 234 | ) |
| 235 | |
| 236 | SQL_CODE_TEXT = ( |
| 237 | "Write SQL code for the following instruction based on the provided database context:\n" |
| 238 | "Instruction: {{sql_prompt}}\n\n" |
| 239 | "Database Context:\n" |
| 240 | "{{sql_context}}\n\n" |
| 241 | "Important Guidelines:\n" |
| 242 | "* Code Quality: Your SQL should be clean, complete, self-contained and accurate.\n" |
| 243 | "* Code Validity: Please ensure that your SQL code is executable and does not contain any errors.\n" |
| 244 | "* Context: Base your query on the provided database context. Only reference tables and columns that " |
| 245 | "exist in the context.\n" |
| 246 | "* Complexity & Concepts: The SQL should be written at a {{sql_complexity}} level, making use of " |
| 247 | "concepts such as {{sql_concept}}.\n" |
| 248 | "* Task Type: Ensure your solution implements the appropriate {{sql_task_type}} operation.\n" |
| 249 | "* Comments: Include brief comments explaining the key parts of your query.\n" |
| 250 | ) |
| 251 | |
| 252 | |
| 253 | TEXT_TO_SQL_JUDGE_TEMPLATE = """\ |
| 254 | You are an expert in SQL with deep knowledge of relational modeling, query semantics, |
| 255 | and performance tuning across common dialects (e.g., PostgreSQL, MySQL, SQLite, SQL Server). |
| 256 | You think critically about correctness, readability, and efficiency. |
| 257 | |
| 258 | Use the SQL Query Quality Rubric below to score the **Generated SQL Query** based on the INSTRUCTIONS. |
| 259 | |
| 260 | #### INSTRUCTIONS |
| 261 | The Generated SQL Query should be a valid response to the Natural Language Prompt below |
| 262 | |
| 263 | Natural Language Prompt: |
| 264 | {{ sql_prompt }} |
| 265 | |
| 266 | Database Context: |
| 267 | {{ sql_context }} |
| 268 | |
| 269 | Generated SQL Query |
| 270 | {{ sql }} |
| 271 | """ |
| 272 | |
| 273 | |
| 274 | sql_scoring = [ |
| 275 | dd.Score( |
| 276 | name="Relevance", |
| 277 | description="Adherence to INSTRUCTIONS and CONTEXT", |
| 278 | options={ |
| 279 | 4: "Perfectly meets all specified requirements.", |
| 280 | 3: "Meets most requirements with minor deviations.", |
| 281 | 2: "Moderate deviation from the instructions.", |
| 282 | 1: "Significant deviations from the instructions.", |
| 283 | 0: "Does not adhere to the instructions.", |
| 284 | }, |
| 285 | ), |
| 286 | dd.Score( |
| 287 | name="SQL Correctness", |
| 288 | description="Syntax and semantic correctness; returns the intended result", |
| 289 | options={ |
| 290 | 4: "Valid SQL with correct joins, filters, grouping/aggregation, and NULL handling; produces the intended result set under the stated/implicit dialect.", |
| 291 | 3: "Generally correct with minor issues (e.g., edge-case NULLs, minor grouping detail) but still likely yields the intended result.", |
| 292 | 2: "Partially correct; noticeable semantic mistakes (joins, grouping, filters) that may change results or fail in edge cases.", |
| 293 | 1: "Largely incorrect; major semantic or syntactic errors likely causing failure or wrong results.", |
| 294 | 0: "Invalid SQL or unrelated to the task; will not run or cannot produce a meaningful result.", |
| 295 | }, |
| 296 | ), |
| 297 | dd.Score( |
| 298 | name="Readability", |
| 299 | description="Formatting, clarity, and maintainability", |
| 300 | options={ |
| 301 | 4: "Cleanly formatted (keywords/clauses consistently styled), clear structure (CTEs/subqueries where helpful), meaningful table/column aliases, and concise.", |
| 302 | 3: "Generally readable with consistent formatting and understandable aliases; could be organized slightly better.", |
| 303 | 2: "Somewhat readable but inconsistent formatting or confusing aliasing; structure is harder to follow.", |
| 304 | 1: "Poorly formatted and hard to read; unclear structure and aliasing.", |
| 305 | 0: "Unreadable or chaotic; no meaningful structure or styling.", |
| 306 | }, |
| 307 | ), |
| 308 | dd.Score( |
| 309 | name="Efficiency", |
| 310 | description="Query performance best practices", |
| 311 | options={ |
| 312 | 4: "Uses sargable predicates, appropriate joins, selective filters early, avoids SELECT *, unnecessary DISTINCT, and wasteful subqueries; likely to use indexes effectively.", |
| 313 | 3: "Mostly efficient; minor opportunities for improvement (e.g., simplifying expressions, reducing data early).", |
| 314 | 2: "Moderate inefficiencies (e.g., non-sargable filters, unnecessary nested subqueries, broad SELECT *).", |
| 315 | 1: "Notably inefficient patterns likely causing large scans or poor plans.", |
| 316 | 0: "Highly inefficient; ignores basic best practices and likely to perform very poorly.", |
| 317 | }, |
| 318 | ), |
| 319 | ] |
| 320 | |
| 321 | if __name__ == "__main__": |
| 322 | from argparse import ArgumentParser |
| 323 | |
| 324 | parser = ArgumentParser() |
| 325 | parser.add_argument("--model-alias", type=str, default="openai-text") |
| 326 | parser.add_argument("--num-records", type=int, default=5) |
| 327 | parser.add_argument("--artifact-path", type=str, default=None) |
| 328 | args = parser.parse_args() |
| 329 | |
| 330 | config_builder = build_config(model_alias=args.model_alias) |
| 331 | results = create_dataset(config_builder, num_records=args.num_records, artifact_path=args.artifact_path) |
| 332 | |
| 333 | print(f"Dataset saved to: {results.artifact_storage.final_dataset_path}") |
| 334 | |
| 335 | results.load_analysis().to_report() |