diff --git a/thinc/backends/numpy_ops.pyx b/thinc/backends/numpy_ops.pyx index 60b49696a..6d96359f9 100644 --- a/thinc/backends/numpy_ops.pyx +++ b/thinc/backends/numpy_ops.pyx @@ -88,7 +88,7 @@ class NumpyOps(Ops): y = self.as_contig(y) if out is not None: out = self.as_contig(out) - return blis.py.gemm(x, y, out=out, trans1=trans1, trans2=trans2) + return blis.py.gemm(x, y, out=out, trans1=trans1, trans2=trans2, beta=0.) def relu(self, np.ndarray X, inplace=False): cdef np.ndarray out = X if inplace else X.copy() diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index 4e805f509..1970f629f 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -275,6 +275,14 @@ def test_gemm_computes_correctly(cpu_ops): cpu_ops.gemm(X, W, trans1=True, out=Y) +@pytest.mark.parametrize("cpu_ops", [*CPU_OPS, BLIS_OPS]) +def test_gemm_out_used(cpu_ops): + a = b = numpy.zeros((2, 2), dtype="f") + c = numpy.ones((2, 2), dtype="f") + cpu_ops.gemm(a, b, out=c) + assert numpy.array_equal(c, numpy.zeros((2, 2))) + + @pytest.mark.parametrize("cpu_ops", CPU_OPS) @settings(max_examples=MAX_EXAMPLES, deadline=None) @given(X=strategies.arrays_BI())