From defd3f20562f7f57d0bb5a572c6d4d3cab9d305d Mon Sep 17 00:00:00 2001 From: Nevin Date: Sat, 23 Dec 2023 07:47:21 -0700 Subject: [PATCH 1/6] Allow path in dbapi connect Just as you can pass a `path` argument to a `Session` (enabling persistency), this will allow the `connect` function to have persistency. --- chdb/dbapi/connections.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chdb/dbapi/connections.py b/chdb/dbapi/connections.py index 0d719894fe3..95b1e58a0ac 100644 --- a/chdb/dbapi/connections.py +++ b/chdb/dbapi/connections.py @@ -39,9 +39,9 @@ def __init__(self, cursorclass=Cursor): self.connect() - def connect(self): + def connect(self, path=None): from chdb import session as chs - self._session = chs.Session() + self._session = chs.Session(path) self._closed = False self._execute_command("select 1;") self._read_query_result() From 3852de6fab6fdc7ec7a090b5f4579c201b4641b8 Mon Sep 17 00:00:00 2001 From: Nevin Date: Sat, 23 Dec 2023 08:02:51 -0700 Subject: [PATCH 2/6] Allow path to be passed in Connection --- chdb/dbapi/connections.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chdb/dbapi/connections.py b/chdb/dbapi/connections.py index 95b1e58a0ac..ff709a274aa 100644 --- a/chdb/dbapi/connections.py +++ b/chdb/dbapi/connections.py @@ -17,6 +17,7 @@ class Connection(object): Accepts several arguments: :param cursorclass: Custom cursor class to use. + :param path: Optional folder path to store database files on disk. See `Connection `_ in the specification. @@ -25,7 +26,7 @@ class Connection(object): _closed = False _session = None - def __init__(self, cursorclass=Cursor): + def __init__(self, cursorclass=Cursor, path=None): self._resp = None @@ -37,7 +38,7 @@ def __init__(self, cursorclass=Cursor): self._result = None self._affected_rows = 0 - self.connect() + self.connect(path) def connect(self, path=None): from chdb import session as chs From 7e77689d8acaf0b13d284844c862b3acc8ff37bd Mon Sep 17 00:00:00 2001 From: Nevin Date: Tue, 19 Mar 2024 14:32:06 -0700 Subject: [PATCH 3/6] add test --- tests/test_dbapi_persistence.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/test_dbapi_persistence.py diff --git a/tests/test_dbapi_persistence.py b/tests/test_dbapi_persistence.py new file mode 100644 index 00000000000..5775340e5a5 --- /dev/null +++ b/tests/test_dbapi_persistence.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +import unittest +from chdb import dbapi + +test_state_dir = ".state_tmp_auxten_dbapi" + +class TestDBAPIPersistence(unittest.TestCase): + def test_persistence(self): + conn = dbapi.connect(path=test_state_dir) + cur = con.cursor() + cur.execute("CREATE DATABASE e ENGINE = Atomic;") + cur.execute("CREATE TABLE e.hi (a String primary key, b Int32) Engine = MergeTree ORDER BY a;") + cur.execute("INSERT INTO e.hi (a, b) VALUES (%s, %s);", ["he", 32]) + + cur.close() + conn.close() + + conn2 = dbapi.connect(path=test_state_dir) + cur2 = conn2.cursor() + cur2.execute('SELECT * FROM e.hi;') + row = cur2.fetchone() + self.assertEqual(('he', 32), row) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From a9483867649c57582a14d9737584eb5769e55b31 Mon Sep 17 00:00:00 2001 From: auxten Date: Wed, 20 Mar 2024 13:05:55 +0800 Subject: [PATCH 4/6] Fix typo --- tests/test_dbapi_persistence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_dbapi_persistence.py b/tests/test_dbapi_persistence.py index 5775340e5a5..ebd21dc0c2e 100644 --- a/tests/test_dbapi_persistence.py +++ b/tests/test_dbapi_persistence.py @@ -8,7 +8,7 @@ class TestDBAPIPersistence(unittest.TestCase): def test_persistence(self): conn = dbapi.connect(path=test_state_dir) - cur = con.cursor() + cur = conn.cursor() cur.execute("CREATE DATABASE e ENGINE = Atomic;") cur.execute("CREATE TABLE e.hi (a String primary key, b Int32) Engine = MergeTree ORDER BY a;") cur.execute("INSERT INTO e.hi (a, b) VALUES (%s, %s);", ["he", 32]) @@ -24,4 +24,4 @@ def test_persistence(self): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From a9ca326b2cfb016d4d519f8904c938d017b26fa8 Mon Sep 17 00:00:00 2001 From: auxten Date: Wed, 20 Mar 2024 13:42:49 +0800 Subject: [PATCH 5/6] Fix data link --- tests/test_issue31.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_issue31.py b/tests/test_issue31.py index b5b5733ea4b..e586a3f9353 100644 --- a/tests/test_issue31.py +++ b/tests/test_issue31.py @@ -10,7 +10,7 @@ from timeout_decorator import timeout -csv_url = "https://media.githubusercontent.com/media/datablist/sample-csv-files/main/files/organizations/organizations-2000000.zip" +csv_url = "https://github.com/chdb-io/chdb/files/14662379/organizations-500000.zip" # download csv file, and unzip it @@ -33,7 +33,7 @@ def download_and_extract(url, save_path): def payload(): now = time.time() res = chdb.query( - 'select Name, count(*) cnt from file("organizations-2000000.csv", CSVWithNames) group by Name order by cnt desc, Name asc limit 10000', + 'select Name, count(*) cnt from file("organizations-500000.csv", CSVWithNames) group by Name order by cnt desc, Name asc limit 10000', "CSV", ) # calculate md5 of the result @@ -71,11 +71,11 @@ def handler(signum, frame): class TestAggOnCSVSpeed(unittest.TestCase): def setUp(self): - download_and_extract(csv_url, "organizations-2000000.zip") + download_and_extract(csv_url, "organizations-500000.zip") def tearDown(self): - os.remove("organizations-2000000.csv") - os.remove("organizations-2000000.zip") + os.remove("organizations-500000.csv") + os.remove("organizations-500000.zip") def _test_agg(self, arg=None): payload() From 7aeb1649f28d244ce2997e0fb6149e7f377a97ee Mon Sep 17 00:00:00 2001 From: auxten Date: Wed, 20 Mar 2024 13:55:32 +0800 Subject: [PATCH 6/6] Update test_issue31.py --- tests/test_issue31.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_issue31.py b/tests/test_issue31.py index e586a3f9353..508b43335f4 100644 --- a/tests/test_issue31.py +++ b/tests/test_issue31.py @@ -39,7 +39,7 @@ def payload(): # calculate md5 of the result hash_out = hashlib.md5(res.bytes()).hexdigest() print("output length: ", len(res)) - if hash_out != "423570bd700ba230ccd2b720b7976626": + if hash_out != "626be2713f7a26b266d7160f7172ab43": print(res.bytes().decode("utf-8")) raise Exception(f"md5 not match {hash_out}") used_time = time.time() - now