Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: close connection on error #358

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pymysql.charset import charset_by_name, charset_by_id
from pymysql.constants import SERVER_STATUS
from pymysql.constants import CLIENT
from pymysql.constants import CR
from pymysql.constants import COMMAND
from pymysql.constants import FIELD_TYPE
from pymysql.util import byte2int, int2byte
Expand Down Expand Up @@ -558,6 +559,12 @@ async def _read_packet(self, packet_type=MysqlPacket):
# we increment in both write_packet and read_packet. The count
# is reset at new COMMAND PHASE.
if packet_number != self._next_seq_id:
self._force_close()
if packet_number == 0:
# MariaDB sends error packet with seqno==0 when shutdown
raise OperationalError(
CR.CR_SERVER_LOST,
"Lost connection to MySQL server during query")
raise InternalError(
"Packet sequence number wrong - got %d expected %d" %
(packet_number, self._next_seq_id))
Expand Down Expand Up @@ -585,10 +592,12 @@ async def _read_bytes(self, num_bytes):
data = await self._reader.readexactly(num_bytes)
except asyncio.streams.IncompleteReadError as e:
msg = "Lost connection to MySQL server during query"
raise OperationalError(2013, msg) from e
self._force_close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
except (IOError, OSError) as e:
msg = "Lost connection to MySQL server during query (%s)" % (e,)
raise OperationalError(2013, msg) from e
self._force_close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
return data

def _write_bytes(self, data):
Expand Down Expand Up @@ -1052,11 +1061,14 @@ def _ensure_alive(self):
else:
raise InterfaceError(self._close_reason)

def __del__(self):
def _force_close(self):
if self._writer:
warnings.warn("Unclosed connection {!r}".format(self),
ResourceWarning)
self.close()

__del__ = _force_close

Warning = Warning
Error = Error
InterfaceError = InterfaceError
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
asynctest==0.12.2
coverage==4.5.1
flake8==3.5.0
ipdb==0.11
Expand Down
115 changes: 115 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import unittest
import aiomysql
from tests._testutils import run_until_complete
from asynctest import patch, MagicMock, CoroutineMock
from tests.base import AIOPyMySQLTestCase
from tests._testutils import BaseTest


PY_341 = sys.version_info >= (3, 4, 1)
Expand Down Expand Up @@ -255,3 +257,116 @@ def test_commit_during_multi_result(self):
yield from cur.execute("SELECT 3;")
resp = yield from cur.fetchone()
self.assertEqual(resp[0], 3)


class PacketTestCase(BaseTest):

@patch('aiomysql.connection.Connection._close_on_cancel')
@patch('aiomysql.connection.Connection._read_bytes')
@run_until_complete
async def test_read_packet_close_on_cancel(self,
read_bytes_mock,
close_on_cancel_mock):

read_bytes_mock.side_effect = asyncio.CancelledError()
conn = aiomysql.Connection()

try:
await conn._read_packet()
except asyncio.CancelledError:
pass

close_on_cancel_mock.assert_called_once()

@patch('aiomysql.connection.Connection._read_bytes')
@run_until_complete
async def test_read_mariadb_shutdown_packet(self,
read_bytes_mock):

read_bytes_mock.return_value = b'\x10\x00\x00\x00'
conn = aiomysql.Connection()
writer_mock = MagicMock()
conn._writer = writer_mock # Fake a connection
conn._next_seq_id = 1

with self.assertRaises(aiomysql.OperationalError):
await conn._read_packet()

self.assertTrue(conn.closed)
writer_mock.transport.close.assert_called_once()

@patch('aiomysql.connection.Connection._read_bytes')
@run_until_complete
async def test_read_packet_wrong_sequence(self,
read_bytes_mock):

read_bytes_mock.return_value = b'\x10\x00\x00\x01'
conn = aiomysql.Connection()
writer_mock = MagicMock()
conn._writer = writer_mock # Fake a connection
conn._next_seq_id = 2

with self.assertRaises(aiomysql.InternalError):
await conn._read_packet()

self.assertTrue(conn.closed)
writer_mock.transport.close.assert_called_once()

@run_until_complete
async def test_read_bytes_incomplete(self):
conn = aiomysql.Connection()

writer_mock = MagicMock()
reader_mock = MagicMock()
reader_mock.readexactly = CoroutineMock(
side_effect=asyncio.streams.IncompleteReadError(
partial=b'\x01',
expected=10)
)

conn._writer = writer_mock # Fake a connection
conn._reader = reader_mock

with self.assertRaises(aiomysql.OperationalError):
await conn._read_bytes(10)

self.assertTrue(conn.closed)
writer_mock.transport.close.assert_called_once()

@run_until_complete
async def test_read_bytes_ioerror(self):
conn = aiomysql.Connection()

writer_mock = MagicMock()
reader_mock = MagicMock()
reader_mock.readexactly = CoroutineMock(
side_effect=IOError()
)

conn._writer = writer_mock # Fake a connection
conn._reader = reader_mock

with self.assertRaises(aiomysql.OperationalError):
await conn._read_bytes(10)

self.assertTrue(conn.closed)
writer_mock.transport.close.assert_called_once()

@run_until_complete
async def test_read_bytes_oserror(self):
conn = aiomysql.Connection()

writer_mock = MagicMock()
reader_mock = MagicMock()
reader_mock.readexactly = CoroutineMock(
side_effect=OSError()
)

conn._writer = writer_mock # Fake a connection
conn._reader = reader_mock

with self.assertRaises(aiomysql.OperationalError):
await conn._read_bytes(10)

self.assertTrue(conn.closed)
writer_mock.transport.close.assert_called_once()