Skip to content

Commit

Permalink
Fix numpy_ops gemm output semantics when BLIS is used
Browse files Browse the repository at this point in the history
The `out` keyword argument of the `gemm` op specifies an output array,
However, the semantics were different depending on whether BLIS is used:

* `use_blis==False`: the values of out array are overwritten.
* `use_blis==True`: the values of out are added to the result.

With this change, the values of `out` are also overwritten with
`use_blis=True`.
  • Loading branch information
danieldk committed Aug 11, 2021
1 parent a125b8c commit fb84cc5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class NumpyOps(Ops):
y = self.as_contig(y)
if out is not None:
out = self.as_contig(out)
return blis.py.gemm(x, y, out=out, trans1=trans1, trans2=trans2)
return blis.py.gemm(x, y, out=out, trans1=trans1, trans2=trans2, beta=0.)

def relu(self, np.ndarray X, inplace=False):
cdef np.ndarray out = X if inplace else X.copy()
Expand Down
8 changes: 8 additions & 0 deletions thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,14 @@ def test_gemm_computes_correctly(cpu_ops):
cpu_ops.gemm(X, W, trans1=True, out=Y)


@pytest.mark.parametrize("cpu_ops", [*CPU_OPS, BLIS_OPS])
def test_gemm_out_used(cpu_ops):
a = b = numpy.zeros((2, 2), dtype="f")
c = numpy.ones((2, 2), dtype="f")
cpu_ops.gemm(a, b, out=c)
assert numpy.array_equal(c, numpy.zeros((2, 2)))


@pytest.mark.parametrize("cpu_ops", CPU_OPS)
@settings(max_examples=MAX_EXAMPLES, deadline=None)
@given(X=strategies.arrays_BI())
Expand Down

0 comments on commit fb84cc5

Please sign in to comment.