From aeafccfbfb5c223d33b61ebe0f1e8b5592249151 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 7 Nov 2023 15:01:52 -0600 Subject: [PATCH] [python-package] fix access to Dataset metadata in scikit-learn custom metrics and objectives (#6108) --- python-package/lightgbm/basic.py | 68 +++++++++++++------ python-package/lightgbm/sklearn.py | 69 ++++++++++++++----- tests/python_package_test/test_basic.py | 90 ++++++++++++++++++++++++- tests/python_package_test/utils.py | 20 ++++++ 4 files changed, 209 insertions(+), 38 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 702c4682ea8d..e8d8bd84cbe7 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -434,7 +434,7 @@ def _data_to_2d_numpy( "It should be list of lists, numpy 2-D array or pandas DataFrame") -def _cfloat32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: +def _cfloat32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: """Convert a ctypes float pointer array to a numpy array.""" if isinstance(cptr, ctypes.POINTER(ctypes.c_float)): return np.ctypeslib.as_array(cptr, shape=(length,)).copy() @@ -442,7 +442,7 @@ def _cfloat32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray raise RuntimeError('Expected float pointer') -def _cfloat64_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: +def _cfloat64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: """Convert a ctypes double pointer array to a numpy array.""" if isinstance(cptr, ctypes.POINTER(ctypes.c_double)): return np.ctypeslib.as_array(cptr, shape=(length,)).copy() @@ -450,7 +450,7 @@ def _cfloat64_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray raise RuntimeError('Expected double pointer') -def _cint32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: +def _cint32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: """Convert a ctypes int pointer array to a numpy array.""" if isinstance(cptr, ctypes.POINTER(ctypes.c_int32)): return np.ctypeslib.as_array(cptr, shape=(length,)).copy() @@ -458,7 +458,7 @@ def _cint32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: raise RuntimeError('Expected int32 pointer') -def _cint64_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: +def _cint64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: """Convert a ctypes int pointer array to a numpy array.""" if isinstance(cptr, ctypes.POINTER(ctypes.c_int64)): return np.ctypeslib.as_array(cptr, shape=(length,)).copy() @@ -1295,18 +1295,18 @@ def __create_sparse_native( data_indices_len = out_shape[0] indptr_len = out_shape[1] if indptr_type == _C_API_DTYPE_INT32: - out_indptr = _cint32_array_to_numpy(out_ptr_indptr, indptr_len) + out_indptr = _cint32_array_to_numpy(cptr=out_ptr_indptr, length=indptr_len) elif indptr_type == _C_API_DTYPE_INT64: - out_indptr = _cint64_array_to_numpy(out_ptr_indptr, indptr_len) + out_indptr = _cint64_array_to_numpy(cptr=out_ptr_indptr, length=indptr_len) else: raise TypeError("Expected int32 or int64 type for indptr") if data_type == _C_API_DTYPE_FLOAT32: - out_data = _cfloat32_array_to_numpy(out_ptr_data, data_indices_len) + out_data = _cfloat32_array_to_numpy(cptr=out_ptr_data, length=data_indices_len) elif data_type == _C_API_DTYPE_FLOAT64: - out_data = _cfloat64_array_to_numpy(out_ptr_data, data_indices_len) + out_data = _cfloat64_array_to_numpy(cptr=out_ptr_data, length=data_indices_len) else: raise TypeError("Expected float32 or float64 type for data") - out_indices = _cint32_array_to_numpy(out_ptr_indices, data_indices_len) + out_indices = _cint32_array_to_numpy(cptr=out_ptr_indices, length=data_indices_len) # break up indptr based on number of rows (note more than one matrix in multiclass case) per_class_indptr_shape = cs.indptr.shape[0] # for CSC there is extra column added @@ -2609,6 +2609,12 @@ def set_field( def get_field(self, field_name: str) -> Optional[np.ndarray]: """Get property from the Dataset. + Can only be run on a constructed Dataset. + + Unlike ``get_group()``, ``get_init_score()``, ``get_label()``, ``get_position()``, and ``get_weight()``, + this method ignores any raw data passed into ``lgb.Dataset()`` on the Python side, and will only read + data from the constructed C++ ``Dataset`` object. + Parameters ---------- field_name : str @@ -2635,11 +2641,20 @@ def get_field(self, field_name: str) -> Optional[np.ndarray]: if tmp_out_len.value == 0: return None if out_type.value == _C_API_DTYPE_INT32: - arr = _cint32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)), tmp_out_len.value) + arr = _cint32_array_to_numpy( + cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)), + length=tmp_out_len.value + ) elif out_type.value == _C_API_DTYPE_FLOAT32: - arr = _cfloat32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), tmp_out_len.value) + arr = _cfloat32_array_to_numpy( + cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), + length=tmp_out_len.value + ) elif out_type.value == _C_API_DTYPE_FLOAT64: - arr = _cfloat64_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)), tmp_out_len.value) + arr = _cfloat64_array_to_numpy( + cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)), + length=tmp_out_len.value + ) else: raise TypeError("Unknown type") if field_name == 'init_score': @@ -2878,6 +2893,10 @@ def set_group( if self._handle is not None and group is not None: group = _list_to_1d_numpy(group, dtype=np.int32, name='group') self.set_field('group', group) + # original values can be modified at cpp side + constructed_group = self.get_field('group') + if constructed_group is not None: + self.group = np.diff(constructed_group) return self def set_position( @@ -2941,37 +2960,40 @@ def get_feature_name(self) -> List[str]: ptr_string_buffers)) return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)] - def get_label(self) -> Optional[np.ndarray]: + def get_label(self) -> Optional[_LGBM_LabelType]: """Get the label of the Dataset. Returns ------- - label : numpy array or None + label : list, numpy 1-D array, pandas Series / one-column DataFrame or None The label information from the Dataset. + For a constructed ``Dataset``, this will only return a numpy array. """ if self.label is None: self.label = self.get_field('label') return self.label - def get_weight(self) -> Optional[np.ndarray]: + def get_weight(self) -> Optional[_LGBM_WeightType]: """Get the weight of the Dataset. Returns ------- - weight : numpy array or None + weight : list, numpy 1-D array, pandas Series or None Weight for each data point from the Dataset. Weights should be non-negative. + For a constructed ``Dataset``, this will only return ``None`` or a numpy array. """ if self.weight is None: self.weight = self.get_field('weight') return self.weight - def get_init_score(self) -> Optional[np.ndarray]: + def get_init_score(self) -> Optional[_LGBM_InitScoreType]: """Get the initial score of the Dataset. Returns ------- - init_score : numpy array or None + init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None Init score of Booster. + For a constructed ``Dataset``, this will only return ``None`` or a numpy array. """ if self.init_score is None: self.init_score = self.get_field('init_score') @@ -3009,17 +3031,18 @@ def get_data(self) -> Optional[_LGBM_TrainDataType]: "set free_raw_data=False when construct Dataset to avoid this.") return self.data - def get_group(self) -> Optional[np.ndarray]: + def get_group(self) -> Optional[_LGBM_GroupType]: """Get the group of the Dataset. Returns ------- - group : numpy array or None + group : list, numpy 1-D array, pandas Series or None Group/query data. Only used in the learning-to-rank task. sum(group) = n_samples. For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc. + For a constructed ``Dataset``, this will only return ``None`` or a numpy array. """ if self.group is None: self.group = self.get_field('group') @@ -3028,13 +3051,14 @@ def get_group(self) -> Optional[np.ndarray]: self.group = np.diff(self.group) return self.group - def get_position(self) -> Optional[np.ndarray]: + def get_position(self) -> Optional[_LGBM_PositionType]: """Get the position of the Dataset. Returns ------- - position : numpy 1-D array or None + position : numpy 1-D array, pandas Series or None Position of items used in unbiased learning-to-rank task. + For a constructed ``Dataset``, this will only return ``None`` or a numpy array. """ if self.position is None: self.position = self.get_field('position') diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index c71c233df908..310d5d2ca6ea 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -86,6 +86,36 @@ _LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType] +def _get_group_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarray]: + group = dataset.get_group() + error_msg = ( + "Estimators in lightgbm.sklearn should only retrieve query groups from a constructed Dataset. " + "If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues." + ) + assert (group is None or isinstance(group, np.ndarray)), error_msg + return group + + +def _get_label_from_constructed_dataset(dataset: Dataset) -> np.ndarray: + label = dataset.get_label() + error_msg = ( + "Estimators in lightgbm.sklearn should only retrieve labels from a constructed Dataset. " + "If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues." + ) + assert isinstance(label, np.ndarray), error_msg + return label + + +def _get_weight_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarray]: + weight = dataset.get_weight() + error_msg = ( + "Estimators in lightgbm.sklearn should only retrieve weights from a constructed Dataset. " + "If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues." + ) + assert (weight is None or isinstance(weight, np.ndarray)), error_msg + return weight + + class _ObjectiveFunctionWrapper: """Proxy class for objective function.""" @@ -151,17 +181,22 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np. The value of the second order derivative (Hessian) of the loss with respect to the elements of preds for each sample point. """ - labels = dataset.get_label() + labels = _get_label_from_constructed_dataset(dataset) argc = len(signature(self.func).parameters) if argc == 2: grad, hess = self.func(labels, preds) # type: ignore[call-arg] - elif argc == 3: - grad, hess = self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg] - elif argc == 4: - grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore [call-arg] - else: - raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}") - return grad, hess + return grad, hess + + weight = _get_weight_from_constructed_dataset(dataset) + if argc == 3: + grad, hess = self.func(labels, preds, weight) # type: ignore[call-arg] + return grad, hess + + if argc == 4: + group = _get_group_from_constructed_dataset(dataset) + return self.func(labels, preds, weight, group) # type: ignore[call-arg] + + raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}") class _EvalFunctionWrapper: @@ -229,16 +264,20 @@ def __call__( is_higher_better : bool Is eval result higher better, e.g. AUC is ``is_higher_better``. """ - labels = dataset.get_label() + labels = _get_label_from_constructed_dataset(dataset) argc = len(signature(self.func).parameters) if argc == 2: return self.func(labels, preds) # type: ignore[call-arg] - elif argc == 3: - return self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg] - elif argc == 4: - return self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore[call-arg] - else: - raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}") + + weight = _get_weight_from_constructed_dataset(dataset) + if argc == 3: + return self.func(labels, preds, weight) # type: ignore[call-arg] + + if argc == 4: + group = _get_group_from_constructed_dataset(dataset) + return self.func(labels, preds, weight, group) # type: ignore[call-arg] + + raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}") # documentation templates for LGBMModel methods are shared between the classes in diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 7f8980c271f7..2f6b07e7a77f 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -15,7 +15,7 @@ import lightgbm as lgb from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series -from .utils import dummy_obj, load_breast_cancer, mse_obj +from .utils import dummy_obj, load_breast_cancer, mse_obj, np_assert_array_equal def test_basic(tmp_path): @@ -499,6 +499,94 @@ def check_asserts(data): check_asserts(lgb_data) +def test_dataset_construction_overwrites_user_provided_metadata_fields(): + + X = np.array([[1.0, 2.0], [3.0, 4.0]]) + + position = np.array([0.0, 1.0], dtype=np.float32) + if getenv('TASK', '') == 'cuda': + position = None + + dtrain = lgb.Dataset( + X, + params={ + "min_data_in_bin": 1, + "min_data_in_leaf": 1, + "verbosity": -1 + }, + group=[1, 1], + init_score=[0.312, 0.708], + label=[1, 2], + position=position, + weight=[0.5, 1.5], + ) + + # unconstructed, get_* methods should return whatever was provided + assert dtrain.group == [1, 1] + assert dtrain.get_group() == [1, 1] + assert dtrain.init_score == [0.312, 0.708] + assert dtrain.get_init_score() == [0.312, 0.708] + assert dtrain.label == [1, 2] + assert dtrain.get_label() == [1, 2] + if getenv('TASK', '') != 'cuda': + np_assert_array_equal( + dtrain.position, + np.array([0.0, 1.0], dtype=np.float32), + strict=True + ) + np_assert_array_equal( + dtrain.get_position(), + np.array([0.0, 1.0], dtype=np.float32), + strict=True + ) + assert dtrain.weight == [0.5, 1.5] + assert dtrain.get_weight() == [0.5, 1.5] + + # before construction, get_field() should raise an exception + for field_name in ["group", "init_score", "label", "position", "weight"]: + with pytest.raises(Exception, match=f"Cannot get {field_name} before construct Dataset"): + dtrain.get_field(field_name) + + # constructed, get_* methods should return numpy arrays, even when the provided + # input was a list of floats or ints + dtrain.construct() + expected_group = np.array([1, 1], dtype=np.int32) + np_assert_array_equal(dtrain.group, expected_group, strict=True) + np_assert_array_equal(dtrain.get_group(), expected_group, strict=True) + # get_field("group") returns a numpy array with boundaries, instead of size + np_assert_array_equal( + dtrain.get_field("group"), + np.array([0, 1, 2], dtype=np.int32), + strict=True + ) + + expected_init_score = np.array([0.312, 0.708],) + np_assert_array_equal(dtrain.init_score, expected_init_score, strict=True) + np_assert_array_equal(dtrain.get_init_score(), expected_init_score, strict=True) + np_assert_array_equal(dtrain.get_field("init_score"), expected_init_score, strict=True) + + expected_label = np.array([1, 2], dtype=np.float32) + np_assert_array_equal(dtrain.label, expected_label, strict=True) + np_assert_array_equal(dtrain.get_label(), expected_label, strict=True) + np_assert_array_equal(dtrain.get_field("label"), expected_label, strict=True) + + if getenv('TASK', '') != 'cuda': + expected_position = np.array([0.0, 1.0], dtype=np.float32) + np_assert_array_equal(dtrain.position, expected_position, strict=True) + np_assert_array_equal(dtrain.get_position(), expected_position, strict=True) + # NOTE: "position" is converted to int32 on the C++ side + np_assert_array_equal( + dtrain.get_field("position"), + np.array([0.0, 1.0], dtype=np.int32), + strict=True + ) + + expected_weight = np.array([0.5, 1.5], dtype=np.float32) + np_assert_array_equal(dtrain.weight, expected_weight, strict=True) + np_assert_array_equal(dtrain.get_weight(), expected_weight, strict=True) + np_assert_array_equal(dtrain.get_field("weight"), expected_weight, strict=True) + + def test_choose_param_value(): original_params = { diff --git a/tests/python_package_test/utils.py b/tests/python_package_test/utils.py index df01e29852e7..7eae62b14369 100644 --- a/tests/python_package_test/utils.py +++ b/tests/python_package_test/utils.py @@ -1,6 +1,7 @@ # coding: utf-8 import pickle from functools import lru_cache +from inspect import getfullargspec import cloudpickle import joblib @@ -193,3 +194,22 @@ def pickle_and_unpickle_object(obj, serializer): serializer=serializer ) return obj_from_disk # noqa: RET504 + + +# doing this here, at import time, to ensure it only runs once_per import +# instead of once per assertion +_numpy_testing_supports_strict_kwarg = ( + "strict" in getfullargspec(np.testing.assert_array_equal).kwonlyargs +) + + +def np_assert_array_equal(*args, **kwargs): + """ + np.testing.assert_array_equal() only got the kwarg ``strict`` in June 2022: + https://github.com/numpy/numpy/pull/21595 + + This function is here for testing on older Python (and therefore ``numpy``) + """ + if not _numpy_testing_supports_strict_kwarg: + kwargs.pop("strict") + np.testing.assert_array_equal(*args, **kwargs)