Skip to content

Commit

Permalink
Fix replace_node on nodes with indirect node refs
Browse files Browse the repository at this point in the history
Model.replace_node failed if a node n contained a reference to another
node m, when m only occurs in one of the (grand)children of n.  Since
replace_node used the breadth-first Model.walk method, the reference to
m would be replaced before m is replaced in the (grand)children of n.
However, this fails because set_ref verifies that the a new node
reference occurs in the graph.

This change resolves this issue by replacing nodes using a depth-first
search in post-order. This amounts to doing a topological sort of the
transposed graph.
  • Loading branch information
danieldk committed Sep 21, 2021
1 parent f8ddf47 commit 6954b30
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 2 deletions.
29 changes: 27 additions & 2 deletions thinc/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, List, Callable, Optional, Any, Union, Iterable, Set, cast
from typing import Generic, Sequence, Tuple, TypeVar
from typing import Generic, Sequence, Tuple, TypeVar, Iterator
import contextlib
from contextvars import ContextVar
import srsly
Expand Down Expand Up @@ -363,6 +363,28 @@ def walk(self) -> Iterable["Model"]:
yield node
queue.extend(node.layers)

def walk_dfs(self, post_order=False) -> Iterable["Model"]:
"""Iterate out layers of the model, depth-first."""
seen: Dict[int, Iterator["Model"]] = dict()
stack = [self]
seen[id(self)] = iter(self.layers)
if not post_order:
yield self

while stack:
try:
next_child = next(seen[id(stack[-1])])
if not id(next_child) in seen:
if not post_order:
yield next_child

stack.append(next_child)
seen[id(next_child)] = iter(next_child.layers)
except StopIteration:
if post_order:
yield stack[-1]
stack.pop()

def remove_node(self, node: "Model") -> None:
"""Remove a node from all layers lists, and then update references.
References that no longer point to a node within the tree will be set
Expand All @@ -385,7 +407,10 @@ def replace_node(self, old: "Model", new: "Model") -> bool:
indicating whether the replacement was made."""
seen = False

for node in list(self.walk()):
# We need to replace nodes in topological order of the transposed graph
# to ensure that a node's dependencies are processed before the node.
# This is equivalent to a post-order traversal of the original graph.
for node in list(self.walk_dfs(post_order=True)):
if node is old:
seen = True
else:
Expand Down
54 changes: 54 additions & 0 deletions thinc/tests/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,36 @@ def test_replace_node():
assert debug.layers[3].layers[1] == relu2


def test_replace_node_with_indirect_node_ref():
# a
# / \
# x b[y=y]
# | |
# y x
# |
# y

def bogus_model(name, layers):
return Model(name, lambda model, X, is_train: ..., layers=layers)

y = bogus_model("y", [])
x = bogus_model("x", [y])

y_debug = with_debug(y)

b = bogus_model("b", [x])
b.set_ref("y", y)

a = chain(x, b)
a.name = "a"

a.replace_node(y, y_debug)

assert a.layers[0].layers[0] == y_debug
assert a.layers[1].layers[0].layers[0] == y_debug
assert a.layers[1].get_ref("y") == y_debug


def test_recursive_wrap():
# Check:
#
Expand Down Expand Up @@ -518,3 +548,27 @@ def test_recursive_double_wrap():
assert concat_debug.layers[0].layers[1].layers[0].layers[0].name == "debug(relu)"
assert concat_debug.layers[0].layers[1].layers[0].layers[1].name == "debug(relu)"
assert concat_debug.layers[0].layers[2].name == "debug(relu)"


def test_wrap_non_child_references():
relu = Relu(5)
relu2 = Relu(5)
chained = chain(relu, relu)
chained2 = chain(relu2, chained)
chained2.set_ref("relu", relu)
# Fails in case non-child references cannot be set.
wrap_model_recursive(chained2, with_debug)


def test_walk_dfs():
relu = Relu(5)
relu2 = Relu(5)
inner_chain = chain(relu, relu2)
chained = chain(inner_chain, inner_chain)
assert list(chained.walk_dfs()) == [chained, inner_chain, relu, relu2]
assert list(chained.walk_dfs(post_order=True)) == [
relu,
relu2,
inner_chain,
chained,
]

0 comments on commit 6954b30

Please sign in to comment.