Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add signal aliases on SignalGroup #299

Merged
merged 17 commits into from
Mar 22, 2024
6 changes: 1 addition & 5 deletions src/psygnal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,7 @@
from ._evented_decorator import evented
from ._exceptions import EmitLoopError
from ._group import EmissionInfo, SignalGroup
from ._group_descriptor import (
SignalGroupDescriptor,
get_evented_namespace,
is_evented,
)
from ._group_descriptor import SignalGroupDescriptor, get_evented_namespace, is_evented
from ._queue import emit_queued
from ._signal import Signal, SignalInstance, _compiled
from ._throttler import debounced, throttled
Expand Down
25 changes: 16 additions & 9 deletions src/psygnal/_evented_decorator.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
from __future__ import annotations

from typing import (
Any,
Callable,
Literal,
TypeVar,
overload,
)
from typing import TYPE_CHECKING, Callable, Literal, Mapping, TypeVar, overload

from psygnal._group_descriptor import SignalGroupDescriptor

if TYPE_CHECKING:
from psygnal._group_descriptor import EqOperator, FieldAliasFunc

__all__ = ["evented"]

T = TypeVar("T", bound=type)

EqOperator = Callable[[Any, Any], bool]


@overload
def evented(
Expand All @@ -25,6 +20,7 @@ def evented(
equality_operators: dict[str, EqOperator] | None = None,
warn_on_no_fields: bool = ...,
cache_on_instance: bool = ...,
signal_aliases: Mapping[str, str | None] | FieldAliasFunc | None = ...,
) -> T: ...


Expand All @@ -36,6 +32,7 @@ def evented(
equality_operators: dict[str, EqOperator] | None = None,
warn_on_no_fields: bool = ...,
cache_on_instance: bool = ...,
signal_aliases: Mapping[str, str | None] | FieldAliasFunc | None = ...,
) -> Callable[[T], T]: ...


Expand All @@ -46,6 +43,7 @@ def evented(
equality_operators: dict[str, EqOperator] | None = None,
warn_on_no_fields: bool = True,
cache_on_instance: bool = True,
signal_aliases: Mapping[str, str | None] | FieldAliasFunc | None = None,
) -> Callable[[T], T] | T:
"""A decorator to add events to a dataclass.

Expand Down Expand Up @@ -85,6 +83,14 @@ def evented(
access, but means that the owner instance will no longer be pickleable. If
`False`, the SignalGroup instance will *still* be cached, but not on the
instance itself.
signal_aliases: Mapping[str, str | None] | Callable[[str], str | None] | None
If defined, a mapping between field name and signal name. Field names that are
not `signal_aliases` keys are not aliased (the signal name is the field name).
If the dict value is None, do not create a signal associated with this field.
If a callable, the signal name is the output of the function applied to the
field name. If the output is None, no signal is created for this field.
If None, defaults to an empty dict, no aliases.
Default to None

Returns
-------
Expand Down Expand Up @@ -122,6 +128,7 @@ def _decorate(cls: T) -> T:
equality_operators=equality_operators,
warn_on_no_fields=warn_on_no_fields,
cache_on_instance=cache_on_instance,
signal_aliases=signal_aliases,
)
# as a decorator, this will have already been called
descriptor.__set_name__(cls, events_namespace)
Expand Down
9 changes: 8 additions & 1 deletion src/psygnal/_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ class MySignals(SignalGroup):
_psygnal_signals: ClassVar[Mapping[str, Signal]]
_psygnal_uniform: ClassVar[bool] = False
_psygnal_name_conflicts: ClassVar[set[str]]
_psygnal_aliases: ClassVar[dict[str, str | None]]

_psygnal_instances: dict[str, SignalInstance]

Expand All @@ -280,7 +281,11 @@ def __init__(self, instance: Any = None) -> None:
}
self._psygnal_relay = SignalRelay(self._psygnal_instances, instance)

def __init_subclass__(cls, strict: bool = False) -> None:
def __init_subclass__(
cls,
strict: bool = False,
signal_aliases: Mapping[str, str | None] = {},
) -> None:
"""Collects all Signal instances on the class under `cls._psygnal_signals`."""
# Collect Signals and remove from class attributes
# Use dir(cls) instead of cls.__dict__ to get attributes from super()
Expand Down Expand Up @@ -328,6 +333,8 @@ def __init_subclass__(cls, strict: bool = False) -> None:
stacklevel=2,
)

aliases = getattr(cls, "_psygnal_aliases", {})
cls._psygnal_aliases = {**aliases, **signal_aliases}
cls._psygnal_uniform = _is_uniform(cls._psygnal_signals.values())
if strict and not cls._psygnal_uniform:
raise TypeError(
Expand Down
Loading
Loading