From 6954b30b82a8538b681f5943973f051a8173e634 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 21 Sep 2021 09:26:36 +0200 Subject: [PATCH] Fix replace_node on nodes with indirect node refs Model.replace_node failed if a node n contained a reference to another node m, when m only occurs in one of the (grand)children of n. Since replace_node used the breadth-first Model.walk method, the reference to m would be replaced before m is replaced in the (grand)children of n. However, this fails because set_ref verifies that the a new node reference occurs in the graph. This change resolves this issue by replacing nodes using a depth-first search in post-order. This amounts to doing a topological sort of the transposed graph. --- thinc/model.py | 29 ++++++++++++++++-- thinc/tests/model/test_model.py | 54 +++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/thinc/model.py b/thinc/model.py index 988af5be7..f4edcde98 100644 --- a/thinc/model.py +++ b/thinc/model.py @@ -1,5 +1,5 @@ from typing import Dict, List, Callable, Optional, Any, Union, Iterable, Set, cast -from typing import Generic, Sequence, Tuple, TypeVar +from typing import Generic, Sequence, Tuple, TypeVar, Iterator import contextlib from contextvars import ContextVar import srsly @@ -363,6 +363,28 @@ def walk(self) -> Iterable["Model"]: yield node queue.extend(node.layers) + def walk_dfs(self, post_order=False) -> Iterable["Model"]: + """Iterate out layers of the model, depth-first.""" + seen: Dict[int, Iterator["Model"]] = dict() + stack = [self] + seen[id(self)] = iter(self.layers) + if not post_order: + yield self + + while stack: + try: + next_child = next(seen[id(stack[-1])]) + if not id(next_child) in seen: + if not post_order: + yield next_child + + stack.append(next_child) + seen[id(next_child)] = iter(next_child.layers) + except StopIteration: + if post_order: + yield stack[-1] + stack.pop() + def remove_node(self, node: "Model") -> None: """Remove a node from all layers lists, and then update references. References that no longer point to a node within the tree will be set @@ -385,7 +407,10 @@ def replace_node(self, old: "Model", new: "Model") -> bool: indicating whether the replacement was made.""" seen = False - for node in list(self.walk()): + # We need to replace nodes in topological order of the transposed graph + # to ensure that a node's dependencies are processed before the node. + # This is equivalent to a post-order traversal of the original graph. + for node in list(self.walk_dfs(post_order=True)): if node is old: seen = True else: diff --git a/thinc/tests/model/test_model.py b/thinc/tests/model/test_model.py index 13c2b818f..090a93e26 100644 --- a/thinc/tests/model/test_model.py +++ b/thinc/tests/model/test_model.py @@ -465,6 +465,36 @@ def test_replace_node(): assert debug.layers[3].layers[1] == relu2 +def test_replace_node_with_indirect_node_ref(): + # a + # / \ + # x b[y=y] + # | | + # y x + # | + # y + + def bogus_model(name, layers): + return Model(name, lambda model, X, is_train: ..., layers=layers) + + y = bogus_model("y", []) + x = bogus_model("x", [y]) + + y_debug = with_debug(y) + + b = bogus_model("b", [x]) + b.set_ref("y", y) + + a = chain(x, b) + a.name = "a" + + a.replace_node(y, y_debug) + + assert a.layers[0].layers[0] == y_debug + assert a.layers[1].layers[0].layers[0] == y_debug + assert a.layers[1].get_ref("y") == y_debug + + def test_recursive_wrap(): # Check: # @@ -518,3 +548,27 @@ def test_recursive_double_wrap(): assert concat_debug.layers[0].layers[1].layers[0].layers[0].name == "debug(relu)" assert concat_debug.layers[0].layers[1].layers[0].layers[1].name == "debug(relu)" assert concat_debug.layers[0].layers[2].name == "debug(relu)" + + +def test_wrap_non_child_references(): + relu = Relu(5) + relu2 = Relu(5) + chained = chain(relu, relu) + chained2 = chain(relu2, chained) + chained2.set_ref("relu", relu) + # Fails in case non-child references cannot be set. + wrap_model_recursive(chained2, with_debug) + + +def test_walk_dfs(): + relu = Relu(5) + relu2 = Relu(5) + inner_chain = chain(relu, relu2) + chained = chain(inner_chain, inner_chain) + assert list(chained.walk_dfs()) == [chained, inner_chain, relu, relu2] + assert list(chained.walk_dfs(post_order=True)) == [ + relu, + relu2, + inner_chain, + chained, + ]