Skip to content

Commit

Permalink
add accept.py in example, which are used to check if a sentence is ac…
Browse files Browse the repository at this point in the history
…cepted by a grammar; fix `geo_query.ebnf` path in test_geo_query.py (#33)
  • Loading branch information
Saibo-creator authored Apr 13, 2024
1 parent 69c3795 commit 9511ec9
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 12 deletions.
54 changes: 54 additions & 0 deletions examples/accept.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
import logging

from transformers_cfg.parser import parse_ebnf
from transformers_cfg.recognizer import StringRecognizer

logging.basicConfig(level=logging.DEBUG)


def main(args):

with open(args.grammar_file_path, "r") as file:
grammar_str = file.read()
parsed_grammar = parse_ebnf(grammar_str)
start_rule_id = parsed_grammar.symbol_table["root"]
recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id)

if args.mode == "prefix":
result = recognizer._accept_prefix(args.sentence)
else:
result = recognizer._accept_string(args.sentence)

print(result)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate text with grammar constraints."
)
parser.add_argument(
"-g",
"--grammar_file_path",
type=str,
required=True,
help="Path to the grammar file (supports both relative and absolute paths)",
)
parser.add_argument(
"-s", "--sentence", type=str, required=True, help="Prefix prompt for generation"
)
parser.add_argument(
"-m",
"--mode",
type=str,
choices=["prefix", "sentence"],
default="prefix",
help="Mode of operation, "
"prefix mode accepts a prefix string, sentence mode only accepts a full sentence",
)

args = parser.parse_args()
main(args)
73 changes: 73 additions & 0 deletions examples/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
import logging

logging.basicConfig(level=logging.DEBUG)


def main(args):

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(args.model_id)

# Load grammar
with open(args.grammar_file_path, "r") as file:
grammar_str = file.read()
grammar = IncrementalGrammarConstraint(
grammar_str, "root", tokenizer, unicode=False
)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

# Generate
prefix = args.prefix_prompt
input_ids = tokenizer(
prefix, add_special_tokens=False, return_tensors="pt", padding=True
)["input_ids"]

output = model.generate(
input_ids,
do_sample=False,
max_new_tokens=20,
logits_processor=[grammar_processor],
repetition_penalty=1.1,
num_return_sequences=1,
)
# decode output
generations = tokenizer.batch_decode(output, skip_special_tokens=True)

print(generations)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate text with grammar constraints."
)
parser.add_argument(
"-m",
"--model_id",
type=str,
required=True,
help="Model identifier for loading the tokenizer and model",
default="gpt2",
)
parser.add_argument(
"-g",
"--grammar_file_path",
type=str,
required=True,
help="Path to the grammar file (supports both relative and absolute paths)",
)
parser.add_argument(
"-p",
"--prefix_prompt",
type=str,
required=True,
help="Prefix prompt for generation",
)

args = parser.parse_args()
main(args)
37 changes: 25 additions & 12 deletions tests/test_string_recognizer/test_geo_query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@

from unittest import TestCase

import logging
from transformers_cfg.parser import parse_ebnf
from transformers_cfg.recognizer import StringRecognizer
from dataclasses import dataclass


@dataclass
class GeoQueryTestCase:
name: str
Expand All @@ -18,10 +18,18 @@ class GeoQueryTestCase:
GeoQueryTestCase("river", "answer(river(all))"),
GeoQueryTestCase("state", "answer(loc_1(major(river(all))))"),
GeoQueryTestCase("next_to_2", "answer(state(next_to_2(stateid('texas'))))"),
GeoQueryTestCase("intersection", "answer(intersection(state(next_to_2(stateid('texas'))), loc_1(major(river(all)))))"),
GeoQueryTestCase(
"intersection",
"answer(intersection(state(next_to_2(stateid('texas'))), loc_1(major(river(all)))))",
),
GeoQueryTestCase("space in name", "answer(population_1(stateid('new york')))"),
GeoQueryTestCase("exclude", "answer(count(exclude(river(all), traverse_2(state(loc_1(capital(cityid('albany', _))))))))"),
GeoQueryTestCase("city_id_with_state", "answer(population_1(cityid('washington', 'dc')))")
GeoQueryTestCase(
"exclude",
"answer(count(exclude(river(all), traverse_2(state(loc_1(capital(cityid('albany', _))))))))",
),
GeoQueryTestCase(
"city_id_with_state", "answer(population_1(cityid('washington', 'dc')))"
),
]

valid_geo_query_prefixes = [
Expand All @@ -35,19 +43,22 @@ class GeoQueryTestCase:
GeoQueryTestCase("unexisting_function", "answer(population_2(stateid('hawaii')))"),
GeoQueryTestCase("empty_operator", "answer(highest(place(loc_2())))"),
GeoQueryTestCase("empty_paranthesis", "()"),
GeoQueryTestCase("missing_argument", "answer(intersection(state(next_to_2(stateid('texas'))), )"),
GeoQueryTestCase(
"missing_argument", "answer(intersection(state(next_to_2(stateid('texas'))), )"
),
]



class Test_parsing_geo_query_object(TestCase):
def setUp(self):
with open(f"examples/grammars/SMILES/geo_query.ebnf", "r") as file:
with open(f"examples/grammars/geo_query.ebnf", "r") as file:
input_text = file.read()
parsed_grammar = parse_ebnf(input_text)
start_rule_id = parsed_grammar.symbol_table["root"]
self.recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id)
print('SetUp successfull!', flush=True)
self.recognizer = StringRecognizer(
parsed_grammar.grammar_encoding, start_rule_id
)
print("SetUp successfull!", flush=True)

def test_valid_sentence(self):

Expand All @@ -57,7 +68,9 @@ def test_valid_sentence(self):
self.recognizer._accept_string(geo_query_test_case.geo_query),
msg=f"Failed on {geo_query_test_case.name}, {geo_query_test_case.geo_query}",
)
for geo_query_test_case in valid_geo_query_prefixes + invalid_geo_query_sentences:
for geo_query_test_case in (
valid_geo_query_prefixes + invalid_geo_query_sentences
):
self.assertEqual(
False,
self.recognizer._accept_string(geo_query_test_case.geo_query),
Expand All @@ -70,11 +83,11 @@ def test_valid_prefixes(self):
True,
self.recognizer._accept_prefix(geo_query_test_case.geo_query),
msg=f"Failed on {geo_query_test_case.name}, {geo_query_test_case.geo_query}",
)
)

for geo_query_test_case in invalid_geo_query_sentences:
self.assertEqual(
False,
self.recognizer._accept_prefix(geo_query_test_case.geo_query),
msg=f"Failed on {geo_query_test_case.name}, {geo_query_test_case.geo_query}"
msg=f"Failed on {geo_query_test_case.name}, {geo_query_test_case.geo_query}",
)

0 comments on commit 9511ec9

Please sign in to comment.