From 308066c66357fb088207d759896b3eae6f62ef4d Mon Sep 17 00:00:00 2001 From: Arina Rak Date: Sun, 9 Jun 2024 16:05:54 +0200 Subject: [PATCH] Custom json grammars (#45) * adding schema to ebnf conversion script + documentation * minor format fixes --- README.md | 4 +- .../grammars/custom_json_grammars/README.md | 61 ++ .../grammars/product_catalog.ebnf | 11 + .../grammars/student.ebnf | 16 + .../json_schema_to_grammar.py | 792 ++++++++++++++++++ .../schemas/product_catalog.json | 25 + .../custom_json_grammars/schemas/student.json | 20 + 7 files changed, 928 insertions(+), 1 deletion(-) create mode 100644 examples/grammars/custom_json_grammars/README.md create mode 100644 examples/grammars/custom_json_grammars/grammars/product_catalog.ebnf create mode 100644 examples/grammars/custom_json_grammars/grammars/student.ebnf create mode 100644 examples/grammars/custom_json_grammars/json_schema_to_grammar.py create mode 100644 examples/grammars/custom_json_grammars/schemas/product_catalog.json create mode 100644 examples/grammars/custom_json_grammars/schemas/student.json diff --git a/README.md b/README.md index 0312ce8..faabc14 100644 --- a/README.md +++ b/README.md @@ -142,7 +142,9 @@ More details can be found in this [doc from llama-cpp](https://github.com/ggerga Advanced grammar debugging guide can be found [here](docs/debugging_custom_grammars.md) ### Automatic Grammar Generation -Here is an awesome tool to generate grammars for you: [Grammar Builder](https://grammar.intrinsiclabs.ai/) + +You can use custom grammars to constrain the output of a language model. +Check out the [documentation](examples%2Fgrammars%2Fcustom_json_grammars%2FREADME.md) on json schema to grammar conversion to learn how to automatically create custom grammars for complex json objects. ### Grammar Collection diff --git a/examples/grammars/custom_json_grammars/README.md b/examples/grammars/custom_json_grammars/README.md new file mode 100644 index 0000000..69d8ded --- /dev/null +++ b/examples/grammars/custom_json_grammars/README.md @@ -0,0 +1,61 @@ +# Custom json grammars + +You can use custom grammars to constrain the output of a language model to generate valid json objects. This is useful when you want to generate json objects for specific applications, such as http requests or shopping carts. + +## Quickstart + +There are multiple ways to represent json schemas. +We provide recommendations on how to do this for two common formats: Typescript and json. + +
+ Example of a Typescript schema for a Student object + +```Typescript +interface Student{ + name: string; + age: number; + is_student : boolean; + courses: string[]; +} +``` +
+ +
+ Example of a json schema for a Student object + +```json +{ + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "number"}, + "is_student": {"type": "boolean"}, + "courses": { + "type": "array", + "items": { "type": "string"} + } + } +} +``` +
+ + +### From Typescript + +To generate custom json grammars from Typescript schemas, you can use [this online tool](https://grammar.intrinsiclabs.ai/) or [this Typescript generator](https://github.com/IntrinsicLabsAI/gbnfgen) from Intrinsic AI. Then, simply copy paste the resulting grammar into a text file and use it with the `IncrementalGrammarConstraint`. + + +### From json schemas + +Alternatively, you can generate custom json grammars from json format schemas using the `json_schema_to_grammar.py` script, analogous to [the one in the lama.cpp repository](https://github.com/ggerganov/llama.cpp/blob/ab9a3240a9da941fdef5cd4a25f2b97c2f5a67aa/examples/json_schema_to_grammar.py). + + +To generate a grammar from a json schema, run the following command: + +```bash +python3 json_schema_to_grammar.py -i schemas/product_catalog.json -o grammars/product_catalog.ebnf +``` +This script generates a grammar from a json schema file (see examples of json schemas in `/schemas` and the corresponding grammars in `/grammars`). The generated grammar is in the Extended Backus-Naur Form (EBNF) format and can be directly used with the `IncrementalGrammarConstraint`. + +Additional arguments allow to specify the property order of the json object as well as string formatting parameters. + diff --git a/examples/grammars/custom_json_grammars/grammars/product_catalog.ebnf b/examples/grammars/custom_json_grammars/grammars/product_catalog.ebnf new file mode 100644 index 0000000..787f27a --- /dev/null +++ b/examples/grammars/custom_json_grammars/grammars/product_catalog.ebnf @@ -0,0 +1,11 @@ +char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) +decimal-part ::= [0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)? +integer ::= ("-"? integral-part) space +integral-part ::= [0-9] | [1-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)? +number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space +price-kv ::= "\"price\"" space ":" space number +productId-kv ::= "\"productId\"" space ":" space integer +productName-kv ::= "\"productName\"" space ":" space string +root ::= "{" space productId-kv "," space productName-kv "," space price-kv "}" space +space ::= " "? +string ::= "\"" char* "\"" space diff --git a/examples/grammars/custom_json_grammars/grammars/student.ebnf b/examples/grammars/custom_json_grammars/grammars/student.ebnf new file mode 100644 index 0000000..136493b --- /dev/null +++ b/examples/grammars/custom_json_grammars/grammars/student.ebnf @@ -0,0 +1,16 @@ +age-kv ::= "\"age\"" space ":" space number +age-rest ::= ( "," space courses-kv )? +boolean ::= ("true" | "false") space +char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) +courses ::= "[" space (string ("," space string)*)? "]" space +courses-kv ::= "\"courses\"" space ":" space courses +decimal-part ::= [0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)? +integral-part ::= [0-9] | [1-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)? +is-student-kv ::= "\"is_student\"" space ":" space boolean +is-student-rest ::= ( "," space name-kv )? name-rest +name-kv ::= "\"name\"" space ":" space string +name-rest ::= ( "," space age-kv )? age-rest +number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space +root ::= "{" space (is-student-kv is-student-rest | name-kv name-rest | age-kv age-rest | courses-kv )? "}" space +space ::= " "? +string ::= "\"" char* "\"" space diff --git a/examples/grammars/custom_json_grammars/json_schema_to_grammar.py b/examples/grammars/custom_json_grammars/json_schema_to_grammar.py new file mode 100644 index 0000000..a2562c9 --- /dev/null +++ b/examples/grammars/custom_json_grammars/json_schema_to_grammar.py @@ -0,0 +1,792 @@ +#!/usr/bin/env python3 +import argparse +import itertools +import json +import re +import requests +from typing import Any, List, Set, Tuple, Union + + +# adapted from https://github.com/ggerganov/llama.cpp/blob/ab9a3240a9da941fdef5cd4a25f2b97c2f5a67aa/examples/json_schema_to_grammar.py + + +def _build_repetition( + item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False +): + if not separator_rule: + if min_items == 0 and max_items == 1: + return f"{item_rule}?" + elif min_items == 1 and max_items is None: + return f"{item_rule}+" + + result = "" + + if min_items > 0: + if item_rule_is_literal and separator_rule is None: + result = '"' + (item_rule[1:-1] * min_items) + '"' + else: + result = (f" {separator_rule} " if separator_rule else " ").join( + [item_rule] * min_items + ) + + def opt_repetitions(up_to_n, prefix_with_sep=False): + """ + - n=4, no sep: '(a (a (a (a)?)?)?)?' + - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?' + - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?' + """ + + content = ( + f"{separator_rule} {item_rule}" + if prefix_with_sep and separator_rule + else item_rule + ) + if up_to_n == 0: + return "" + elif up_to_n == 1: + return f"({content})?" + elif separator_rule and not prefix_with_sep: + return f"({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?" + else: + return (f"({content} " * up_to_n).rstrip() + (")?" * up_to_n) + + if min_items > 0 and max_items != min_items: + result += " " + + if max_items is not None: + result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0) + else: + item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})' + + if min_items == 0 and separator_rule: + result = f"({item_rule} {item_operator}*)?" + else: + result += f"{item_operator}*" + + return result + + +class BuiltinRule: + def __init__(self, content: str, deps: list = None): + self.content = content + self.deps = deps or [] + + +_up_to_15_digits = _build_repetition("[0-9]", 0, 15) + +# whitespace is constrained to a single space char to prevent model "running away" in +# whitespace. Also maybe improves generation quality? +SPACE_RULE = '" "?' + +PRIMITIVE_RULES = { + "boolean": BuiltinRule('("true" | "false") space', []), + "decimal-part": BuiltinRule("[0-9] " + _up_to_15_digits, []), + "integral-part": BuiltinRule("[0-9] | [1-9] " + _up_to_15_digits, []), + "number": BuiltinRule( + '("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', + ["integral-part", "decimal-part"], + ), + "integer": BuiltinRule('("-"? integral-part) space', ["integral-part"]), + "value": BuiltinRule( + "object | array | string | number | boolean | null", + ["object", "array", "string", "number", "boolean", "null"], + ), + "object": BuiltinRule( + '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', + ["string", "value"], + ), + "array": BuiltinRule( + '"[" space ( value ("," space value)* )? "]" space', ["value"] + ), + "uuid": BuiltinRule( + r'"\"" ' + + ' "-" '.join("[0-9a-fA-F]" * n for n in [8, 4, 4, 4, 12]) + + r' "\"" space', + [], + ), + "char": BuiltinRule( + r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', + [], + ), + "string": BuiltinRule(r'"\"" char* "\"" space', ["char"]), + "null": BuiltinRule('"null" space', []), +} + +# TODO: support "uri", "email" string formats +STRING_FORMAT_RULES = { + "date": BuiltinRule( + '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )', + [], + ), + "time": BuiltinRule( + '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', + [], + ), + "date-time": BuiltinRule('date "T" time', ["date", "time"]), + "date-string": BuiltinRule('"\\"" date "\\"" space', ["date"]), + "time-string": BuiltinRule('"\\"" time "\\"" space', ["time"]), + "date-time-string": BuiltinRule('"\\"" date-time "\\"" space', ["date-time"]), +} + +DOTALL = "[\\U00000000-\\U0010FFFF]" +DOT = "[^\\x0A\\x0D]" + +RESERVED_NAMES = set( + ["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()] +) + +INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") +GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') +GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') +GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"', "-": "\\-", "]": "\\]"} + +NON_LITERAL_SET = set("|.()[]{}*+?") +ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set("[]()|{}*+?") + + +class SchemaConverter: + def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): + self._prop_order = prop_order + self._allow_fetch = allow_fetch + self._dotall = dotall + self._raw_pattern = raw_pattern + self._rules = { + "space": SPACE_RULE, + } + self._refs = {} + self._refs_being_resolved = set() + + def _format_literal(self, literal): + escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal + ) + return f'"{escaped}"' + + def not_literal( + self, literal: str, dotall: bool = True, maybe_escaped_underscores=False + ) -> str: + """ + not_literal('a') -> '[^a]' + not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' + """ + assert len(literal) > 0, "Empty literal not supported" + + def recurse(i: int): + c = literal[i] + if maybe_escaped_underscores and c == "_": + yield f"[^{c}\\\\]" + yield " | " + yield f'"\\\\"? "{c}"' + else: + yield f"[^{c}]" + if i < len(literal) - 1: + yield " | " + yield self._format_literal(c) + yield " (" + yield from recurse(i + 1) + yield ")?" + + return "".join(("(", *recurse(0), ")")) + + def _add_rule(self, name, rule): + esc_name = INVALID_RULE_CHARS_RE.sub("-", name) + if esc_name not in self._rules or self._rules[esc_name] == rule: + key = esc_name + else: + i = 0 + while ( + f"{esc_name}{i}" in self._rules + and self._rules[f"{esc_name}{i}"] != rule + ): + i += 1 + key = f"{esc_name}{i}" + self._rules[key] = rule + return key + + def resolve_refs(self, schema: dict, url: str): + """ + Resolves all $ref fields in the given schema, fetching any remote schemas, + replacing $ref with absolute reference URL and populating self._refs with the + respective referenced (sub)schema dictionaries. + """ + + def visit(n: dict): + if isinstance(n, list): + return [visit(x) for x in n] + elif isinstance(n, dict): + ref = n.get("$ref") + if ref is not None and ref not in self._refs: + if ref.startswith("https://"): + assert ( + self._allow_fetch + ), "Fetching remote schemas is not allowed (use --allow-fetch for force)" + import requests + + frag_split = ref.split("#") + base_url = frag_split[0] + + target = self._refs.get(base_url) + if target is None: + target = self.resolve_refs( + requests.get(ref).json(), base_url + ) + self._refs[base_url] = target + + if len(frag_split) == 1 or frag_split[-1] == "": + return target + elif ref.startswith("#/"): + target = schema + ref = f"{url}{ref}" + n["$ref"] = ref + else: + raise ValueError(f"Unsupported ref {ref}") + + for sel in ref.split("#")[-1].split("/")[1:]: + assert ( + target is not None and sel in target + ), f"Error resolving ref {ref}: {sel} not in {target}" + target = target[sel] + + self._refs[ref] = target + else: + for v in n.values(): + visit(v) + + return n + + return visit(schema) + + def _generate_union_rule(self, name, alt_schemas): + return " | ".join( + ( + self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') + for i, alt_schema in enumerate(alt_schemas) + ) + ) + + def _visit_pattern(self, pattern, name): + """ + Transforms a regular expression pattern into a GBNF rule. + + Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions + Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + + Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. + + Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which + we define sub-rules to keep the output lean. + """ + + assert pattern.startswith("^") and pattern.endswith( + "$" + ), 'Pattern must start with "^" and end with "$"' + pattern = pattern[1:-1] + sub_rule_ids = {} + + i = 0 + length = len(pattern) + + def to_rule(s: Tuple[str, bool]) -> str: + (txt, is_literal) = s + return '"' + txt + '"' if is_literal else txt + + def transform() -> Tuple[str, bool]: + """ + Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. + """ + nonlocal i + nonlocal pattern + nonlocal sub_rule_ids + + start = i + # For each component of this sequence, store its string representation and whether it's a literal. + # We only need a flat structure here to apply repetition operators to the last item, and + # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially + # (GBNF's syntax is luckily very close to regular expressions!) + seq: list[Tuple[str, bool]] = [] + + def get_dot(): + if self._dotall: + rule = DOTALL + else: + # Accept any character... except \n and \r line break chars (\x0A and \xOD) + rule = DOT + return self._add_rule(f"dot", rule) + + def join_seq(): + nonlocal seq + ret = [] + for is_literal, g in itertools.groupby(seq, lambda x: x[1]): + if is_literal: + ret.append(("".join(x[0] for x in g), True)) + else: + ret.extend(g) + if len(ret) == 1: + return ret[0] + return (" ".join(to_rule(x) for x in seq), False) + + while i < length: + c = pattern[i] + if c == ".": + seq.append((get_dot(), False)) + i += 1 + elif c == "(": + i += 1 + if i < length: + assert ( + pattern[i] != "?" + ), f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' + seq.append((f"({to_rule(transform())})", False)) + elif c == ")": + i += 1 + assert ( + start > 0 and pattern[start - 1] == "(" + ), f"Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}" + return join_seq() + elif c == "[": + square_brackets = c + i += 1 + while i < length and pattern[i] != "]": + if pattern[i] == "\\": + square_brackets += pattern[i : i + 2] + i += 2 + else: + square_brackets += pattern[i] + i += 1 + assert ( + i < length + ), f"Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}" + square_brackets += "]" + i += 1 + seq.append((square_brackets, False)) + elif c == "|": + seq.append(("|", False)) + i += 1 + elif c in ("*", "+", "?"): + seq[-1] = (to_rule(seq[-1]) + c, False) + i += 1 + elif c == "{": + curly_brackets = c + i += 1 + while i < length and pattern[i] != "}": + curly_brackets += pattern[i] + i += 1 + assert ( + i < length + ), f"Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}" + curly_brackets += "}" + i += 1 + nums = [s.strip() for s in curly_brackets[1:-1].split(",")] + min_times = 0 + max_times = None + try: + if len(nums) == 1: + min_times = int(nums[0]) + max_times = min_times + else: + assert len(nums) == 2 + min_times = int(nums[0]) if nums[0] else 0 + max_times = int(nums[1]) if nums[1] else None + except ValueError: + raise ValueError( + f"Invalid quantifier {curly_brackets} in /{pattern}/" + ) + + (sub, sub_is_literal) = seq[-1] + + if not sub_is_literal: + id = sub_rule_ids.get(sub) + if id is None: + id = self._add_rule(f"{name}-{len(sub_rule_ids) + 1}", sub) + sub_rule_ids[sub] = id + sub = id + + seq[-1] = ( + _build_repetition( + f'"{sub}"' if sub_is_literal else sub, + min_times, + max_times, + item_rule_is_literal=sub_is_literal, + ), + False, + ) + else: + literal = "" + while i < length: + if pattern[i] == "\\" and i < length - 1: + next = pattern[i + 1] + if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: + i += 1 + literal += pattern[i] + i += 1 + else: + literal += pattern[i : i + 2] + i += 2 + elif pattern[i] == '"' and not self._raw_pattern: + literal += '\\"' + i += 1 + elif pattern[i] not in NON_LITERAL_SET and ( + i == length - 1 + or literal == "" + or pattern[i + 1] == "." + or pattern[i + 1] not in NON_LITERAL_SET + ): + literal += pattern[i] + i += 1 + else: + break + if literal: + seq.append((literal, True)) + + return join_seq() + + return self._add_rule( + name, + ( + to_rule(transform()) + if self._raw_pattern + else '"\\"" ' + to_rule(transform()) + ' "\\"" space' + ), + ) + + def _resolve_ref(self, ref): + ref_name = ref.split("/")[-1] + if ref_name not in self._rules and ref not in self._refs_being_resolved: + self._refs_being_resolved.add(ref) + resolved = self._refs[ref] + ref_name = self.visit(resolved, ref_name) + self._refs_being_resolved.remove(ref) + return ref_name + + def _generate_constant_rule(self, value): + return self._format_literal(json.dumps(value)) + + def visit(self, schema, name): + schema_type = schema.get("type") + schema_format = schema.get("format") + rule_name = name + "-" if name in RESERVED_NAMES else name or "root" + + ref = schema.get("$ref") + if ref is not None: + return self._add_rule(rule_name, self._resolve_ref(ref)) + + elif "oneOf" in schema or "anyOf" in schema: + return self._add_rule( + rule_name, + self._generate_union_rule(name, schema.get("oneOf") or schema["anyOf"]), + ) + + elif isinstance(schema_type, list): + return self._add_rule( + rule_name, + self._generate_union_rule(name, [{"type": t} for t in schema_type]), + ) + + elif "const" in schema: + return self._add_rule( + rule_name, self._generate_constant_rule(schema["const"]) + ) + + elif "enum" in schema: + rule = " | ".join((self._generate_constant_rule(v) for v in schema["enum"])) + return self._add_rule(rule_name, rule) + + elif schema_type in (None, "object") and ( + "properties" in schema + or ( + "additionalProperties" in schema + and schema["additionalProperties"] is not True + ) + ): + required = set(schema.get("required", [])) + properties = list(schema.get("properties", {}).items()) + return self._add_rule( + rule_name, + self._build_object_rule( + properties, required, name, schema.get("additionalProperties") + ), + ) + + elif schema_type in (None, "object") and "allOf" in schema: + required = set() + properties = [] + hybrid_name = name + + def add_component(comp_schema, is_required): + ref = comp_schema.get("$ref") + if ref is not None: + comp_schema = self._refs[ref] + + if "properties" in comp_schema: + for prop_name, prop_schema in comp_schema["properties"].items(): + properties.append((prop_name, prop_schema)) + if is_required: + required.add(prop_name) + + for t in schema["allOf"]: + if "anyOf" in t: + for tt in t["anyOf"]: + add_component(tt, is_required=False) + else: + add_component(t, is_required=True) + + return self._add_rule( + rule_name, + self._build_object_rule( + properties, required, hybrid_name, additional_properties=[] + ), + ) + + elif schema_type in (None, "array") and ( + "items" in schema or "prefixItems" in schema + ): + items = schema.get("items") or schema["prefixItems"] + if isinstance(items, list): + return self._add_rule( + rule_name, + '"[" space ' + + ' "," space '.join( + self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') + for i, item in enumerate(items) + ) + + ' "]" space', + ) + else: + item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') + min_items = schema.get("minItems", 0) + max_items = schema.get("maxItems") + return self._add_rule( + rule_name, + '"[" space ' + + _build_repetition( + item_rule_name, min_items, max_items, separator_rule='"," space' + ) + + ' "]" space', + ) + + elif schema_type in (None, "string") and "pattern" in schema: + return self._visit_pattern(schema["pattern"], rule_name) + + elif schema_type in (None, "string") and re.match( + r"^uuid[1-5]?$", schema_format or "" + ): + return self._add_primitive( + "root" if rule_name == "root" else schema_format, + PRIMITIVE_RULES["uuid"], + ) + + elif ( + schema_type in (None, "string") + and f"{schema_format}-string" in STRING_FORMAT_RULES + ): + prim_name = f"{schema_format}-string" + return self._add_rule( + rule_name, + self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]), + ) + + elif schema_type == "string" and ( + "minLength" in schema or "maxLength" in schema + ): + char_rule = self._add_primitive("char", PRIMITIVE_RULES["char"]) + min_len = schema.get("minLength", 0) + max_len = schema.get("maxLength") + + return self._add_rule( + rule_name, + r'"\"" ' + + _build_repetition(char_rule, min_len, max_len) + + r' "\"" space', + ) + + elif (schema_type == "object") or (len(schema) == 0): + return self._add_rule( + rule_name, self._add_primitive("object", PRIMITIVE_RULES["object"]) + ) + + else: + assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" + # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero + return self._add_primitive( + "root" if rule_name == "root" else schema_type, + PRIMITIVE_RULES[schema_type], + ) + + def _add_primitive(self, name: str, rule: BuiltinRule): + n = self._add_rule(name, rule.content) + + for dep in rule.deps: + dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) + assert dep_rule, f"Rule {dep} not known" + if dep not in self._rules: + self._add_primitive(dep, dep_rule) + return n + + def _build_object_rule( + self, + properties: List[Tuple[str, Any]], + required: Set[str], + name: str, + additional_properties: Union[bool, Any], + ): + prop_order = self._prop_order + # sort by position in prop_order (if specified) then by original order + sorted_props = [ + kv[0] + for _, kv in sorted( + enumerate(properties), + key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]), + ) + ] + + prop_kv_rule_names = {} + for prop_name, prop_schema in properties: + prop_rule_name = self.visit( + prop_schema, f'{name}{"-" if name else ""}{prop_name}' + ) + prop_kv_rule_names[prop_name] = self._add_rule( + f'{name}{"-" if name else ""}{prop_name}-kv', + rf'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}', + ) + required_props = [k for k in sorted_props if k in required] + optional_props = [k for k in sorted_props if k not in required] + + if additional_properties == True or isinstance(additional_properties, dict): + sub_name = f'{name}{"-" if name else ""}additional' + value_rule = self.visit( + {} if additional_properties == True else additional_properties, + f"{sub_name}-value", + ) + prop_kv_rule_names["*"] = self._add_rule( + f"{sub_name}-kv", + self._add_primitive("string", PRIMITIVE_RULES["string"]) + + f' ":" space {value_rule}', + ) + optional_props.append("*") + + rule = '"{" space ' + rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) + + if optional_props: + rule += " (" + if required_props: + rule += ' "," space ( ' + + def get_recursive_refs(ks, first_is_optional): + [k, *rest] = ks + kv_rule_name = prop_kv_rule_names[k] + if k == "*": + res = self._add_rule( + f'{name}{"-" if name else ""}additional-kvs', + f'{kv_rule_name} ( "," space ' + kv_rule_name + " )*", + ) + elif first_is_optional: + res = f'( "," space {kv_rule_name} )?' + else: + res = kv_rule_name + if len(rest) > 0: + res += " " + self._add_rule( + f'{name}{"-" if name else ""}{k}-rest', + get_recursive_refs(rest, first_is_optional=True), + ) + return res + + rule += " | ".join( + get_recursive_refs(optional_props[i:], first_is_optional=False) + for i in range(len(optional_props)) + ) + if required_props: + rule += " )" + rule += " )?" + + rule += ' "}" space' + + return rule + + def format_grammar(self): + return "\n".join( + f"{name} ::= {rule}" + for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) + ) + + +def parse_args(): + parser = argparse.ArgumentParser( + description=""" + Generates an EBNF grammar that produces JSON conforming to a given JSON schema. + Only a subset of JSON schema features are supported; more may be added in the future. + """, + ) + parser.add_argument( + "-i", + "--input-schema-path", + help="Path or link (if fetching is allowed) to a JSON schema file", + type=str, + required=True, + ) + parser.add_argument( + "-o", + "--output-grammar-path", + help="Path to the corresponding EBNF grammar file", + type=str, + required=True, + ) + parser.add_argument( + "--prop-order", + default=[], + type=lambda s: s.split(","), + help=""" + Comma-separated property names defining the order of precedence for object properties; + properties not specified here are given lower precedence than those that are, and + are kept in their original order from the schema. Required properties are always + given precedence over optional properties. + """, + ) + parser.add_argument( + "--allow-fetch", + action="store_true", + default=False, + help="Whether to allow fetching referenced schemas over HTTPS", + ) + parser.add_argument( + "--dotall", + action="store_true", + default=False, + help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns', + ) + parser.add_argument( + "--raw-pattern", + action="store_true", + default=False, + help="Treats string patterns as raw patterns w/o quotes (or quote escapes)", + ) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.input_schema_path.startswith("https://"): + url = args.input_schema_path + schema = requests.get(url).json() + else: + url = f"file://{args.input_schema_path}" + with open(args.input_schema_path) as f: + schema = json.load(f) + + converter = SchemaConverter( + prop_order={name: idx for idx, name in enumerate(args.prop_order)}, + allow_fetch=args.allow_fetch, + dotall=args.dotall, + raw_pattern=args.raw_pattern, + ) + schema = converter.resolve_refs(schema, url) + + converter.visit(schema, "") + with open(args.output_grammar_path, "w") as f: + f.write(converter.format_grammar()) + + +if __name__ == "__main__": + main() diff --git a/examples/grammars/custom_json_grammars/schemas/product_catalog.json b/examples/grammars/custom_json_grammars/schemas/product_catalog.json new file mode 100644 index 0000000..8ec9bb9 --- /dev/null +++ b/examples/grammars/custom_json_grammars/schemas/product_catalog.json @@ -0,0 +1,25 @@ +{ + "title": "Product", + "description": "A product from my catalog", + "type": "object", + "properties": { + "productId": { + "description": "The unique identifier for a product", + "type": "integer" + }, + "productName": { + "description": "Name of the product", + "type": "string" + }, + "price": { + "description": "The price of the product", + "type": "number", + "exclusiveMinimum": 0 + } + }, + "required": [ + "productId", + "productName", + "price" + ] +} \ No newline at end of file diff --git a/examples/grammars/custom_json_grammars/schemas/student.json b/examples/grammars/custom_json_grammars/schemas/student.json new file mode 100644 index 0000000..9fed468 --- /dev/null +++ b/examples/grammars/custom_json_grammars/schemas/student.json @@ -0,0 +1,20 @@ +{ + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "number" + }, + "is_student": { + "type": "boolean" + }, + "courses": { + "type": "array", + "items": { + "type": "string" + } + } + } +} \ No newline at end of file