diff --git a/requirements.txt b/requirements.txt index 69099ffa2..a87b70e65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ wasabi>=0.8.1,<1.1.0 catalogue>=2.0.3,<2.1.0 ml_datasets>=0.2.0,<0.3.0 # Third-party dependencies -pydantic>=1.7.1,<1.8.0 +pydantic>=1.7.1,<1.8.2 numpy>=1.15.0 # Backports of modern Python features dataclasses>=0.6,<1.0; python_version < "3.7" diff --git a/setup.cfg b/setup.cfg index ea964d135..d5ab2802b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = # Third-party dependencies setuptools numpy>=1.15.0 - pydantic>=1.7.1,<1.8.0 + pydantic>=1.7.1,<1.8.2 # Backports of modern Python features dataclasses>=0.6,<1.0; python_version < "3.7" typing_extensions>=3.7.4.1,<4.0.0.0; python_version < "3.8" diff --git a/thinc/util.py b/thinc/util.py index a6a9bdecb..e29c3fa83 100644 --- a/thinc/util.py +++ b/thinc/util.py @@ -48,8 +48,8 @@ except ImportError: # pragma: no cover has_mxnet = False -from .types import ArrayXd, ArgsKwargs, Ragged, Padded, FloatsXd, IntsXd - +from .types import ArrayXd, ArgsKwargs, Ragged, Padded, FloatsXd, IntsXd # noqa: E402 +from . import types # noqa: E402 def get_array_module(arr): # pragma: no cover if is_cupy_array(arr): @@ -82,7 +82,7 @@ def fix_random_seed(seed: int = 0) -> None: # pragma: no cover def is_xp_array(obj: Any) -> bool: """Check whether an object is a numpy or cupy array.""" - return is_numpy_array(obj) or is_cupy_array(obj) + return is_numpy_array(obj) or is_cupy_array(obj) def is_cupy_array(obj: Any) -> bool: # pragma: no cover @@ -207,7 +207,7 @@ def to_categorical(Y: IntsXd, n_classes: Optional[int] = None) -> FloatsXd: if xp is cupy: # pragma: no cover Y = Y.get() keep_shapes: List[int] = list(Y.shape) - Y = numpy.array(Y, dtype="int").ravel() # type: ignore + Y = numpy.array(Y, dtype="int").ravel() # type: ignore if n_classes is None: n_classes = int(numpy.max(Y) + 1) keep_shapes.append(n_classes) @@ -320,7 +320,7 @@ def xp2tensorflow( """Convert a numpy or cupy tensor to a TensorFlow Tensor or Variable""" assert_tensorflow_installed() if hasattr(xp_tensor, "toDlpack"): - dlpack_tensor = xp_tensor.toDlpack() # type: ignore + dlpack_tensor = xp_tensor.toDlpack() # type: ignore tf_tensor = tensorflow.experimental.dlpack.from_dlpack(dlpack_tensor) else: tf_tensor = tf.convert_to_tensor(xp_tensor) @@ -437,6 +437,9 @@ def validate_fwd_input_output( sig_args["Y"] = (annot_y, ...) args["Y"] = (Y, lambda x: x) ArgModel = create_model("ArgModel", **sig_args) + # Make sure the forward refs are resolved and the types used by them are + # available in the correct scope. See #494 for details. + ArgModel.update_forward_refs(**types.__dict__) try: ArgModel.parse_obj(args) except ValidationError as e: