Skip to content

Commit

Permalink
track prepared statements in DB object
Browse files Browse the repository at this point in the history
- fixes JuliaDatabases#211
- closes all prepared statements upon close!(db)
- adds SQLite3.finalize_statements!(db) call
- execute(db, sql): close the internal prepared statement immediately,
  don't wait for GC
  • Loading branch information
alyst committed Jan 13, 2021
1 parent dbdd71c commit b5b8656
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 49 deletions.
189 changes: 148 additions & 41 deletions src/SQLite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,38 @@ sqliteexception(db) = SQLiteException(unsafe_string(sqlite3_errmsg(db.handle)))
const DBHandle = Ptr{Cvoid} # SQLite3 DB connection handle
const StmtHandle = Ptr{Cvoid} # SQLite3 prepared statement handle

"""
The wrapper that holds the handle to SQLite3 prepared statement.
It is managed by [`SQLite.DB`](@ref) and referenced by "public" [`SQLite.Stmt`](@ref) objects.
When no `SQLite.Stmt` instances reference the given `SQlite._Stmt` object
(its `refcount` goes to 0), it is closed automatically.
When `SQLite.DB` is closed or [`SQLite.finalize_statements!`](@ref) is called,
all its `SQLite._Stmt` objects are closed.
"""
mutable struct _Stmt
handle::StmtHandle
params::Dict{Int, Any}
refcount::Int # by how many Stmt objects referenced

function _Stmt(handle::StmtHandle)
stmt = new(handle, Dict{Int, Any}(), 0)
finalizer(_close!, stmt)
return stmt
end
end

# close statement
function _close!(stmt::_Stmt)
stmt.handle == C_NULL || sqlite3_finalize(stmt.handle)
stmt.handle = C_NULL
return
end

# _Stmt unique identifier in DB
const _StmtId = Int

