diff --git a/chdb/dbapi/connections.py b/chdb/dbapi/connections.py index 0a3ae1fcebc..090aa5500f9 100644 --- a/chdb/dbapi/connections.py +++ b/chdb/dbapi/connections.py @@ -17,6 +17,7 @@ class Connection(object): Accepts several arguments: :param cursorclass: Custom cursor class to use. + :param path: Optional folder path to store database files on disk. See `Connection `_ in the specification. @@ -25,7 +26,7 @@ class Connection(object): _closed = False _session = None - def __init__(self, cursorclass=Cursor): + def __init__(self, cursorclass=Cursor, path=None): self._resp = None @@ -37,11 +38,11 @@ def __init__(self, cursorclass=Cursor): self._result = None self._affected_rows = 0 - self.connect() + self.connect(path) - def connect(self): + def connect(self, path=None): from chdb import session as chs - self._session = chs.Session() + self._session = chs.Session(path) self._closed = False self._execute_command("select 1;") self._read_query_result() diff --git a/tests/test_dbapi_persistence.py b/tests/test_dbapi_persistence.py new file mode 100644 index 00000000000..ebd21dc0c2e --- /dev/null +++ b/tests/test_dbapi_persistence.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +import unittest +from chdb import dbapi + +test_state_dir = ".state_tmp_auxten_dbapi" + +class TestDBAPIPersistence(unittest.TestCase): + def test_persistence(self): + conn = dbapi.connect(path=test_state_dir) + cur = conn.cursor() + cur.execute("CREATE DATABASE e ENGINE = Atomic;") + cur.execute("CREATE TABLE e.hi (a String primary key, b Int32) Engine = MergeTree ORDER BY a;") + cur.execute("INSERT INTO e.hi (a, b) VALUES (%s, %s);", ["he", 32]) + + cur.close() + conn.close() + + conn2 = dbapi.connect(path=test_state_dir) + cur2 = conn2.cursor() + cur2.execute('SELECT * FROM e.hi;') + row = cur2.fetchone() + self.assertEqual(('he', 32), row) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_issue31.py b/tests/test_issue31.py index b5b5733ea4b..508b43335f4 100644 --- a/tests/test_issue31.py +++ b/tests/test_issue31.py @@ -10,7 +10,7 @@ from timeout_decorator import timeout -csv_url = "https://media.githubusercontent.com/media/datablist/sample-csv-files/main/files/organizations/organizations-2000000.zip" +csv_url = "https://github.com/chdb-io/chdb/files/14662379/organizations-500000.zip" # download csv file, and unzip it @@ -33,13 +33,13 @@ def download_and_extract(url, save_path): def payload(): now = time.time() res = chdb.query( - 'select Name, count(*) cnt from file("organizations-2000000.csv", CSVWithNames) group by Name order by cnt desc, Name asc limit 10000', + 'select Name, count(*) cnt from file("organizations-500000.csv", CSVWithNames) group by Name order by cnt desc, Name asc limit 10000', "CSV", ) # calculate md5 of the result hash_out = hashlib.md5(res.bytes()).hexdigest() print("output length: ", len(res)) - if hash_out != "423570bd700ba230ccd2b720b7976626": + if hash_out != "626be2713f7a26b266d7160f7172ab43": print(res.bytes().decode("utf-8")) raise Exception(f"md5 not match {hash_out}") used_time = time.time() - now @@ -71,11 +71,11 @@ def handler(signum, frame): class TestAggOnCSVSpeed(unittest.TestCase): def setUp(self): - download_and_extract(csv_url, "organizations-2000000.zip") + download_and_extract(csv_url, "organizations-500000.zip") def tearDown(self): - os.remove("organizations-2000000.csv") - os.remove("organizations-2000000.zip") + os.remove("organizations-500000.csv") + os.remove("organizations-500000.zip") def _test_agg(self, arg=None): payload()