Skip to content

Commit

Permalink
Fix MySQL 8.0 tests, properly close timed out connections (#660)
Browse files Browse the repository at this point in the history
* fix closed MySQL 8.0 connections not always being detected properly
ensure connections are closed properly when the server connection was lost
* implement a custom StreamReader to avoid accessing the internal attribute `_eof` on asyncio.StreamReader
* ensure connections are closed when raising an InternalError
  • Loading branch information
Nothing4You authored Jan 26, 2022
1 parent fce0355 commit 6887375
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ To be included in 1.0.0 (unreleased)

* Don't send sys.argv[0] as program_name to MySQL server by default #620
* Allow running process as anonymous uid #587
* Fix timed out MySQL 8.0 connections raising InternalError rather than OperationalError #660
* Fix timed out MySQL 8.0 connections being returned from Pool #660
* Ensure connections are properly closed before raising an OperationalError when the server connection is lost #660
* Ensure connections are properly closed before raising an InternalError when packet sequence numbers are out of sync #660


0.0.22 (2021-11-14)
Expand Down
73 changes: 67 additions & 6 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pymysql.constants import SERVER_STATUS
from pymysql.constants import CLIENT
from pymysql.constants import COMMAND
from pymysql.constants import CR
from pymysql.constants import FIELD_TYPE
from pymysql.util import byte2int, int2byte
from pymysql.converters import (escape_item, encoders, decoders,
Expand Down Expand Up @@ -79,6 +80,57 @@ async def _connect(*args, **kwargs):
return conn


async def _open_connection(host=None, port=None, **kwds):
"""This is based on asyncio.open_connection, allowing us to use a custom
StreamReader.
`limit` arg has been removed as we don't currently use it.
"""
loop = asyncio.events.get_running_loop()
reader = _StreamReader(loop=loop)
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
transport, _ = await loop.create_connection(
lambda: protocol, host, port, **kwds)
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
return reader, writer


async def _open_unix_connection(path=None, **kwds):
"""This is based on asyncio.open_unix_connection, allowing us to use a custom
StreamReader.
`limit` arg has been removed as we don't currently use it.
"""
loop = asyncio.events.get_running_loop()

reader = _StreamReader(loop=loop)
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
transport, _ = await loop.create_unix_connection(
lambda: protocol, path, **kwds)
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
return reader, writer


class _StreamReader(asyncio.StreamReader):
"""This StreamReader exposes whether EOF was received, allowing us to
discard the associated connection instead of returning it from the pool
when checking free connections in Pool._fill_free_pool().
`limit` arg has been removed as we don't currently use it.
"""
def __init__(self, loop=None):
self._eof_received = False
super().__init__(loop=loop)

def feed_eof(self) -> None:
self._eof_received = True
super().feed_eof()

@property
def eof_received(self):
return self._eof_received


class Connection:
"""Representation of a socket with a mysql server.
Expand Down Expand Up @@ -471,21 +523,21 @@ async def set_charset(self, charset):

async def _connect(self):
# TODO: Set close callback
# raise OperationalError(2006,
# raise OperationalError(CR.CR_SERVER_GONE_ERROR,
# "MySQL server has gone away (%r)" % (e,))
try:
if self._unix_socket and self._host in ('localhost', '127.0.0.1'):
self._reader, self._writer = await \
asyncio.wait_for(
asyncio.open_unix_connection(
_open_unix_connection(
self._unix_socket),
timeout=self.connect_timeout)
self.host_info = "Localhost via UNIX socket: " + \
self._unix_socket
else:
self._reader, self._writer = await \
asyncio.wait_for(
asyncio.open_connection(
_open_connection(
self._host,
self._port),
timeout=self.connect_timeout)
Expand Down Expand Up @@ -570,6 +622,13 @@ async def _read_packet(self, packet_type=MysqlPacket):
# we increment in both write_packet and read_packet. The count
# is reset at new COMMAND PHASE.
if packet_number != self._next_seq_id:
self.close()
if packet_number == 0:
# MySQL 8.0 sends error packet with seqno==0 when shutdown
raise OperationalError(
CR.CR_SERVER_LOST,
"Lost connection to MySQL server during query")

raise InternalError(
"Packet sequence number wrong - got %d expected %d" %
(packet_number, self._next_seq_id))
Expand Down Expand Up @@ -597,10 +656,12 @@ async def _read_bytes(self, num_bytes):
data = await self._reader.readexactly(num_bytes)
except asyncio.IncompleteReadError as e:
msg = "Lost connection to MySQL server during query"
raise OperationalError(2013, msg) from e
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
except (IOError, OSError) as e:
msg = "Lost connection to MySQL server during query (%s)" % (e,)
raise OperationalError(2013, msg) from e
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
return data

def _write_bytes(self, data):
Expand Down Expand Up @@ -704,7 +765,7 @@ async def _request_authentication(self):
# TCP connection not at start. Passing in a socket to
# open_connection will cause it to negotiate TLS on an existing
# connection not initiate a new one.
self._reader, self._writer = await asyncio.open_connection(
self._reader, self._writer = await _open_connection(
sock=raw_sock, ssl=self._ssl_context,
server_hostname=self._host
)
Expand Down
10 changes: 9 additions & 1 deletion aiomysql/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ async def _acquire(self):
await self._cond.wait()

async def _fill_free_pool(self, override_min):
# iterate over free connections and remove timeouted ones
# iterate over free connections and remove timed out ones
free_size = len(self._free)
n = 0
while n < free_size:
Expand All @@ -152,6 +152,14 @@ async def _fill_free_pool(self, override_min):
self._free.pop()
conn.close()

# On MySQL 8.0 a timed out connection sends an error packet before
# closing the connection, preventing us from relying on at_eof().
# This relies on our custom StreamReader, as eof_received is not
# present in asyncio.StreamReader.
elif conn._reader.eof_received:
self._free.pop()
conn.close()

elif (self._recycle > -1 and
self._loop.time() - conn.last_usage > self._recycle):
self._free.pop()
Expand Down

0 comments on commit 6887375

Please sign in to comment.