From 87fd875f64cdbe74fb12066dc1e376a0eb8841c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Wed, 11 Aug 2021 20:30:32 +0200 Subject: [PATCH] Fix numpy_ops gemm output semantics when BLIS is used The `out` keyword argument of the `gemm` op specifies an output array. However, the semantics were different depending on whether BLIS is used: * `use_blis==False`: the values of out array are overwritten. * `use_blis==True`: the values of out are added to the result. With this change, the values of `out` are also overwritten with `use_blis=True`. --- thinc/backends/numpy_ops.pyx | 2 +- thinc/tests/backends/test_ops.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/thinc/backends/numpy_ops.pyx b/thinc/backends/numpy_ops.pyx index 60b49696a..6d96359f9 100644 --- a/thinc/backends/numpy_ops.pyx +++ b/thinc/backends/numpy_ops.pyx @@ -88,7 +88,7 @@ class NumpyOps(Ops): y = self.as_contig(y) if out is not None: out = self.as_contig(out) - return blis.py.gemm(x, y, out=out, trans1=trans1, trans2=trans2) + return blis.py.gemm(x, y, out=out, trans1=trans1, trans2=trans2, beta=0.) def relu(self, np.ndarray X, inplace=False): cdef np.ndarray out = X if inplace else X.copy() diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index 4e805f509..1970f629f 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -275,6 +275,14 @@ def test_gemm_computes_correctly(cpu_ops): cpu_ops.gemm(X, W, trans1=True, out=Y) +@pytest.mark.parametrize("cpu_ops", [*CPU_OPS, BLIS_OPS]) +def test_gemm_out_used(cpu_ops): + a = b = numpy.zeros((2, 2), dtype="f") + c = numpy.ones((2, 2), dtype="f") + cpu_ops.gemm(a, b, out=c) + assert numpy.array_equal(c, numpy.zeros((2, 2))) + + @pytest.mark.parametrize("cpu_ops", CPU_OPS) @settings(max_examples=MAX_EXAMPLES, deadline=None) @given(X=strategies.arrays_BI())