Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix replace_node on nodes with indirect node refs #537

Merged
merged 3 commits into from
Sep 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions thinc/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, List, Callable, Optional, Any, Union, Iterable, Set, cast
from typing import Generic, Sequence, Tuple, TypeVar
from typing import Generic, Sequence, Tuple, TypeVar, Iterator
import contextlib
from contextvars import ContextVar
import srsly
Expand Down Expand Up @@ -352,7 +352,22 @@ def use_params(self, params: Dict[Tuple[int, str], FloatsXd]):
for name, param in backup.items():
self.set_param(name, param)

def walk(self) -> Iterable["Model"]:
def walk(self, *, order: str = "bfs") -> Iterable["Model"]:
"""Iterate out layers of the model.

Nodes are returned in breadth-first order by default. Other possible
orders are "dfs_pre" (depth-first search in preorder) and "dfs_post"
(depth-first search in postorder)."""
if order == "bfs":
return self._walk_bfs()
elif order == "dfs_pre":
return self._walk_dfs(post_order=False)
elif order == "dfs_post":
return self._walk_dfs(post_order=True)
else:
raise ValueError("Invalid order, must be one of: bfs, dfs_pre, dfs_post")

def _walk_bfs(self) -> Iterable["Model"]:
"""Iterate out layers of the model, breadth-first."""
queue = [self]
seen: Set[int] = set()
Expand All @@ -363,6 +378,28 @@ def walk(self) -> Iterable["Model"]:
yield node
queue.extend(node.layers)

def _walk_dfs(self, post_order: bool = False) -> Iterable["Model"]:
"""Iterate out layers of the model, depth-first."""
seen: Dict[int, Iterator["Model"]] = dict()
stack = [self]
seen[id(self)] = iter(self.layers)
if not post_order:
yield self

while stack:
try:
next_child = next(seen[id(stack[-1])])
if not id(next_child) in seen:
if not post_order:
yield next_child

stack.append(next_child)
seen[id(next_child)] = iter(next_child.layers)
except StopIteration:
if post_order:
yield stack[-1]
stack.pop()

def remove_node(self, node: "Model") -> None:
"""Remove a node from all layers lists, and then update references.
References that no longer point to a node within the tree will be set
Expand All @@ -385,7 +422,10 @@ def replace_node(self, old: "Model", new: "Model") -> bool:
indicating whether the replacement was made."""
seen = False

for node in list(self.walk()):
# We need to replace nodes in topological order of the transposed graph
# to ensure that a node's dependencies are processed before the node.
# This is equivalent to a post-order traversal of the original graph.
for node in list(self.walk(order="dfs_post")):
if node is old:
seen = True
else:
Expand Down
60 changes: 60 additions & 0 deletions thinc/tests/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,36 @@ def test_replace_node():
assert debug.layers[3].layers[1] == relu2


def test_replace_node_with_indirect_node_ref():
# a
# / \
# x b[y=y]
# | |
# y x
# |
# y

def dummy_model(name, layers):
return Model(name, lambda model, X, is_train: ..., layers=layers)

y = dummy_model("y", [])
x = dummy_model("x", [y])

y_debug = with_debug(y)

b = dummy_model("b", [x])
b.set_ref("y", y)

a = chain(x, b)
a.name = "a"

a.replace_node(y, y_debug)

assert a.layers[0].layers[0] == y_debug
assert a.layers[1].layers[0].layers[0] == y_debug
assert a.layers[1].get_ref("y") == y_debug


def test_recursive_wrap():
# Check:
#
Expand Down Expand Up @@ -518,3 +548,33 @@ def test_recursive_double_wrap():
assert concat_debug.layers[0].layers[1].layers[0].layers[0].name == "debug(relu)"
assert concat_debug.layers[0].layers[1].layers[0].layers[1].name == "debug(relu)"
assert concat_debug.layers[0].layers[2].name == "debug(relu)"


def test_wrap_non_child_references():
relu = Relu(5)
relu2 = Relu(5)
chained = chain(relu, relu)
chained2 = chain(relu2, chained)
chained2.set_ref("relu", relu)
# Fails in case non-child references cannot be set.
wrap_model_recursive(chained2, with_debug)


def test_walk_dfs():
relu = Relu(5)
relu2 = Relu(5)
inner_chain = chain(relu, relu2)
chained = chain(inner_chain, inner_chain)
assert list(chained.walk(order="dfs_pre")) == [chained, inner_chain, relu, relu2]
assert list(chained.walk(order="dfs_post")) == [
relu,
relu2,
inner_chain,
chained,
]


def test_walk_bfs_post_order_fails():
relu = Relu(5)
with pytest.raises(ValueError, match="Invalid order"):
relu.walk(order="dfs_post_order")