Skip to content

Commit

Permalink
refactor: Add back SignalGroup methods (#286)
Browse files Browse the repository at this point in the history
* Credit getzze

* add back group methods

* test: add test_monitor_group

* wip

* tests passing

* fix typing and add comment

* use annotations rather than __dict__

* minor name change

* fix test

* test slotted classes

* fix pyproject

* fix cov

* test: update deepcopy test

* fix py38

* add test for set_name

* more tests

---------

Co-authored-by: getzze <getzze@gmail.com>
  • Loading branch information
tlambert03 and getzze authored Mar 4, 2024
1 parent 53a5fb7 commit 023722c
Show file tree
Hide file tree
Showing 12 changed files with 451 additions and 134 deletions.
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@ default_install_hook_types: [pre-commit, commit-msg]
exclude: .asv

repos:
- repo: https://github.com/compilerla/conventional-pre-commit
rev: v3.1.0
hooks:
- id: conventional-pre-commit
stages: [commit-msg]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ ignore_errors = true
module = ["wrapt"]
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = ["tests.*"]
disallow_untyped_defs = false

# https://coverage.readthedocs.io/en/6.4/config.html
[tool.coverage.report]
exclude_lines = [
Expand Down
2 changes: 2 additions & 0 deletions src/psygnal/_evented_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class Person:
def _decorate(cls: T) -> T:
if not isinstance(cls, type): # pragma: no cover
raise TypeError("evented can only be used on classes")
if any(k.startswith("_psygnal") for k in getattr(cls, "__annotations__", {})):
raise TypeError("Fields on an evented class cannot start with '_psygnal'")

descriptor = SignalGroupDescriptor(
equality_operators=equality_operators,
Expand Down
220 changes: 168 additions & 52 deletions src/psygnal/_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from __future__ import annotations

import warnings
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -22,14 +23,18 @@
Literal,
Mapping,
NamedTuple,
overload,
)

from psygnal._signal import Signal, SignalInstance, _SignalBlocker
from psygnal._signal import _NULL, Signal, SignalInstance, _SignalBlocker

from ._mypyc import mypyc_attr

if TYPE_CHECKING:
from psygnal._weak_callback import WeakCallback
import threading

from psygnal._signal import F, ReducerFunc
from psygnal._weak_callback import RefErrorChoice, WeakCallback

__all__ = ["EmissionInfo", "SignalGroup"]

Expand Down Expand Up @@ -65,7 +70,7 @@ def __init__(
self, signals: Mapping[str, SignalInstance], instance: Any = None
) -> None:
super().__init__(signature=(EmissionInfo,), instance=instance)
self._signals = signals
self._signals = MappingProxyType(signals)
self._sig_was_blocked: dict[str, bool] = {}

def _append_slot(self, slot: WeakCallback) -> None:
Expand Down Expand Up @@ -251,53 +256,75 @@ class MySignals(SignalGroup):
"""

_psygnal_signals: ClassVar[Mapping[str, Signal]]
_psygnal_instances: dict[str, SignalInstance]
_psygnal_uniform: ClassVar[bool] = False
_psygnal_name_conflicts: ClassVar[set[str]]

_psygnal_instances: dict[str, SignalInstance]

def __init__(self, instance: Any = None) -> None:
cls = type(self)
if not hasattr(cls, "_psygnal_signals"): # pragma: no cover
if not hasattr(cls, "_psygnal_signals"):
raise TypeError(
"Cannot instantiate `SignalGroup` directly. Use a subclass instead."
)

self._psygnal_instances = {
name: sig.__get__(self, cls) for name, sig in cls._psygnal_signals.items()
name: (
sig._create_signal_instance(self)
if name in cls._psygnal_name_conflicts
else sig.__get__(self, cls)
)
for name, sig in cls._psygnal_signals.items()
}
self._psygnal_relay = SignalRelay(self._psygnal_instances, instance)

def __init_subclass__(cls, strict: bool = False) -> None:
"""Collects all Signal instances on the class under `cls._psygnal_signals`."""
cls._psygnal_signals = {
k: val
for k, val in getattr(cls, "__dict__", {}).items()
if isinstance(val, Signal)
# Collect Signals and remove from class attributes
# Use dir(cls) instead of cls.__dict__ to get attributes from super()
forbidden = {
k for k in getattr(cls, "__dict__", ()) if k.startswith("_psygnal")
}

if conflicts := {k for k in cls._psygnal_signals if k.startswith("_psygnal")}:
warnings.warn(
"Signal names may not begin with '_psygnal'. "
f"Skipping signals: {conflicts}",
stacklevel=2,
if forbidden:
raise TypeError(
f"SignalGroup subclass cannot have attributes starting with '_psygnal'."
f" Found: {forbidden}"
)
for key in conflicts:
del cls._psygnal_signals[key]

if "all" in cls._psygnal_signals:
_psygnal_signals = {}
for k in dir(cls):
val = getattr(cls, k, None)
if isinstance(val, Signal):
_psygnal_signals[k] = val

# Collect the Signals also from super-class
# When subclassing, the Signals have been removed from the attributes,
# look for cls._psygnal_signals also
cls._psygnal_signals = {
**getattr(cls, "_psygnal_signals", {}),
**_psygnal_signals,
}

# Emit warning for signal names conflicting with SignalGroup attributes
reserved = set(dir(SignalGroup))
cls._psygnal_name_conflicts = conflicts = {
k
for k in cls._psygnal_signals
if k in reserved or k.startswith(("_psygnal", "psygnal"))
}
if conflicts:
for name in conflicts:
if isinstance(getattr(cls, name), Signal):
delattr(cls, name)
Names = "Names" if len(conflicts) > 1 else "Name"
Are = "are" if len(conflicts) > 1 else "is"
warnings.warn(
"Name 'all' is reserved for the SignalRelay. You cannot use this "
"name on to access a SignalInstance on a SignalGroup. (You may still "
"access it at `group['all']`).",
f"{Names} {sorted(conflicts)!r} {Are} reserved. You cannot use these "
"names to access SignalInstances as attributes on a SignalGroup. (You "
"may still access them as keys to __getitem__: `group['name']`).",
UserWarning,
stacklevel=2,
)
delattr(cls, "all")

if "psygnals_uniform" in cls._psygnal_signals:
raise NameError(
"Name 'psygnals_uniform' is reserved. You cannot use this "
"name as a signal on a SignalGroup"
)

cls._psygnal_uniform = _is_uniform(cls._psygnal_signals.values())
if strict and not cls._psygnal_uniform:
Expand Down Expand Up @@ -325,29 +352,6 @@ class MySignals(SignalGroup):
"""
return self._psygnal_relay

# TODO: change type hint to -> SignalInstance after completing deprecation of
# direct access to names on SignalRelay object
def __getattr__(self, name: str) -> Any:
# Note, technically these lines aren't actually needed because of Signal's
# descriptor protocol: Accessing a name on a group instance will first look
# the instance's __dict__, and then in the class's __dict__, which
# will call Signal.__get__ and return the SignalInstance.
# these lines are here as a reminder to developers (and safeguard?).
if name != "_psygnal_instances" and name in self._psygnal_instances:
return self._psygnal_instances[name] # pragma: no cover

if name != "_psygnal_relay" and hasattr(self._psygnal_relay, name):
warnings.warn(
f"Accessing SignalInstance attribute {name!r} on a SignalGroup is "
f"deprecated. Access it on the `group.all` attribute instead. e.g. "
f"`group.all.{name}`. This will be an error in v0.11.",
FutureWarning,
stacklevel=2,
)
return getattr(self._psygnal_relay, name)

raise AttributeError(f"{type(self).__name__!r} has no signal named {name!r}")

@property
def signals(self) -> Mapping[str, SignalInstance]:
"""DEPRECATED: A mapping of signal names to SignalInstance instances."""
Expand All @@ -369,6 +373,16 @@ def __getitem__(self, item: str) -> SignalInstance:
"""Get a signal instance by name."""
return self._psygnal_instances[item]

# this is just here for type checking, particularly on cases
# where the SignalGroup comes from the SignalGroupDescriptor
# (such as in evented dataclasses). In those cases, it's hard to indicate
# to mypy that all remaining attributes are SignalInstances.
def __getattr__(self, __name: str) -> SignalInstance:
"""Get a signal instance by name."""
raise AttributeError( # pragma: no cover
f"{type(self).__name__!r} object has no attribute {__name!r}"
)

def __iter__(self) -> Iterator[str]:
"""Yield the names of all signals in the group."""
return iter(self._psygnal_signals)
Expand Down Expand Up @@ -407,6 +421,108 @@ def __deepcopy__(self, memo: dict[int, Any]) -> SignalGroup:
# it will be a group without any signals connected
return type(self)(instance=self._psygnal_relay.instance)

# The rest are passthrough methods to the SignalRelay.
# The full signatures are here to make mypy and IDEs happy.
# parity with SignalInstance methods is tested in test_group.py

@overload
def connect(
self,
*,
thread: threading.Thread | Literal["main", "current"] | None = ...,
check_nargs: bool | None = ...,
check_types: bool | None = ...,
unique: bool | str = ...,
max_args: int | None = None,
on_ref_error: RefErrorChoice = ...,
) -> Callable[[F], F]: ...

@overload
def connect(
self,
slot: F,
*,
thread: threading.Thread | Literal["main", "current"] | None = ...,
check_nargs: bool | None = ...,
check_types: bool | None = ...,
unique: bool | str = ...,
max_args: int | None = None,
on_ref_error: RefErrorChoice = ...,
) -> F: ...

def connect(
self,
slot: F | None = None,
*,
thread: threading.Thread | Literal["main", "current"] | None = None,
check_nargs: bool | None = None,
check_types: bool | None = None,
unique: bool | str = False,
max_args: int | None = None,
on_ref_error: RefErrorChoice = "warn",
) -> Callable[[F], F] | F:
if slot is None:
return self._psygnal_relay.connect(
thread=thread,
check_nargs=check_nargs,
check_types=check_types,
unique=unique,
max_args=max_args,
on_ref_error=on_ref_error,
)
else:
return self._psygnal_relay.connect(
slot,
thread=thread,
check_nargs=check_nargs,
check_types=check_types,
unique=unique,
max_args=max_args,
on_ref_error=on_ref_error,
)

def connect_direct(
self,
slot: Callable | None = None,
*,
check_nargs: bool | None = None,
check_types: bool | None = None,
unique: bool | str = False,
max_args: int | None = None,
) -> Callable[[Callable], Callable] | Callable:
return self._psygnal_relay.connect_direct(
slot,
check_nargs=check_nargs,
check_types=check_types,
unique=unique,
max_args=max_args,
)

def disconnect(self, slot: Callable | None = None, missing_ok: bool = True) -> None:
return self._psygnal_relay.disconnect(slot=slot, missing_ok=missing_ok)

def block(self, exclude: Iterable[str | SignalInstance] = ()) -> None:
return self._psygnal_relay.block(exclude=exclude)

def unblock(self) -> None:
return self._psygnal_relay.unblock()

def blocked(
self, exclude: Iterable[str | SignalInstance] = ()
) -> ContextManager[None]:
return self._psygnal_relay.blocked(exclude=exclude)

def pause(self) -> None:
return self._psygnal_relay.pause()

def resume(self, reducer: ReducerFunc | None = None, initial: Any = _NULL) -> None:
return self._psygnal_relay.resume(reducer=reducer, initial=initial)

def paused(
self, reducer: ReducerFunc | None = None, initial: Any = _NULL
) -> ContextManager[None]:
return self._psygnal_relay.paused(reducer=reducer, initial=initial)


def _is_uniform(signals: Iterable[Signal]) -> bool:
"""Return True if all signals have the same signature."""
Expand Down
Loading

0 comments on commit 023722c

Please sign in to comment.