Skip to content

Commit

Permalink
Handle OSCProtocol.disconnect() cleanly under asyncio (#336)
Browse files Browse the repository at this point in the history
* Harden async OSC / process protocols

* Bump version to 23.5b5
  • Loading branch information
josiah-wolf-oberholtzer authored May 26, 2023
1 parent b99e240 commit 2b27ff9
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 62 deletions.
2 changes: 1 addition & 1 deletion supriya/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
This follows black's versioning scheme.
"""
__version_info__ = (23, "5b4")
__version_info__ = (23, "5b5")
__version__ = ".".join(str(x) for x in __version_info__)
6 changes: 3 additions & 3 deletions supriya/contexts/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _disconnect(self) -> None:
logger.info("Disconnecting")
self._boot_status = BootStatus.QUITTING
self._teardown_shm()
cast(ThreadedOscProtocol, self._osc_protocol).disconnect()
self._osc_protocol.disconnect()
self._teardown_shm()
self._teardown_state()
if self in self._contexts:
Expand Down Expand Up @@ -833,7 +833,7 @@ async def _connect(self) -> None:
async def _disconnect(self) -> None:
logger.info("Disconnecting")
self._boot_status = BootStatus.QUITTING
await cast(AsyncOscProtocol, self._osc_protocol).disconnect()
self._osc_protocol.disconnect()
self._teardown_shm()
self._teardown_state()
if self in self._contexts:
Expand Down Expand Up @@ -1154,7 +1154,7 @@ async def quit(self, force: bool = False) -> "AsyncServer":
await Quit().communicate_async(server=self, timeout=1)
except (OscProtocolOffline, asyncio.TimeoutError):
pass
self._process_protocol.quit()
await self._process_protocol.quit()
await self._disconnect()
return self

Expand Down
87 changes: 48 additions & 39 deletions supriya/osc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import abc
import asyncio
import atexit
import collections
import contextlib
import dataclasses
Expand Down Expand Up @@ -563,6 +562,9 @@ def _add_callback(self, callback: OscCallback):
callbacks, callback_map = callback_map.setdefault(item, ([], {}))
callbacks.append(callback)

def _disconnect(self) -> None:
raise NotImplementedError

def _match_callbacks(self, message):
items = (message.address,) + message.contents
matching_callbacks = []
Expand Down Expand Up @@ -611,7 +613,7 @@ def _setup(self, ip_address, port, healthcheck):
)

def _teardown(self):
osc_protocol_logger.info("Tearing down...")
osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] Tearing down...")
self.is_running = False
if self.healthcheck is not None:
self.unregister(self.healthcheck_osc_callback)
Expand All @@ -633,12 +635,12 @@ def _validate_callback(
)

def _validate_receive(self, datagram):
udp_in_logger.debug(f"{self.ip_address}:{self.port} {datagram}")
udp_in_logger.debug(f"[{self.ip_address}:{self.port}] {datagram}")
try:
message = OscMessage.from_datagram(datagram)
except Exception:
raise
osc_in_logger.debug(f"{self.ip_address}:{self.port} {message!r}")
osc_in_logger.debug(f"[{self.ip_address}:{self.port}] {message!r}")
for capture in self.captures:
capture.messages.append(
CaptureEntry(timestamp=time.time(), label="R", message=message)
Expand All @@ -655,13 +657,13 @@ def _validate_send(self, message):
message = OscMessage(message)
elif isinstance(message, SequenceABC):
message = OscMessage(*message)
osc_out_logger.debug(f"{self.ip_address}:{self.port} {message!r}")
osc_out_logger.debug(f"[{self.ip_address}:{self.port}] {message!r}")
for capture in self.captures:
capture.messages.append(
CaptureEntry(timestamp=time.time(), label="S", message=message)
)
datagram = message.to_datagram()
udp_out_logger.debug(f"{self.ip_address}:{self.port} {datagram}")
udp_out_logger.debug(f"[{self.ip_address}:{self.port}] {datagram}")
return datagram

### PUBLIC METHODS ###
Expand All @@ -670,9 +672,14 @@ def _validate_send(self, message):
def activate_healthcheck(self) -> None:
raise NotImplementedError

def capture(self):
def capture(self) -> "Capture":
return Capture(self)

def disconnect(self) -> None:
osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] disconnecting")
self._disconnect()
osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] ...disconnected")

@abc.abstractmethod
def register(
self,
Expand Down Expand Up @@ -701,10 +708,20 @@ def __init__(self) -> None:
OscProtocol.__init__(self)
self.background_tasks: Set[asyncio.Task] = set()
self.healthcheck_task: Optional[asyncio.Task] = None
atexit.register(lambda: asyncio.run(self.disconnect()))

### PRIVATE METHODS ###

def _disconnect(self) -> None:
if not self.is_running:
osc_protocol_logger.info(
f"{self.ip_address}:{self.port} already disconnected!"
)
return
self._teardown()
self.transport.close()
if self.healthcheck_task:
self.healthcheck_task.cancel()

async def _run_healthcheck(self):
while self.is_running:
if self.attempts >= self.healthcheck.max_attempts:
Expand Down Expand Up @@ -742,7 +759,11 @@ def activate_healthcheck(self) -> None:
async def connect(
self, ip_address: str, port: int, *, healthcheck: Optional[HealthCheck] = None
):
osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] connecting...")
if self.is_running:
osc_protocol_logger.info(
f"[{self.ip_address}:{self.port}] already connected!"
)
raise OscProtocolAlreadyConnected
self._setup(ip_address, port, healthcheck)
loop = asyncio.get_running_loop()
Expand All @@ -754,14 +775,16 @@ async def connect(
self.healthcheck_task = asyncio.get_running_loop().create_task(
self._run_healthcheck()
)
osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] ...connected")

def connection_made(self, transport):
osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] connected")
osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] connection made")
self.transport = transport
self.is_running = True

def connection_lost(self, exc):
pass
osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] connection lost")
self.exit_future.set_result(True)

def datagram_received(self, data, addr):
loop = asyncio.get_running_loop()
Expand All @@ -773,17 +796,6 @@ def datagram_received(self, data, addr):
else:
callback(message)

async def disconnect(self):
osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] disconnecting")
if not self.is_running:
return
self.exit_future.set_result(True)
self._teardown()
if not self.transport.is_closing():
self.transport.close()
if self.healthcheck_task:
await self.healthcheck_task

def error_received(self, exc):
osc_out_logger.warning(f"[{self.ip_address}:{self.port}] errored: {exc}")

Expand All @@ -805,7 +817,7 @@ def register(
return callback

def send(self, message):
osc_protocol_logger.info(
osc_protocol_logger.debug(
f"[{self.ip_address}:{self.port}] sending: {message!r}"
)
datagram = self._validate_send(message)
Expand Down Expand Up @@ -843,10 +855,22 @@ def __init__(self):
self.lock = threading.RLock()
self.osc_server = None
self.osc_server_thread = None
atexit.register(self.disconnect)

### PRIVATE METHODS ###

def _disconnect(self) -> None:
with self.lock:
if not self.is_running:
osc_protocol_logger.info(
f"{self.ip_address}:{self.port} already disconnected!"
)
return
self._teardown()
if not self.osc_server._BaseServer__shutdown_request:
self.osc_server.shutdown()
self.osc_server = None
self.osc_server_thread = None

def _process_command_queue(self):
while self.command_queue.qsize():
try:
Expand Down Expand Up @@ -919,22 +943,7 @@ def connect(
self.osc_server_thread.daemon = True
self.osc_server_thread.start()
self.is_running = True
osc_protocol_logger.info(f"{self.ip_address}:{self.port} ...connected")

def disconnect(self):
osc_protocol_logger.info(f"{self.ip_address}:{self.port} disconnecting...")
with self.lock:
if not self.is_running:
osc_protocol_logger.info(
f"{self.ip_address}:{self.port} already disconnected!"
)
return
self._teardown()
if not self.osc_server._BaseServer__shutdown_request:
self.osc_server.shutdown()
self.osc_server = None
self.osc_server_thread = None
osc_protocol_logger.info(f"{self.ip_address}:{self.port} ...disconnected")
osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] ...connected")

def register(
self,
Expand Down
29 changes: 12 additions & 17 deletions supriya/scsynth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import atexit
import enum
import logging
import os
Expand Down Expand Up @@ -259,7 +258,6 @@ class LineStatus(enum.IntEnum):
class ProcessProtocol:
def __init__(self):
self.is_running = False
atexit.register(self.quit)

def boot(self, options: Options):
raise NotImplementedError
Expand All @@ -282,8 +280,6 @@ def _handle_line(self, line):


class SyncProcessProtocol(ProcessProtocol):
### PUBLIC METHODS ###

def boot(self, options: Options):
if self.is_running:
return
Expand All @@ -293,7 +289,6 @@ def boot(self, options: Options):
list(options),
stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
start_new_session=True,
)
start_time = time.time()
timeout = 10
Expand All @@ -314,7 +309,7 @@ def boot(self, options: Options):
self.process.wait()
raise

def quit(self):
def quit(self) -> None:
if not self.is_running:
return
self.process.terminate()
Expand Down Expand Up @@ -346,7 +341,7 @@ async def boot(self, options: Options):
self.error_text = ""
self.buffer_ = ""
_, _ = await loop.subprocess_exec(
lambda: self, *options, stdin=None, stderr=None, start_new_session=True
lambda: self, *options, stdin=None, stderr=None
)
if not (await self.boot_future):
raise ServerCannotBoot(self.error_text)
Expand Down Expand Up @@ -380,23 +375,23 @@ def pipe_data_received(self, fd, data):
self.buffer_ = text

def process_exited(self):
self.is_running = False
self.exit_future.set_result(None)
if not self.boot_future.done():
self.boot_future.set_result(False)
logger.info(f"Process exited with {self.transport.get_returncode()}.")
self.is_running = False
try:
self.exit_future.set_result(None)
if not self.boot_future.done():
self.boot_future.set_result(False)
except asyncio.exceptions.InvalidStateError:
pass

def quit(self):
async def quit(self):
logger.info("Quitting ...")
if not self.is_running:
logger.info("... already quit!")
return
if not self.boot_future.done():
self.boot_future.set_result(False)
if not self.exit_future.done():
self.exit_future.set_result
self.transport.close()
self.is_running = False
self.transport.close()
await self.exit_future
logger.info("... quit!")


Expand Down
4 changes: 2 additions & 2 deletions tests/test_osc.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def on_healthcheck_failed():
assert osc_protocol.is_running
assert not healthcheck_failed
await asyncio.sleep(1)
process_protocol.quit()
await process_protocol.quit()
for _ in range(20):
await asyncio.sleep(1)
if not osc_protocol.is_running:
break
assert healthcheck_failed
assert not osc_protocol.is_running
finally:
process_protocol.quit()
await process_protocol.quit()


def test_ThreadedOscProtocol():
Expand Down

0 comments on commit 2b27ff9

Please sign in to comment.