Skip to content

Commit

Permalink
feat: Bump psycopg2 to psycopg3 for all Postgres components (feast-de…
Browse files Browse the repository at this point in the history
…v#4303)

* Makefile: Formatting

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Makefile: Exclude Snowflake tests for postgres offline store tests

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Bootstrap: Use conninfo

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Tests: Make connection string compatible with psycopg3

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Tests: Test connection type pool and singleton

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Global: Replace conn.set_session() calls to be psycopg3 compatible

Set connection read only

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Offline: Use psycopg3

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Online: Use psycopg3

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Online: Restructure online_write_batch

Addition

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Online: Use correct placeholder

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Online: Handle bytes properly in online_read()

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Online: Whitespace

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Online: Open ConnectionPool

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Online: Add typehint

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Utils: Use psycopg3

Use new ConnectionPool

Pass kwargs as named argument

Use executemany over execute_values

Remove not-required open argument in psycopg.connect

Improve

Use SpooledTemporaryFile

Use max_size and add docstring

Properly write with StringIO

Utils: Use SpooledTemporaryFile over StringIO object

Add replace

Fix df_to_postgres_table

Remove import

Utils

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Lint: Raise exceptions if cursor returned no columns or rows

Add log statement

Lint: Fix _to_arrow_internal

Lint: Fix _get_entity_df_event_timestamp_range

Update exception

Use ZeroColumnQueryResult

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Add comment on +psycopg string

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Docs: Remove mention of psycopg2

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Lint: Fix

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Default to postgresql+psycopg and log warning

Update warning

Fix

Format warning

Add typehints

Use better variable name

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

* Solve merge conflicts

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>

---------

Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>
  • Loading branch information
job-almekinders authored Jul 1, 2024
1 parent 43e198f commit 9451d9c
Show file tree
Hide file tree
Showing 18 changed files with 925 additions and 408 deletions.
7 changes: 4 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ install-python:
python setup.py develop

lock-python-dependencies:
uv pip compile --system --no-strip-extras setup.py --output-file sdk/python/requirements/py$(PYTHON)-requirements.txt
uv pip compile --system --no-strip-extras setup.py --output-file sdk/python/requirements/py$(PYTHON)-requirements.txt

lock-python-dependencies-all:
pixi run --environment py39 --manifest-path infra/scripts/pixi/pixi.toml "uv pip compile --system --no-strip-extras setup.py --output-file sdk/python/requirements/py3.9-requirements.txt"
Expand Down Expand Up @@ -164,7 +164,7 @@ test-python-universal-mssql:
sdk/python/tests


# To use Athena as an offline store, you need to create an Athena database and an S3 bucket on AWS.
# To use Athena as an offline store, you need to create an Athena database and an S3 bucket on AWS.
# https://docs.aws.amazon.com/athena/latest/ug/getting-started.html
# Modify environment variables ATHENA_REGION, ATHENA_DATA_SOURCE, ATHENA_DATABASE, ATHENA_WORKGROUP or
# ATHENA_S3_BUCKET_NAME according to your needs. If tests fail with the pytest -n 8 option, change the number to 1.
Expand All @@ -191,7 +191,7 @@ test-python-universal-athena:
not s3_registry and \
not test_snowflake" \
sdk/python/tests

test-python-universal-postgres-offline:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.postgres_repo_configuration \
Expand All @@ -209,6 +209,7 @@ test-python-universal-postgres-offline:
not test_push_features_to_offline_store and \
not gcs_registry and \
not s3_registry and \
not test_snowflake and \
not test_universal_types" \
sdk/python/tests

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/using-scalable-registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ When this happens, your database is likely using what is referred to as an
in `SQLAlchemy` terminology. See your database's documentation for examples on
how to set its scheme in the Database URL.

`Psycopg2`, which is the database library leveraged by the online and offline
`Psycopg`, which is the database library leveraged by the online and offline
stores, is not impacted by the need to speak a particular dialect, and so the
following only applies to the registry.

Expand Down
10 changes: 10 additions & 0 deletions sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,13 @@ def __init__(self, input_dict: dict):
super().__init__(
f"Failed to serialize the provided dictionary into a pandas DataFrame: {input_dict.keys()}"
)


class ZeroRowsQueryResult(Exception):
def __init__(self, query: str):
super().__init__(f"This query returned zero rows:\n{query}")


class ZeroColumnQueryResult(Exception):
def __init__(self, query: str):
super().__init__(f"This query returned zero columns:\n{query}")
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import pandas as pd
import pyarrow as pa
from jinja2 import BaseLoader, Environment
from psycopg2 import sql
from psycopg import sql
from pytz import utc

from feast.data_source import DataSource
from feast.errors import InvalidEntityType
from feast.errors import InvalidEntityType, ZeroColumnQueryResult, ZeroRowsQueryResult
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores import offline_utils
from feast.infra.offline_stores.contrib.postgres_offline_store.postgres_source import (
Expand Down Expand Up @@ -274,8 +274,10 @@ def to_sql(self) -> str:
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
with self._query_generator() as query:
with _get_conn(self.config.offline_store) as conn, conn.cursor() as cur:
conn.set_session(readonly=True)
conn.read_only = True
cur.execute(query)
if not cur.description:
raise ZeroColumnQueryResult(query)
fields = [
(c.name, pg_type_code_to_arrow(c.type_code))
for c in cur.description
Expand Down Expand Up @@ -331,16 +333,19 @@ def _get_entity_df_event_timestamp_range(
entity_df_event_timestamp.max().to_pydatetime(),
)
elif isinstance(entity_df, str):
# If the entity_df is a string (SQL query), determine range
# from table
# If the entity_df is a string (SQL query), determine range from table
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
(
cur.execute(
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max FROM ({entity_df}) as tmp_alias"
),
)
query = f"""
SELECT
MIN({entity_df_event_timestamp_col}) AS min,
MAX({entity_df_event_timestamp_col}) AS max
FROM ({entity_df}) AS tmp_alias
"""
cur.execute(query)
res = cur.fetchone()
entity_df_event_timestamp_range = (res[0], res[1])
if not res:
raise ZeroRowsQueryResult(query)
entity_df_event_timestamp_range = (res[0], res[1])
else:
raise InvalidEntityType(type(entity_df))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typeguard import typechecked

from feast.data_source import DataSource
from feast.errors import DataSourceNoNameException
from feast.errors import DataSourceNoNameException, ZeroColumnQueryResult
from feast.infra.utils.postgres.connection_utils import _get_conn
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.protos.feast.core.SavedDataset_pb2 import (
Expand Down Expand Up @@ -111,7 +111,11 @@ def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
cur.execute(f"SELECT * FROM {self.get_table_query_string()} AS sub LIMIT 0")
query = f"SELECT * FROM {self.get_table_query_string()} AS sub LIMIT 0"
cur.execute(query)
if not cur.description:
raise ZeroColumnQueryResult(query)

return (
(c.name, pg_type_code_to_pg_type(c.type_code)) for c in cur.description
)
Expand Down
120 changes: 66 additions & 54 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@
import logging
from collections import defaultdict
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Literal,
Optional,
Sequence,
Tuple,
)

import psycopg2
import pytz
from psycopg2 import sql
from psycopg2.extras import execute_values
from psycopg2.pool import SimpleConnectionPool
from psycopg import sql
from psycopg.connection import Connection
from psycopg_pool import ConnectionPool

from feast import Entity
from feast.feature_view import FeatureView
Expand Down Expand Up @@ -39,15 +48,17 @@ class PostgreSQLOnlineStoreConfig(PostgreSQLConfig):


class PostgreSQLOnlineStore(OnlineStore):
_conn: Optional[psycopg2._psycopg.connection] = None
_conn_pool: Optional[SimpleConnectionPool] = None
_conn: Optional[Connection] = None
_conn_pool: Optional[ConnectionPool] = None

@contextlib.contextmanager
def _get_conn(self, config: RepoConfig):
def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
assert config.online_store.type == "postgres"

if config.online_store.conn_type == ConnectionType.pool:
if not self._conn_pool:
self._conn_pool = _get_connection_pool(config.online_store)
self._conn_pool.open()
connection = self._conn_pool.getconn()
yield connection
self._conn_pool.putconn(connection)
Expand All @@ -64,57 +75,56 @@ def online_write_batch(
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
batch_size: int = 5000,
) -> None:
project = config.project
# Format insert values
insert_values = []
for entity_key, values, timestamp, created_ts in data:
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
timestamp = _to_naive_utc(timestamp)
if created_ts is not None:
created_ts = _to_naive_utc(created_ts)

with self._get_conn(config) as conn, conn.cursor() as cur:
insert_values = []
for entity_key, values, timestamp, created_ts in data:
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
timestamp = _to_naive_utc(timestamp)
if created_ts is not None:
created_ts = _to_naive_utc(created_ts)

for feature_name, val in values.items():
vector_val = None
if config.online_store.pgvector_enabled:
vector_val = get_list_val_str(val)
insert_values.append(
(
entity_key_bin,
feature_name,
val.SerializeToString(),
vector_val,
timestamp,
created_ts,
)
for feature_name, val in values.items():
vector_val = None
if config.online_store.pgvector_enabled:
vector_val = get_list_val_str(val)
insert_values.append(
(
entity_key_bin,
feature_name,
val.SerializeToString(),
vector_val,
timestamp,
created_ts,
)
# Control the batch so that we can update the progress
batch_size = 5000
)

# Create insert query
sql_query = sql.SQL(
"""
INSERT INTO {}
(entity_key, feature_name, value, vector_value, event_ts, created_ts)
VALUES (%s, %s, %s, %s, %s, %s)
ON CONFLICT (entity_key, feature_name) DO
UPDATE SET
value = EXCLUDED.value,
vector_value = EXCLUDED.vector_value,
event_ts = EXCLUDED.event_ts,
created_ts = EXCLUDED.created_ts;
"""
).format(sql.Identifier(_table_id(config.project, table)))

# Push data in batches to online store
with self._get_conn(config) as conn, conn.cursor() as cur:
for i in range(0, len(insert_values), batch_size):
cur_batch = insert_values[i : i + batch_size]
execute_values(
cur,
sql.SQL(
"""
INSERT INTO {}
(entity_key, feature_name, value, vector_value, event_ts, created_ts)
VALUES %s
ON CONFLICT (entity_key, feature_name) DO
UPDATE SET
value = EXCLUDED.value,
vector_value = EXCLUDED.vector_value,
event_ts = EXCLUDED.event_ts,
created_ts = EXCLUDED.created_ts;
""",
).format(sql.Identifier(_table_id(project, table))),
cur_batch,
page_size=batch_size,
)
cur.executemany(sql_query, cur_batch)
conn.commit()

if progress:
progress(len(cur_batch))

Expand Down Expand Up @@ -172,7 +182,9 @@ def online_read(
# when we iterate through the keys since they are in the correct order
values_dict = defaultdict(list)
for row in rows if rows is not None else []:
values_dict[row[0].tobytes()].append(row[1:])
values_dict[
row[0] if isinstance(row[0], bytes) else row[0].tobytes()
].append(row[1:])

for key in keys:
if key in values_dict:
Expand Down
Loading

0 comments on commit 9451d9c

Please sign in to comment.