"""
`SQLite.DB()` => in-memory SQLite database
`SQLite.DB(file)` => file-based SQLite database
Expand All @@ -42,16 +74,19 @@ The `SQLite.DB` will be automatically closed/shutdown when it goes out of scope
mutable struct DB <: DBInterface.Connection
file::String
handle::DBHandle
stmts::Dict{_StmtId, _Stmt} # opened prepared statements

lastStmtId::_StmtId

function DB(f::AbstractString)
handle = Ref{DBHandle}()
f = isempty(f) ? f : expanduser(f)
if @OK sqliteopen(f, handle)
db = new(f, handle[])
db = new(f, handle[], Dict{StmtHandle, _Stmt}())
finalizer(_close, db)
return db
else # error
db = new(f, handle[])
db = new(f, handle[], Dict{StmtHandle, _Stmt}())
finalizer(_close, db)
sqliteerror(db)
end
Expand All @@ -64,51 +99,106 @@ DBInterface.close!(db::DB) = _close(db)
Base.close(db::DB) = _close(db)
Base.isopen(db::DB) = db.handle != C_NULL

# close all prepared statements of db connection
function finalize_statements!(db::DB)
for stmt in values(db.stmts)
_close!(stmt)
end
empty!(db.stmts)
end

function _close(db::DB)
finalize_statements!(db)
# disconnect from DB
db.handle == C_NULL || sqlite3_close_v2(db.handle)
db.handle = C_NULL
return
end

Base.show(io::IO, db::SQLite.DB) = print(io, string("SQLite.DB(", "\"$(db.file)\"", ")"))

# prepare given sql statement
function _Stmt(db::DB, sql::AbstractString)
handle = Ref{StmtHandle}()
sqliteprepare(db, sql, handle, Ref{StmtHandle}())
return _Stmt(handle[])
end

"""
SQLite.Stmt(db, sql) => SQL.Stmt
Constructs and prepares (compiled by the SQLite library)
an SQL statement in the context of the provided `db`.
Note the SQL statement is not actually executed,
but only compiled
(mainly for usage where the same statement
is repeated with different parameters bound as values.
Prepares an optimized internal representation of SQL statement in
the context of the provided SQLite3 `db` and constructs the `SQLite.Stmt`
Julia object that holds a reference to the prepared statement.
*Note*: the `sql` statement is not actually executed, but only compiled
(mainly for usage where the same statement is executed multiple times
with different parameters bound as values).
Internally `SQLite.Stmt` constructor creates the [`SQLite._Stmt`](@ref) object that is managed by `db`.
`SQLite.Stmt` references the `SQLite._Stmt` by its unique id.
The `SQLite.Stmt` will be automatically closed/shutdown when it goes out of scope
(i.e. the end of the Julia session, end of a function call wherein it was created, etc.),
but you can close `DBInterface.close!(stmt)` to explicitly and immediately close the statement.
(i.e. the end of the Julia session, end of a function call wherein it was created, etc.).
One can also call `DBInterface.close!(stmt)` to immediately close it.
All prepared statements of a given DB connection are also automatically closed when the
DB is disconnected or when [`SQLite.finalize_statements!`](@ref) is explicitly called.
"""
mutable struct Stmt <: DBInterface.Statement
db::DB
handle::StmtHandle
params::Dict{Int, Any}
status::Int
id::_StmtId # id of _Stmt inside db (may refer to already closed connection)

function Stmt(db::DB, sql::AbstractString)
handle = Ref{StmtHandle}()
sqliteprepare(db, sql, handle, Ref{StmtHandle}())
stmt = new(db, handle[], Dict{Int, Any}(), 0)
finalizer(_close, stmt)
_stmt = _Stmt(db, sql)
_stmt.refcount += 1
id = (db.lastStmtId += 1)
stmt = new(db, id)
db.stmts[id] = _stmt # FIXME check for duplicate handle?
finalizer(_finalize, stmt)
return stmt
end
end

DBInterface.close!(stmt::Stmt) = _close(stmt)
# check if the statement is ready (not finalized due to
# _close(_Stmt) called and the statment handle removed from DB)
isready(stmt::Stmt) = haskey(stmt.db.stmts, stmt.id)

function _close(stmt::Stmt)
stmt.handle == C_NULL || sqlite3_finalize(stmt.handle)
stmt.handle = C_NULL
# get underlying _Stmt or nothing if not found
_stmt_safe(stmt::Stmt) = get(stmt.db.stmts, stmt.id, nothing)

# get underlying _Stmt or throw if not found
@inline function _stmt(stmt::Stmt)
_st = _stmt_safe(stmt)
(_st === nothing) && throw(SQLiteException("Statement $(stmt.id) not found"))
return _st
end

# automatically finalizes prepared statement (_Stmt)
# when no Stmt objects refer to it and removes
# it from the db.stmts collection
function _finalize(stmt::Stmt)
_st = _stmt_safe(stmt)
(_st === nothing) && return # silently do nothing if _Stmt is already unregistered
_st.refcount -= 1
@assert _st.refcount >= 0
if _st.refcount == 0 # close and delete unreferenced statement
_close!(_st)
delete!(stmt.db.stmts, stmt.id)
end
return
end

# explicitly close prepared statement (ref count might be > 0)
function DBInterface.close!(stmt::Stmt)
_st = _stmt_safe(stmt)
if _st !== nothing
_close!(_st)
delete!(stmt.db.stmts, stmt.id) # remove the _Stmt
end
return stmt
end

sqliteprepare(db, sql, stmt, null) = @CHECK db sqlite3_prepare_v2(db.handle, sql, stmt, null)

include("UDF.jl")
Expand All @@ -120,8 +210,9 @@ export @sr_str
Clears any bound values to a prepared SQL statement
"""
function clear!(stmt::Stmt)
sqlite3_clear_bindings(stmt.handle)
empty!(stmt.params)
_st = _stmt(stmt)
sqlite3_clear_bindings(_st.handle)
empty!(_st.params)
return
end

Expand Down Expand Up @@ -173,7 +264,7 @@ From the [SQLite documentation](https://www3.sqlite.org/cintro.html):
"""
function bind! end

function bind!(stmt::Stmt, params::DBInterface.NamedStatementParams)
function bind!(stmt::_Stmt, params::DBInterface.NamedStatementParams)
nparams = sqlite3_bind_parameter_count(stmt.handle)
(nparams == length(params)) || throw(SQLiteException("values should be provided for all query placeholders"))
for i in 1:nparams
Expand All @@ -186,36 +277,41 @@ function bind!(stmt::Stmt, params::DBInterface.NamedStatementParams)
end
end

function bind!(stmt::Stmt, values::DBInterface.PositionalStatementParams)
function bind!(stmt::_Stmt, values::DBInterface.PositionalStatementParams)
nparams = sqlite3_bind_parameter_count(stmt.handle)
(nparams == length(values)) || throw(SQLiteException("values should be provided for all query placeholders"))
for i in 1:nparams
@inbounds bind!(stmt, i, values[i])
end
end

bind!(stmt::Stmt; kwargs...) = bind!(stmt, kwargs.data)
bind!(stmt::Stmt, values::DBInterface.StatementParams) = bind!(_stmt(stmt), values)

bind!(stmt::Union{_Stmt, Stmt}; kwargs...) = bind!(stmt, kwargs.data)

# Binding parameters to SQL statements
function bind!(stmt::Stmt, name::AbstractString, val::Any)
function bind!(stmt::_Stmt, name::AbstractString, val::Any)
i::Int = sqlite3_bind_parameter_index(stmt.handle, name)
if i == 0
throw(SQLiteException("SQL parameter $name not found in $stmt"))
end
return bind!(stmt, i, val)
end

bind!(stmt::Stmt, i::Integer, val::AbstractFloat) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_double(stmt.handle, i, Float64(val)); return nothing)
bind!(stmt::Stmt, i::Integer, val::Int32) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_int(stmt.handle, i, val); return nothing)
bind!(stmt::Stmt, i::Integer, val::Int64) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_int64(stmt.handle, i, val); return nothing)
bind!(stmt::Stmt, i::Integer, val::Missing) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_null(stmt.handle, i); return nothing)
bind!(stmt::Stmt, i::Integer, val::AbstractString) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_text(stmt.handle, i, val); return nothing)
bind!(stmt::Stmt, i::Integer, val::WeakRefString{UInt8}) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_text(stmt.handle, i, val.ptr, val.len); return nothing)
bind!(stmt::Stmt, i::Integer, val::WeakRefString{UInt16}) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_text16(stmt.handle, i, val.ptr, val.len*2); return nothing)
bind!(stmt::Stmt, i::Integer, val::Bool) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_int(stmt.handle, i, Int32(val)); return nothing)
bind!(stmt::Stmt, i::Integer, val::Vector{UInt8}) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_blob(stmt.handle, i, val); return nothing)
# binding method for internal _Stmt class
bind!(stmt::_Stmt, i::Integer, val::AbstractFloat) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_double(stmt.handle, i, Float64(val)); return nothing)
bind!(stmt::_Stmt, i::Integer, val::Int32) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_int(stmt.handle, i, val); return nothing)
bind!(stmt::_Stmt, i::Integer, val::Int64) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_int64(stmt.handle, i, val); return nothing)
bind!(stmt::_Stmt, i::Integer, val::Missing) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_null(stmt.handle, i); return nothing)
bind!(stmt::_Stmt, i::Integer, val::AbstractString) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_text(stmt.handle, i, val); return nothing)
bind!(stmt::_Stmt, i::Integer, val::WeakRefString{UInt8}) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_text(stmt.handle, i, val.ptr, val.len); return nothing)
bind!(stmt::_Stmt, i::Integer, val::WeakRefString{UInt16}) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_text16(stmt.handle, i, val.ptr, val.len*2); return nothing)
bind!(stmt::_Stmt, i::Integer, val::Bool) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_int(stmt.handle, i, Int32(val)); return nothing)
bind!(stmt::_Stmt, i::Integer, val::Vector{UInt8}) = (stmt.params[i] = val; @CHECK stmt.db sqlite3_bind_blob(stmt.handle, i, val); return nothing)
# Fallback is BLOB and defaults to serializing the julia value

bind!(stmt::Stmt, param::Union{Integer, AbstractString}, val::Any) = bind!(_stmt(stmt), param, val)

# internal wrapper mutable struct to, in-effect, mark something which has been serialized
struct Serialized
object
Expand All @@ -232,7 +328,7 @@ function sqlserialize(x)
return take!(GLOBAL_BUF)
end
# fallback method to bind arbitrary julia `val` to the parameter at index `i` (object is serialized)
bind!(stmt::Stmt, i::Integer, val::Any) = bind!(stmt, i, sqlserialize(val))
bind!(stmt::_Stmt, i::Integer, val::Any) = bind!(stmt, i, sqlserialize(val))

struct SerializeError <: Exception
msg::String
Expand Down Expand Up @@ -319,24 +415,35 @@ To get the results of a SQL query, it is recommended to use [`DBInterface.execut
"""
function execute end

function execute(stmt::Stmt, params::DBInterface.StatementParams)
function execute(db::DB, stmt::_Stmt, params::DBInterface.StatementParams=())
sqlite3_reset(stmt.handle)
bind!(stmt, params)
r = sqlite3_step(stmt.handle)
stmt.status = r
if r == SQLITE_DONE
sqlite3_reset(stmt.handle)
elseif r != SQLITE_ROW
e = sqliteexception(stmt.db)
e = sqliteexception(db)
sqlite3_reset(stmt.handle)
throw(e)
end
return r
end

execute(stmt::Stmt, params::DBInterface.StatementParams) =
execute(stmt.db, _stmt(stmt), params)

execute(stmt::Stmt; kwargs...) = execute(stmt, kwargs.data)

execute(db::DB, sql::AbstractString, params::DBInterface.StatementParams) = execute(Stmt(db, sql), params)
function execute(db::DB, sql::AbstractString, params::DBInterface.StatementParams)
# prepare without registering _Stmt in DB
_stmt = _Stmt(db, sql)
try
return execute(db, _stmt, params)
finally
_close!(_stmt) # immediately close, don't wait for GC
end
end

execute(db::DB, sql::AbstractString; kwargs...) = execute(db, sql, kwargs.data)

"""
Expand Down
18 changes: 10 additions & 8 deletions src/tables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,23 @@ Base.IteratorSize(::Type{Query}) = Base.SizeUnknown()
Base.eltype(q::Query) = Row

function reset!(q::Query)
sqlite3_reset(q.stmt.handle)
sqlite3_reset(_stmt(q.stmt).handle)
q.status[] = execute(q.stmt)
return
end

function done(q::Query)
st = q.status[]
if st == SQLITE_DONE
sqlite3_reset(q.stmt.handle)
sqlite3_reset(_stmt(q.stmt).handle)
return true
end
st == SQLITE_ROW || sqliteerror(q.stmt.db)
return false
end

function getvalue(q::Query, col::Int, ::Type{T}) where {T}
handle = q.stmt.handle
handle = _stmt(q.stmt).handle
t = sqlite3_column_type(handle, col)
if t == SQLITE_NULL
return missing
Expand All @@ -60,7 +60,7 @@ function Base.iterate(q::Query)
end

function Base.iterate(q::Query, ::Nothing)
q.status[] = sqlite3_step(q.stmt.handle)
q.status[] = sqlite3_step(_stmt(q.stmt).handle)
done(q) && return nothing
return Row(q), nothing
end
Expand Down Expand Up @@ -92,12 +92,13 @@ like `DataFrame(results)`, `CSV.write("results.csv", results)`, etc.
"""
function DBInterface.execute(stmt::Stmt, params::DBInterface.StatementParams)
status = execute(stmt, params)
cols = sqlite3_column_count(stmt.handle)
_st = _stmt(stmt)
cols = sqlite3_column_count(_st.handle)
header = Vector{Symbol}(undef, cols)
types = Vector{Type}(undef, cols)
for i = 1:cols
header[i] = sym(sqlite3_column_name(stmt.handle, i))
types[i] = Union{juliatype(stmt.handle, i), Missing}
header[i] = sym(sqlite3_column_name(_st.handle, i))
types[i] = Union{juliatype(_st.handle, i), Missing}
end
return Query(stmt, Ref(status), header, types, Dict(x=>i for (i, x) in enumerate(header)))
end
Expand Down Expand Up @@ -195,7 +196,7 @@ function load!(sch::Tables.Schema, rows, db::DB, name::AbstractString, db_tablei
# build insert statement
columns = join(esc_id.(string.(sch.names)), ",")
params = chop(repeat("?,", length(sch.names)))
stmt = Stmt(db, "INSERT INTO $(esc_id(string(name))) ($columns) VALUES ($params)")
stmt = _Stmt(db, "INSERT INTO $(esc_id(string(name))) ($columns) VALUES ($params)")
# start a transaction for inserting rows
transaction(db) do
for row in rows
Expand All @@ -206,6 +207,7 @@ function load!(sch::Tables.Schema, rows, db::DB, name::AbstractString, db_tablei
sqlite3_reset(stmt.handle)
end
end
_close!(stmt)
analyze && execute(db, "ANALYZE $nm")
return name
end
Expand Down
Loading

0 comments on commit b5b8656

Please sign in to comment.