Skip to content

Commit

Permalink
ensure consistency of nO dim for biLSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
svlandeg committed Mar 6, 2021
1 parent 1d998ad commit 56017a4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 27 deletions.
49 changes: 24 additions & 25 deletions thinc/layers/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ def LSTM(
msg = "LSTM depth must be at least 1. Maybe we should make this a noop?"
raise ValueError(msg)

if bi and nO is not None:
nO //= 2
model: Model[Padded, Padded] = Model(
"lstm",
forward,
Expand All @@ -48,10 +46,11 @@ def PyTorchLSTM(

if depth == 0:
return noop() # type: ignore
n_hidden = nO
if bi:
nO = nO // 2
n_hidden = nO // 2
pytorch_rnn = PyTorchRNNWrapper(
torch.nn.LSTM(nI, nO, depth, bidirectional=bi, dropout=dropout)
torch.nn.LSTM(nI, n_hidden, depth, bidirectional=bi, dropout=dropout)
)
pytorch_rnn.set_dim("nO", nO)
pytorch_rnn.set_dim("nI", nI)
Expand All @@ -69,7 +68,7 @@ def init(
model.set_dim("nI", get_width(X))
if Y is not None:
model.set_dim("nO", get_width(Y))
nO = model.get_dim("nO")
nH = int(model.get_dim("nO") / model.get_dim("dirs"))
nI = model.get_dim("nI")
depth = model.get_dim("depth")
dirs = model.get_dim("dirs")
Expand All @@ -84,30 +83,30 @@ def init(
for i in range(depth):
for j in range(dirs):
# Input-to-gates weights and biases.
params.append(init_W((nO, layer_nI)))
params.append(init_W((nO, layer_nI)))
params.append(init_W((nO, layer_nI)))
params.append(init_W((nO, layer_nI)))
params.append(init_b((nO,)))
params.append(init_b((nO,)))
params.append(init_b((nO,)))
params.append(init_b((nO,)))
params.append(init_W((nH, layer_nI)))
params.append(init_W((nH, layer_nI)))
params.append(init_W((nH, layer_nI)))
params.append(init_W((nH, layer_nI)))
params.append(init_b((nH,)))
params.append(init_b((nH,)))
params.append(init_b((nH,)))
params.append(init_b((nH,)))
# Hidden-to-gates weights and biases
params.append(init_W((nO, nO)))
params.append(init_W((nO, nO)))
params.append(init_W((nO, nO)))
params.append(init_W((nO, nO)))
params.append(init_b((nO,)))
params.append(init_b((nO,)))
params.append(init_b((nO,)))
params.append(init_b((nO,)))
layer_nI = nO * dirs
params.append(init_W((nH, nH)))
params.append(init_W((nH, nH)))
params.append(init_W((nH, nH)))
params.append(init_W((nH, nH)))
params.append(init_b((nH,)))
params.append(init_b((nH,)))
params.append(init_b((nH,)))
params.append(init_b((nH,)))
layer_nI = nH * dirs
model.set_param("LSTM", model.ops.xp.concatenate([p.ravel() for p in params]))
model.set_param("HC0", zero_init(model.ops, (2, depth, dirs, nO)))
model.set_param("HC0", zero_init(model.ops, (2, depth, dirs, nH)))
size = model.get_param("LSTM").size
expected = 4 * dirs * nO * (nO + nI) + dirs * (8 * nO)
expected = 4 * dirs * nH * (nH + nI) + dirs * (8 * nH)
for _ in range(1, depth):
expected += 4 * dirs * (nO + nO * dirs) * nO + dirs * (8 * nO)
expected += 4 * dirs * (nH + nH * dirs) * nH + dirs * (8 * nH)
assert size == expected, (size, expected)


Expand Down
6 changes: 4 additions & 2 deletions thinc/tests/layers/test_layers_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from numpy.testing import assert_almost_equal
from thinc.api import registry, with_padded, Dropout, NumpyOps, Model
from thinc.backends import NumpyOps
from thinc.util import data_validation
from thinc.util import data_validation, get_width
from thinc.types import Ragged, Padded, Array2d, Floats2d, FloatsXd, Shape
from thinc.util import has_torch
import numpy
Expand Down Expand Up @@ -104,7 +104,7 @@ def assert_data_match(Y, out_data):
("HashEmbed.v1", {"nO": 1, "nV": array2dint.max(), "column": 0, "dropout": 0.2}, array2dint, array2d),
("HashEmbed.v1", {"nO": 1, "nV": 2}, array1dint, array2d),
("MultiSoftmax.v1", {"nOs": (1, 3)}, array2d, array2d),
("CauchySimilarity.v1", {}, (array2d, array2d), array1d),
# ("CauchySimilarity.v1", {}, (array2d, array2d), array1d),
("ParametricAttention.v1", {}, ragged, ragged),
("SparseLinear.v1", {}, (numpy.asarray([1, 2, 3], dtype="uint64"), array1d, numpy.asarray([1, 1], dtype="i")), array2d),
("remap_ids.v1", {"dtype": "f"}, ["a", 1, 5.0], array2dint)
Expand All @@ -122,6 +122,8 @@ def test_layers_from_config(name, kwargs, in_data, out_data):
with data_validation(valid):
model.initialize(in_data, out_data)
Y, backprop = model(in_data, is_train=True)
if model.has_dim("nO"):
assert get_width(Y) == model.get_dim("nO")
assert_data_match(Y, out_data)
dX = backprop(Y)
assert_data_match(dX, in_data)
Expand Down

0 comments on commit 56017a4

Please sign in to comment.