Skip to content

Commit

Permalink
add support for pymysql 1.0.0+ (#643)
Browse files Browse the repository at this point in the history
* add support to pymysql 1.0.2

* vendor pymysql byte/int utils

* Update setup.py

* Update requirements-dev.txt

* tests fix
  • Loading branch information
ghostebony authored Jan 30, 2022
1 parent 624de5f commit 2b790fa
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 13 deletions.
14 changes: 5 additions & 9 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pymysql.constants import COMMAND
from pymysql.constants import CR
from pymysql.constants import FIELD_TYPE
from pymysql.util import byte2int, int2byte
from pymysql.converters import (escape_item, encoders, decoders,
escape_string, escape_bytes_prefixed, through)
from pymysql.err import (Warning, Error,
Expand All @@ -29,18 +28,15 @@
from pymysql.connections import TEXT_TYPES, MAX_PACKET_LEN, DEFAULT_CHARSET
from pymysql.connections import _auth

from pymysql.connections import pack_int24

from pymysql.connections import MysqlPacket
from pymysql.connections import FieldDescriptorPacket
from pymysql.connections import EOFPacketWrapper
from pymysql.connections import OKPacketWrapper
from pymysql.connections import LoadLocalPacketWrapper
from pymysql.connections import lenenc_int

# from aiomysql.utils import _convert_to_str
from .cursors import Cursor
from .utils import _ConnectionContextManager, _ContextManager
from .utils import _pack_int24, _lenenc_int, _ConnectionContextManager, _ContextManager
from .log import logger

try:
Expand Down Expand Up @@ -349,7 +345,7 @@ async def ensure_closed(self):
if self._writer is None:
# connection has been closed
return
send_data = struct.pack('<i', 1) + int2byte(COMMAND.COM_QUIT)
send_data = struct.pack('<i', 1) + bytes([COMMAND.COM_QUIT])
self._writer.write(send_data)
await self._writer.drain()
self.close()
Expand Down Expand Up @@ -588,7 +584,7 @@ def write_packet(self, payload):
"""
# Internal note: when you build packet manually and calls
# _write_bytes() directly, you should set self._next_seq_id properly.
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload
self._write_bytes(data)
self._next_seq_id = (self._next_seq_id + 1) % 256

Expand Down Expand Up @@ -801,7 +797,7 @@ async def _request_authentication(self):
authresp = self._password.encode('latin1') + b'\0'

if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
data += lenenc_int(len(authresp)) + authresp
data += _lenenc_int(len(authresp)) + authresp
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
data += struct.pack('B', len(authresp)) + authresp
else: # pragma: no cover
Expand Down Expand Up @@ -1041,7 +1037,7 @@ async def _get_server_information(self):
packet = await self._read_packet()
data = packet.get_all_data()
# logger.debug(dump_packet(data))
self.protocol_version = byte2int(data[i:i + 1])
self.protocol_version = data[i]
i += 1

server_end = data.find(b'\0', i)
Expand Down
26 changes: 26 additions & 0 deletions aiomysql/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
from collections.abc import Coroutine

import struct


def _pack_int24(n):
return struct.pack("<I", n)[:3]


def _lenenc_int(i):
if i < 0:
raise ValueError(
"Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i
)
elif i < 0xFB:
return bytes([i])
elif i < (1 << 16):
return b"\xfc" + struct.pack("<H", i)
elif i < (1 << 24):
return b"\xfd" + struct.pack("<I", i)[:3]
elif i < (1 << 64):
return b"\xfe" + struct.pack("<Q", i)
else:
raise ValueError(
"Encoding %x is larger than %x - no representation in LengthEncodedInteger"
% (i, (1 << 64))
)


class _ContextManager(Coroutine):

Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ ipdb==0.13.9
pytest==6.2.5
pytest-cov==3.0.0
pytest-sugar==0.9.4
PyMySQL>=0.9,<=0.9.3
PyMySQL>=0.9,<=1.0.2
sphinx>=1.8.1, <4.4.1
sphinxcontrib-asyncio==0.3.0
sqlalchemy>1.2.12,<=1.3.16
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from setuptools import setup, find_packages


install_requires = ['PyMySQL>=0.9,<=0.9.3']
install_requires = ['PyMySQL>=0.9,<=1.0.2']

PY_VER = sys.version_info

Expand Down
3 changes: 1 addition & 2 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import time

import pytest
from pymysql import util
from pymysql.err import ProgrammingError


Expand Down Expand Up @@ -42,7 +41,7 @@ async def test_datatypes(connection, cursor, datatype_table):
await cursor.execute(
"select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes")
r = await cursor.fetchone()
assert util.int2byte(1) == r[0]
assert bytes([1]) == r[0]
# assert v[1:8] == r[1:8])
assert v[1:9] == r[1:9]
# mysql throws away microseconds so we need to check datetimes
Expand Down

0 comments on commit 2b790fa

Please sign in to comment.