From 51d83d411249604cec2093d64092e391a4926fa5 Mon Sep 17 00:00:00 2001 From: Pedro Kiefer Date: Tue, 20 Nov 2018 15:41:29 -0200 Subject: [PATCH] refactor: close connection on error --- aiomysql/connection.py | 18 +++++- requirements-dev.txt | 1 + tests/test_connection.py | 115 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+), 3 deletions(-) diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 74bf22a5..c053c07d 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -14,6 +14,7 @@ from pymysql.charset import charset_by_name, charset_by_id from pymysql.constants import SERVER_STATUS from pymysql.constants import CLIENT +from pymysql.constants import CR from pymysql.constants import COMMAND from pymysql.constants import FIELD_TYPE from pymysql.util import byte2int, int2byte @@ -558,6 +559,12 @@ async def _read_packet(self, packet_type=MysqlPacket): # we increment in both write_packet and read_packet. The count # is reset at new COMMAND PHASE. if packet_number != self._next_seq_id: + self._force_close() + if packet_number == 0: + # MariaDB sends error packet with seqno==0 when shutdown + raise OperationalError( + CR.CR_SERVER_LOST, + "Lost connection to MySQL server during query") raise InternalError( "Packet sequence number wrong - got %d expected %d" % (packet_number, self._next_seq_id)) @@ -585,10 +592,12 @@ async def _read_bytes(self, num_bytes): data = await self._reader.readexactly(num_bytes) except asyncio.streams.IncompleteReadError as e: msg = "Lost connection to MySQL server during query" - raise OperationalError(2013, msg) from e + self._force_close() + raise OperationalError(CR.CR_SERVER_LOST, msg) from e except (IOError, OSError) as e: msg = "Lost connection to MySQL server during query (%s)" % (e,) - raise OperationalError(2013, msg) from e + self._force_close() + raise OperationalError(CR.CR_SERVER_LOST, msg) from e return data def _write_bytes(self, data): @@ -1052,11 +1061,14 @@ def _ensure_alive(self): else: raise InterfaceError(self._close_reason) - def __del__(self): + def _force_close(self): if self._writer: warnings.warn("Unclosed connection {!r}".format(self), ResourceWarning) self.close() + + __del__ = _force_close + Warning = Warning Error = Error InterfaceError = InterfaceError diff --git a/requirements-dev.txt b/requirements-dev.txt index f2dbe6a0..68322330 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ +asynctest==0.12.2 coverage==4.5.1 flake8==3.5.0 ipdb==0.11 diff --git a/tests/test_connection.py b/tests/test_connection.py index 4329fab2..9cbb92b9 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,7 +5,9 @@ import unittest import aiomysql from tests._testutils import run_until_complete +from asynctest import patch, MagicMock, CoroutineMock from tests.base import AIOPyMySQLTestCase +from tests._testutils import BaseTest PY_341 = sys.version_info >= (3, 4, 1) @@ -255,3 +257,116 @@ def test_commit_during_multi_result(self): yield from cur.execute("SELECT 3;") resp = yield from cur.fetchone() self.assertEqual(resp[0], 3) + + +class PacketTestCase(BaseTest): + + @patch('aiomysql.connection.Connection._close_on_cancel') + @patch('aiomysql.connection.Connection._read_bytes') + @run_until_complete + async def test_read_packet_close_on_cancel(self, + read_bytes_mock, + close_on_cancel_mock): + + read_bytes_mock.side_effect = asyncio.CancelledError() + conn = aiomysql.Connection() + + try: + await conn._read_packet() + except asyncio.CancelledError: + pass + + close_on_cancel_mock.assert_called_once() + + @patch('aiomysql.connection.Connection._read_bytes') + @run_until_complete + async def test_read_mariadb_shutdown_packet(self, + read_bytes_mock): + + read_bytes_mock.return_value = b'\x10\x00\x00\x00' + conn = aiomysql.Connection() + writer_mock = MagicMock() + conn._writer = writer_mock # Fake a connection + conn._next_seq_id = 1 + + with self.assertRaises(aiomysql.OperationalError): + await conn._read_packet() + + self.assertTrue(conn.closed) + writer_mock.transport.close.assert_called_once() + + @patch('aiomysql.connection.Connection._read_bytes') + @run_until_complete + async def test_read_packet_wrong_sequence(self, + read_bytes_mock): + + read_bytes_mock.return_value = b'\x10\x00\x00\x01' + conn = aiomysql.Connection() + writer_mock = MagicMock() + conn._writer = writer_mock # Fake a connection + conn._next_seq_id = 2 + + with self.assertRaises(aiomysql.InternalError): + await conn._read_packet() + + self.assertTrue(conn.closed) + writer_mock.transport.close.assert_called_once() + + @run_until_complete + async def test_read_bytes_incomplete(self): + conn = aiomysql.Connection() + + writer_mock = MagicMock() + reader_mock = MagicMock() + reader_mock.readexactly = CoroutineMock( + side_effect=asyncio.streams.IncompleteReadError( + partial=b'\x01', + expected=10) + ) + + conn._writer = writer_mock # Fake a connection + conn._reader = reader_mock + + with self.assertRaises(aiomysql.OperationalError): + await conn._read_bytes(10) + + self.assertTrue(conn.closed) + writer_mock.transport.close.assert_called_once() + + @run_until_complete + async def test_read_bytes_ioerror(self): + conn = aiomysql.Connection() + + writer_mock = MagicMock() + reader_mock = MagicMock() + reader_mock.readexactly = CoroutineMock( + side_effect=IOError() + ) + + conn._writer = writer_mock # Fake a connection + conn._reader = reader_mock + + with self.assertRaises(aiomysql.OperationalError): + await conn._read_bytes(10) + + self.assertTrue(conn.closed) + writer_mock.transport.close.assert_called_once() + + @run_until_complete + async def test_read_bytes_oserror(self): + conn = aiomysql.Connection() + + writer_mock = MagicMock() + reader_mock = MagicMock() + reader_mock.readexactly = CoroutineMock( + side_effect=OSError() + ) + + conn._writer = writer_mock # Fake a connection + conn._reader = reader_mock + + with self.assertRaises(aiomysql.OperationalError): + await conn._read_bytes(10) + + self.assertTrue(conn.closed) + writer_mock.transport.close.assert_called_once()