From 6887375d0c991b125801423ba1de6a9bca1ee24b Mon Sep 17 00:00:00 2001 From: Richard Schwab Date: Wed, 26 Jan 2022 18:41:39 +0100 Subject: [PATCH] Fix MySQL 8.0 tests, properly close timed out connections (#660) * fix closed MySQL 8.0 connections not always being detected properly ensure connections are closed properly when the server connection was lost * implement a custom StreamReader to avoid accessing the internal attribute `_eof` on asyncio.StreamReader * ensure connections are closed when raising an InternalError --- CHANGES.txt | 4 +++ aiomysql/connection.py | 73 ++++++++++++++++++++++++++++++++++++++---- aiomysql/pool.py | 10 +++++- 3 files changed, 80 insertions(+), 7 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index 4f3c82ef..b5b276b1 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -6,6 +6,10 @@ To be included in 1.0.0 (unreleased) * Don't send sys.argv[0] as program_name to MySQL server by default #620 * Allow running process as anonymous uid #587 +* Fix timed out MySQL 8.0 connections raising InternalError rather than OperationalError #660 +* Fix timed out MySQL 8.0 connections being returned from Pool #660 +* Ensure connections are properly closed before raising an OperationalError when the server connection is lost #660 +* Ensure connections are properly closed before raising an InternalError when packet sequence numbers are out of sync #660 0.0.22 (2021-11-14) diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 01496d1c..f2b139b1 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -15,6 +15,7 @@ from pymysql.constants import SERVER_STATUS from pymysql.constants import CLIENT from pymysql.constants import COMMAND +from pymysql.constants import CR from pymysql.constants import FIELD_TYPE from pymysql.util import byte2int, int2byte from pymysql.converters import (escape_item, encoders, decoders, @@ -79,6 +80,57 @@ async def _connect(*args, **kwargs): return conn +async def _open_connection(host=None, port=None, **kwds): + """This is based on asyncio.open_connection, allowing us to use a custom + StreamReader. + + `limit` arg has been removed as we don't currently use it. + """ + loop = asyncio.events.get_running_loop() + reader = _StreamReader(loop=loop) + protocol = asyncio.StreamReaderProtocol(reader, loop=loop) + transport, _ = await loop.create_connection( + lambda: protocol, host, port, **kwds) + writer = asyncio.StreamWriter(transport, protocol, reader, loop) + return reader, writer + + +async def _open_unix_connection(path=None, **kwds): + """This is based on asyncio.open_unix_connection, allowing us to use a custom + StreamReader. + + `limit` arg has been removed as we don't currently use it. + """ + loop = asyncio.events.get_running_loop() + + reader = _StreamReader(loop=loop) + protocol = asyncio.StreamReaderProtocol(reader, loop=loop) + transport, _ = await loop.create_unix_connection( + lambda: protocol, path, **kwds) + writer = asyncio.StreamWriter(transport, protocol, reader, loop) + return reader, writer + + +class _StreamReader(asyncio.StreamReader): + """This StreamReader exposes whether EOF was received, allowing us to + discard the associated connection instead of returning it from the pool + when checking free connections in Pool._fill_free_pool(). + + `limit` arg has been removed as we don't currently use it. + """ + def __init__(self, loop=None): + self._eof_received = False + super().__init__(loop=loop) + + def feed_eof(self) -> None: + self._eof_received = True + super().feed_eof() + + @property + def eof_received(self): + return self._eof_received + + class Connection: """Representation of a socket with a mysql server. @@ -471,13 +523,13 @@ async def set_charset(self, charset): async def _connect(self): # TODO: Set close callback - # raise OperationalError(2006, + # raise OperationalError(CR.CR_SERVER_GONE_ERROR, # "MySQL server has gone away (%r)" % (e,)) try: if self._unix_socket and self._host in ('localhost', '127.0.0.1'): self._reader, self._writer = await \ asyncio.wait_for( - asyncio.open_unix_connection( + _open_unix_connection( self._unix_socket), timeout=self.connect_timeout) self.host_info = "Localhost via UNIX socket: " + \ @@ -485,7 +537,7 @@ async def _connect(self): else: self._reader, self._writer = await \ asyncio.wait_for( - asyncio.open_connection( + _open_connection( self._host, self._port), timeout=self.connect_timeout) @@ -570,6 +622,13 @@ async def _read_packet(self, packet_type=MysqlPacket): # we increment in both write_packet and read_packet. The count # is reset at new COMMAND PHASE. if packet_number != self._next_seq_id: + self.close() + if packet_number == 0: + # MySQL 8.0 sends error packet with seqno==0 when shutdown + raise OperationalError( + CR.CR_SERVER_LOST, + "Lost connection to MySQL server during query") + raise InternalError( "Packet sequence number wrong - got %d expected %d" % (packet_number, self._next_seq_id)) @@ -597,10 +656,12 @@ async def _read_bytes(self, num_bytes): data = await self._reader.readexactly(num_bytes) except asyncio.IncompleteReadError as e: msg = "Lost connection to MySQL server during query" - raise OperationalError(2013, msg) from e + self.close() + raise OperationalError(CR.CR_SERVER_LOST, msg) from e except (IOError, OSError) as e: msg = "Lost connection to MySQL server during query (%s)" % (e,) - raise OperationalError(2013, msg) from e + self.close() + raise OperationalError(CR.CR_SERVER_LOST, msg) from e return data def _write_bytes(self, data): @@ -704,7 +765,7 @@ async def _request_authentication(self): # TCP connection not at start. Passing in a socket to # open_connection will cause it to negotiate TLS on an existing # connection not initiate a new one. - self._reader, self._writer = await asyncio.open_connection( + self._reader, self._writer = await _open_connection( sock=raw_sock, ssl=self._ssl_context, server_hostname=self._host ) diff --git a/aiomysql/pool.py b/aiomysql/pool.py index a17e3fca..3eacb47d 100644 --- a/aiomysql/pool.py +++ b/aiomysql/pool.py @@ -143,7 +143,7 @@ async def _acquire(self): await self._cond.wait() async def _fill_free_pool(self, override_min): - # iterate over free connections and remove timeouted ones + # iterate over free connections and remove timed out ones free_size = len(self._free) n = 0 while n < free_size: @@ -152,6 +152,14 @@ async def _fill_free_pool(self, override_min): self._free.pop() conn.close() + # On MySQL 8.0 a timed out connection sends an error packet before + # closing the connection, preventing us from relying on at_eof(). + # This relies on our custom StreamReader, as eof_received is not + # present in asyncio.StreamReader. + elif conn._reader.eof_received: + self._free.pop() + conn.close() + elif (self._recycle > -1 and self._loop.time() - conn.last_usage > self._recycle): self._free.pop()