-
Notifications
You must be signed in to change notification settings - Fork 112
/
mysql.py
200 lines (152 loc) · 6.71 KB
/
mysql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""This module defines specific functions for MySQL dialect."""
from sqlalchemy import text
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.sqltypes import NullType
from geoalchemy2 import functions
from geoalchemy2.admin.dialects.common import _check_spatial_type
from geoalchemy2.admin.dialects.common import _spatial_idx_name
from geoalchemy2.admin.dialects.common import setup_create_drop
from geoalchemy2.types import Geography
from geoalchemy2.types import Geometry
_POSSIBLE_TYPES = [
"geometry",
"point",
"linestring",
"polygon",
"multipoint",
"multilinestring",
"multipolygon",
"geometrycollection",
]
def reflect_geometry_column(inspector, table, column_info):
"""Reflect a column of type Geometry with Postgresql dialect."""
if not isinstance(column_info.get("type"), (Geometry, NullType)):
return
column_name = column_info.get("name")
schema = table.schema or inspector.default_schema_name
# Check geometry type, SRID and if the column is nullable
geometry_type_query = """SELECT DATA_TYPE, SRS_ID, IS_NULLABLE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format(
table.name, column_name
)
if schema is not None:
geometry_type_query += """ and table_schema = '{}'""".format(schema)
geometry_type, srid, nullable_str = inspector.bind.execute(text(geometry_type_query)).one()
is_nullable = str(nullable_str).lower() == "yes"
if geometry_type not in _POSSIBLE_TYPES:
return
# Check if the column has spatial index
has_index_query = """SELECT DISTINCT
INDEX_TYPE
FROM INFORMATION_SCHEMA.STATISTICS
WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format(
table.name, column_name
)
if schema is not None:
has_index_query += """ and TABLE_SCHEMA = '{}'""".format(schema)
spatial_index_res = inspector.bind.execute(text(has_index_query)).scalar()
spatial_index = str(spatial_index_res).lower() == "spatial"
# Set attributes
column_info["type"] = Geometry(
geometry_type=geometry_type.upper(),
srid=srid,
spatial_index=spatial_index,
nullable=is_nullable,
_spatial_index_reflected=True,
)
def before_create(table, bind, **kw):
"""Handle spatial indexes during the before_create event."""
dialect, gis_cols, regular_cols = setup_create_drop(table, bind)
# Remove the spatial indexes from the table metadata because they should not be
# created during the table.create() step since the associated columns do not exist
# at this time.
table.info["_after_create_indexes"] = []
current_indexes = set(table.indexes)
for idx in current_indexes:
for col in table.info["_saved_columns"]:
if (_check_spatial_type(col.type, Geometry, dialect)) and col in idx.columns.values():
table.indexes.remove(idx)
if idx.name != _spatial_idx_name(table.name, col.name) or not getattr(
col.type, "spatial_index", False
):
table.info["_after_create_indexes"].append(idx)
table.columns = table.info.pop("_saved_columns")
def after_create(table, bind, **kw):
"""Handle spatial indexes during the after_create event."""
# Restore original column list including managed Geometry columns
dialect = bind.dialect
# table.columns = table.info.pop("_saved_columns")
for col in table.columns:
# Add spatial indices for the Geometry and Geography columns
if (
_check_spatial_type(col.type, (Geometry, Geography), dialect)
and col.type.spatial_index is True
):
# If the index does not exist, define it and create it
if not [i for i in table.indexes if col in i.columns.values()]:
sql = "ALTER TABLE {} ADD SPATIAL INDEX({});".format(table.name, col.name)
q = text(sql)
bind.execute(q)
for idx in table.info.pop("_after_create_indexes"):
table.indexes.add(idx)
def before_drop(table, bind, **kw):
return
def after_drop(table, bind, **kw):
return
_MYSQL_FUNCTIONS = {
"ST_AsEWKB": "ST_AsBinary",
}
def _compiles_mysql(cls, fn):
def _compile_mysql(element, compiler, **kw):
return "{}({})".format(fn, compiler.process(element.clauses, **kw))
compiles(getattr(functions, cls), "mysql")(_compile_mysql)
compiles(getattr(functions, cls), "mariadb")(_compile_mysql)
def register_mysql_mapping(mapping):
"""Register compilation mappings for the given functions.
Args:
mapping: Should have the following form::
{
"function_name_1": "mysql_function_name_1",
"function_name_2": "mysql_function_name_2",
...
}
"""
for cls, fn in mapping.items():
_compiles_mysql(cls, fn)
register_mysql_mapping(_MYSQL_FUNCTIONS)
def _compile_GeomFromText_MySql(element, compiler, **kw):
element.identifier = "ST_GeomFromText"
compiled = compiler.process(element.clauses, **kw)
srid = element.type.srid
if srid > 0:
return "{}({}, {})".format(element.identifier, compiled, srid)
else:
return "{}({})".format(element.identifier, compiled)
def _compile_GeomFromWKB_MySql(element, compiler, **kw):
element.identifier = "ST_GeomFromWKB"
wkb_data = list(element.clauses)[0].value
if isinstance(wkb_data, memoryview):
list(element.clauses)[0].value = wkb_data.tobytes()
compiled = compiler.process(element.clauses, **kw)
srid = element.type.srid
if srid > 0:
return "{}({}, {})".format(element.identifier, compiled, srid)
else:
return "{}({})".format(element.identifier, compiled)
@compiles(functions.ST_GeomFromText, "mysql") # type: ignore
@compiles(functions.ST_GeomFromText, "mariadb") # type: ignore
def _MySQL_ST_GeomFromText(element, compiler, **kw):
return _compile_GeomFromText_MySql(element, compiler, **kw)
@compiles(functions.ST_GeomFromEWKT, "mysql") # type: ignore
@compiles(functions.ST_GeomFromEWKT, "mariadb") # type: ignore
def _MySQL_ST_GeomFromEWKT(element, compiler, **kw):
return _compile_GeomFromText_MySql(element, compiler, **kw)
@compiles(functions.ST_GeomFromWKB, "mysql") # type: ignore
@compiles(functions.ST_GeomFromWKB, "mariadb") # type: ignore
def _MySQL_ST_GeomFromWKB(element, compiler, **kw):
return _compile_GeomFromWKB_MySql(element, compiler, **kw)
@compiles(functions.ST_GeomFromEWKB, "mysql") # type: ignore
@compiles(functions.ST_GeomFromEWKB, "mariadb") # type: ignore
def _MySQL_ST_GeomFromEWKB(element, compiler, **kw):
return _compile_GeomFromWKB_MySql(element, compiler, **kw)