Python sklearn.utils.testing.SkipTest() Examples
The following are 30
code examples of sklearn.utils.testing.SkipTest().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
sklearn.utils.testing
, or try the search function
.
Example #1
Source File: test_multiclass.py From Mastering-Elasticsearch-7.0 with MIT License | 6 votes |
def test_type_of_target(): for group, group_examples in EXAMPLES.items(): for example in group_examples: assert_equal(type_of_target(example), group, msg=('type_of_target(%r) should be %r, got %r' % (example, group, type_of_target(example)))) for example in NON_ARRAY_LIKE_EXAMPLES: msg_regex = r'Expected array-like \(array or non-string sequence\).*' assert_raises_regex(ValueError, msg_regex, type_of_target, example) for example in MULTILABEL_SEQUENCES: msg = ('You appear to be using a legacy multi-label data ' 'representation. Sequence of sequences are no longer supported;' ' use a binary array or sparse matrix instead.') assert_raises_regex(ValueError, msg, type_of_target, example) try: from pandas import SparseSeries except ImportError: raise SkipTest("Pandas not found") y = SparseSeries([1, 0, 0, 1, 0]) msg = "y cannot be class 'SparseSeries'." assert_raises_regex(ValueError, msg, type_of_target, y)
Example #2
Source File: test_kd_tree.py From twitter-stock-recommendation with MIT License | 6 votes |
def test_gaussian_kde(n_samples=1000): # Compare gaussian KDE results to scipy.stats.gaussian_kde from scipy.stats import gaussian_kde rng = check_random_state(0) x_in = rng.normal(0, 1, n_samples) x_out = np.linspace(-5, 5, 30) for h in [0.01, 0.1, 1]: kdt = KDTree(x_in[:, None]) try: gkde = gaussian_kde(x_in, bw_method=h / np.std(x_in)) except TypeError: raise SkipTest("Old scipy, does not accept explicit bandwidth.") dens_kdt = kdt.kernel_density(x_out[:, None], h) / n_samples dens_gkde = gkde.evaluate(x_out) assert_array_almost_equal(dens_kdt, dens_gkde, decimal=3)
Example #3
Source File: test_ball_tree.py From twitter-stock-recommendation with MIT License | 6 votes |
def test_gaussian_kde(n_samples=1000): # Compare gaussian KDE results to scipy.stats.gaussian_kde from scipy.stats import gaussian_kde rng = check_random_state(0) x_in = rng.normal(0, 1, n_samples) x_out = np.linspace(-5, 5, 30) for h in [0.01, 0.1, 1]: bt = BallTree(x_in[:, None]) try: gkde = gaussian_kde(x_in, bw_method=h / np.std(x_in)) except TypeError: raise SkipTest("Old version of scipy, doesn't accept " "explicit bandwidth.") dens_bt = bt.kernel_density(x_out[:, None], h) / n_samples dens_gkde = gkde.evaluate(x_out) assert_array_almost_equal(dens_bt, dens_gkde, decimal=3)
Example #4
Source File: test_covtype.py From twitter-stock-recommendation with MIT License | 6 votes |
def test_fetch(): try: data1 = fetch(shuffle=True, random_state=42) except IOError: raise SkipTest("Covertype dataset can not be loaded.") data2 = fetch(shuffle=True, random_state=37) X1, X2 = data1['data'], data2['data'] assert_equal((581012, 54), X1.shape) assert_equal(X1.shape, X2.shape) assert_equal(X1.sum(), X2.sum()) y1, y2 = data1['target'], data2['target'] assert_equal((X1.shape[0],), y1.shape) assert_equal((X1.shape[0],), y2.shape)
Example #5
Source File: test_bicluster.py From twitter-stock-recommendation with MIT License | 6 votes |
def test_perfect_checkerboard(): raise SkipTest("This test is failing on the buildbot, but cannot" " reproduce. Temporarily disabling it until it can be" " reproduced and fixed.") model = SpectralBiclustering(3, svd_method="arpack", random_state=0) S, rows, cols = make_checkerboard((30, 30), 3, noise=0, random_state=0) model.fit(S) assert_equal(consensus_score(model.biclusters_, (rows, cols)), 1) S, rows, cols = make_checkerboard((40, 30), 3, noise=0, random_state=0) model.fit(S) assert_equal(consensus_score(model.biclusters_, (rows, cols)), 1) S, rows, cols = make_checkerboard((30, 40), 3, noise=0, random_state=0) model.fit(S) assert_equal(consensus_score(model.biclusters_, (rows, cols)), 1)
Example #6
Source File: test_bayes.py From twitter-stock-recommendation with MIT License | 6 votes |
def test_bayesian_on_diabetes(): # Test BayesianRidge on diabetes raise SkipTest("XFailed Test") diabetes = datasets.load_diabetes() X, y = diabetes.data, diabetes.target clf = BayesianRidge(compute_score=True) # Test with more samples than features clf.fit(X, y) # Test that scores are increasing at each iteration assert_array_equal(np.diff(clf.scores_) > 0, True) # Test with more features than samples X = X[:5, :] y = y[:5] clf.fit(X, y) # Test that scores are increasing at each iteration assert_array_equal(np.diff(clf.scores_) > 0, True)
Example #7
Source File: test_spectral_embedding.py From twitter-stock-recommendation with MIT License | 6 votes |
def test_spectral_embedding_amg_solver(seed=36): # Test spectral embedding with amg solver try: from pyamg import smoothed_aggregation_solver # noqa except ImportError: raise SkipTest("pyamg not available.") se_amg = SpectralEmbedding(n_components=2, affinity="nearest_neighbors", eigen_solver="amg", n_neighbors=5, random_state=np.random.RandomState(seed)) se_arpack = SpectralEmbedding(n_components=2, affinity="nearest_neighbors", eigen_solver="arpack", n_neighbors=5, random_state=np.random.RandomState(seed)) embed_amg = se_amg.fit_transform(S) embed_arpack = se_arpack.fit_transform(S) assert_true(_check_with_col_sign_flipping(embed_amg, embed_arpack, 0.05))
Example #8
Source File: test_multiclass.py From twitter-stock-recommendation with MIT License | 6 votes |
def test_type_of_target(): for group, group_examples in iteritems(EXAMPLES): for example in group_examples: assert_equal(type_of_target(example), group, msg=('type_of_target(%r) should be %r, got %r' % (example, group, type_of_target(example)))) for example in NON_ARRAY_LIKE_EXAMPLES: msg_regex = 'Expected array-like \(array or non-string sequence\).*' assert_raises_regex(ValueError, msg_regex, type_of_target, example) for example in MULTILABEL_SEQUENCES: msg = ('You appear to be using a legacy multi-label data ' 'representation. Sequence of sequences are no longer supported;' ' use a binary array or sparse matrix instead.') assert_raises_regex(ValueError, msg, type_of_target, example) try: from pandas import SparseSeries except ImportError: raise SkipTest("Pandas not found") y = SparseSeries([1, 0, 0, 1, 0]) msg = "y cannot be class 'SparseSeries'." assert_raises_regex(ValueError, msg, type_of_target, y)
Example #9
Source File: estimator_checks.py From twitter-stock-recommendation with MIT License | 6 votes |
def check_estimators_data_not_an_array(name, estimator_orig, X, y): if name in CROSS_DECOMPOSITION: raise SkipTest # separate estimators to control random seeds estimator_1 = clone(estimator_orig) estimator_2 = clone(estimator_orig) set_random_state(estimator_1) set_random_state(estimator_2) y_ = NotAnArray(np.asarray(y)) X_ = NotAnArray(np.asarray(X)) # fit estimator_1.fit(X_, y_) pred1 = estimator_1.predict(X_) estimator_2.fit(X, y) pred2 = estimator_2.predict(X) assert_allclose(pred1, pred2, atol=1e-2, err_msg=name)
Example #10
Source File: estimator_checks.py From twitter-stock-recommendation with MIT License | 6 votes |
def check_sample_weights_pandas_series(name, estimator_orig): # check that estimators will accept a 'sample_weight' parameter of # type pandas.Series in the 'fit' function. estimator = clone(estimator_orig) if has_fit_parameter(estimator, "sample_weight"): try: import pandas as pd X = pd.DataFrame([[1, 1], [1, 2], [1, 3], [2, 1], [2, 2], [2, 3]]) y = pd.Series([1, 1, 1, 2, 2, 2]) weights = pd.Series([1] * 6) try: estimator.fit(X, y, sample_weight=weights) except ValueError: raise ValueError("Estimator {0} raises error if " "'sample_weight' parameter is of " "type pandas.Series".format(name)) except ImportError: raise SkipTest("pandas is not installed: not testing for " "input of type pandas.Series to class weight.")
Example #11
Source File: classifier.py From autogbt-alt with MIT License | 6 votes |
def fit(self, X, y): logger = logging.get_logger(__name__) trainer = create_trainer( objective='binary', metric='auc', sampler=self.sampler, n_jobs=self.n_jobs, create_validation=self.create_validation, cv=self.cv, random_state=self.random_state, ) optimizer = create_optimizer( objective=self.objective, trainer=trainer, n_trials=self.n_trials, random_state=self.random_state, ) X, y = validate_dataset(optimizer, X, y) if len(y.unique()) != 2: raise SkipTest('binary classification is only supported') logger.info('start optimization') optimizer.optimize(X, y) self._optimizer = optimizer
Example #12
Source File: estimator_checks.py From Splunking-Crime with GNU Affero General Public License v3.0 | 6 votes |
def check_estimators_data_not_an_array(name, estimator_orig, X, y): if name in CROSS_DECOMPOSITION: raise SkipTest # separate estimators to control random seeds estimator_1 = clone(estimator_orig) estimator_2 = clone(estimator_orig) set_random_state(estimator_1) set_random_state(estimator_2) y_ = NotAnArray(np.asarray(y)) X_ = NotAnArray(np.asarray(X)) # fit estimator_1.fit(X_, y_) pred1 = estimator_1.predict(X_) estimator_2.fit(X, y) pred2 = estimator_2.predict(X) assert_allclose(pred1, pred2, atol=1e-2, err_msg=name)
Example #13
Source File: estimator_checks.py From Splunking-Crime with GNU Affero General Public License v3.0 | 6 votes |
def check_sample_weights_pandas_series(name, estimator_orig): # check that estimators will accept a 'sample_weight' parameter of # type pandas.Series in the 'fit' function. estimator = clone(estimator_orig) if has_fit_parameter(estimator, "sample_weight"): try: import pandas as pd X = pd.DataFrame([[1, 1], [1, 2], [1, 3], [2, 1], [2, 2], [2, 3]]) y = pd.Series([1, 1, 1, 2, 2, 2]) weights = pd.Series([1] * 6) try: estimator.fit(X, y, sample_weight=weights) except ValueError: raise ValueError("Estimator {0} raises error if " "'sample_weight' parameter is of " "type pandas.Series".format(name)) except ImportError: raise SkipTest("pandas is not installed: not testing for " "input of type pandas.Series to class weight.")
Example #14
Source File: test_spectral_embedding.py From Mastering-Elasticsearch-7.0 with MIT License | 6 votes |
def test_spectral_embedding_amg_solver(seed=36): # Test spectral embedding with amg solver try: from pyamg import smoothed_aggregation_solver # noqa except ImportError: raise SkipTest("pyamg not available.") se_amg = SpectralEmbedding(n_components=2, affinity="nearest_neighbors", eigen_solver="amg", n_neighbors=5, random_state=np.random.RandomState(seed)) se_arpack = SpectralEmbedding(n_components=2, affinity="nearest_neighbors", eigen_solver="arpack", n_neighbors=5, random_state=np.random.RandomState(seed)) embed_amg = se_amg.fit_transform(S) embed_arpack = se_arpack.fit_transform(S) assert _check_with_col_sign_flipping(embed_amg, embed_arpack, 0.05)
Example #15
Source File: test_covtype.py From Mastering-Elasticsearch-7.0 with MIT License | 6 votes |
def test_fetch(): try: data1 = fetch(shuffle=True, random_state=42) except IOError: raise SkipTest("Covertype dataset can not be loaded.") data2 = fetch(shuffle=True, random_state=37) X1, X2 = data1['data'], data2['data'] assert_equal((581012, 54), X1.shape) assert_equal(X1.shape, X2.shape) assert_equal(X1.sum(), X2.sum()) y1, y2 = data1['target'], data2['target'] assert_equal((X1.shape[0],), y1.shape) assert_equal((X1.shape[0],), y2.shape) # test return_X_y option fetch_func = partial(fetch) check_return_X_y(data1, fetch_func)
Example #16
Source File: test_bicluster.py From Mastering-Elasticsearch-7.0 with MIT License | 6 votes |
def test_perfect_checkerboard(): # XXX test always skipped raise SkipTest("This test is failing on the buildbot, but cannot" " reproduce. Temporarily disabling it until it can be" " reproduced and fixed.") model = SpectralBiclustering(3, svd_method="arpack", random_state=0) S, rows, cols = make_checkerboard((30, 30), 3, noise=0, random_state=0) model.fit(S) assert_equal(consensus_score(model.biclusters_, (rows, cols)), 1) S, rows, cols = make_checkerboard((40, 30), 3, noise=0, random_state=0) model.fit(S) assert_equal(consensus_score(model.biclusters_, (rows, cols)), 1) S, rows, cols = make_checkerboard((30, 40), 3, noise=0, random_state=0) model.fit(S) assert_equal(consensus_score(model.biclusters_, (rows, cols)), 1)
Example #17
Source File: test_extmath.py From twitter-stock-recommendation with MIT License | 5 votes |
def test_stable_cumsum(): if np_version < (1, 9): raise SkipTest("Sum is as unstable as cumsum for numpy < 1.9") assert_array_equal(stable_cumsum([1, 2, 3]), np.cumsum([1, 2, 3])) r = np.random.RandomState(0).rand(100000) assert_warns(RuntimeWarning, stable_cumsum, r, rtol=0, atol=0) # test axis parameter A = np.random.RandomState(36).randint(1000, size=(5, 5, 5)) assert_array_equal(stable_cumsum(A, axis=0), np.cumsum(A, axis=0)) assert_array_equal(stable_cumsum(A, axis=1), np.cumsum(A, axis=1)) assert_array_equal(stable_cumsum(A, axis=2), np.cumsum(A, axis=2))
Example #18
Source File: test_gradient_boosting.py From pygbm with MIT License | 5 votes |
def custom_check_estimator(Estimator): # Same as sklearn.check_estimator, skipping tests that can't succeed. from sklearn.utils.estimator_checks import _yield_all_checks from sklearn.utils.testing import SkipTest from sklearn.exceptions import SkipTestWarning from sklearn.utils import estimator_checks estimator = Estimator name = type(estimator).__name__ for check in _yield_all_checks(name, estimator): if (check is estimator_checks.check_fit2d_1feature or check is estimator_checks.check_fit2d_1sample): # X is both Fortran and C aligned and numba can't compile. # Opened numba issue 3569 continue if check is estimator_checks.check_classifiers_train: continue # probas don't exactly sum to 1 (very close though) if (hasattr(check, 'func') and check.func is estimator_checks.check_classifiers_train): continue # same, wrapped in a functools.partial object. try: check(name, estimator) except SkipTest as exception: # the only SkipTest thrown currently results from not # being able to import pandas. warnings.warn(str(exception), SkipTestWarning)
Example #19
Source File: test_validation.py From Mastering-Elasticsearch-7.0 with MIT License | 5 votes |
def test_check_dataframe_fit_attribute(): # check pandas dataframe with 'fit' column does not raise error # https://github.com/scikit-learn/scikit-learn/issues/8415 try: import pandas as pd X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) X_df = pd.DataFrame(X, columns=['a', 'b', 'fit']) check_consistent_length(X_df) except ImportError: raise SkipTest("Pandas not found")
Example #20
Source File: test_20news.py From twitter-stock-recommendation with MIT License | 5 votes |
def test_20news_vectorized(): try: datasets.fetch_20newsgroups(subset='all', download_if_missing=False) except IOError: raise SkipTest("Download 20 newsgroups to run this test") # test subset = train bunch = datasets.fetch_20newsgroups_vectorized(subset="train") assert_true(sp.isspmatrix_csr(bunch.data)) assert_equal(bunch.data.shape, (11314, 130107)) assert_equal(bunch.target.shape[0], 11314) assert_equal(bunch.data.dtype, np.float64) # test subset = test bunch = datasets.fetch_20newsgroups_vectorized(subset="test") assert_true(sp.isspmatrix_csr(bunch.data)) assert_equal(bunch.data.shape, (7532, 130107)) assert_equal(bunch.target.shape[0], 7532) assert_equal(bunch.data.dtype, np.float64) # test subset = all bunch = datasets.fetch_20newsgroups_vectorized(subset='all') assert_true(sp.isspmatrix_csr(bunch.data)) assert_equal(bunch.data.shape, (11314 + 7532, 130107)) assert_equal(bunch.target.shape[0], 11314 + 7532) assert_equal(bunch.data.dtype, np.float64)
Example #21
Source File: test_20news.py From twitter-stock-recommendation with MIT License | 5 votes |
def test_20news_length_consistency(): """Checks the length consistencies within the bunch This is a non-regression test for a bug present in 0.16.1. """ try: data = datasets.fetch_20newsgroups( subset='all', download_if_missing=False, shuffle=False) except IOError: raise SkipTest("Download 20 newsgroups to run this test") # Extract the full dataset data = datasets.fetch_20newsgroups(subset='all') assert_equal(len(data['data']), len(data.data)) assert_equal(len(data['target']), len(data.target)) assert_equal(len(data['filenames']), len(data.filenames))
Example #22
Source File: test_kddcup99.py From twitter-stock-recommendation with MIT License | 5 votes |
def test_percent10(): try: data = fetch_kddcup99(download_if_missing=False) except IOError: raise SkipTest("kddcup99 dataset can not be loaded.") assert_equal(data.data.shape, (494021, 41)) assert_equal(data.target.shape, (494021,)) data_shuffled = fetch_kddcup99(shuffle=True, random_state=0) assert_equal(data.data.shape, data_shuffled.data.shape) assert_equal(data.target.shape, data_shuffled.target.shape) data = fetch_kddcup99('SA') assert_equal(data.data.shape, (100655, 41)) assert_equal(data.target.shape, (100655,)) data = fetch_kddcup99('SF') assert_equal(data.data.shape, (73237, 4)) assert_equal(data.target.shape, (73237,)) data = fetch_kddcup99('http') assert_equal(data.data.shape, (58725, 3)) assert_equal(data.target.shape, (58725,)) data = fetch_kddcup99('smtp') assert_equal(data.data.shape, (9571, 3)) assert_equal(data.target.shape, (9571,))
Example #23
Source File: test_k_means.py From twitter-stock-recommendation with MIT License | 5 votes |
def test_k_means_plus_plus_init_2_jobs(): if sys.version_info[:2] < (3, 4): raise SkipTest( "Possible multi-process bug with some BLAS under Python < 3.4") km = KMeans(init="k-means++", n_clusters=n_clusters, n_jobs=2, random_state=42).fit(X) _check_fitted_model(km)
Example #24
Source File: test_logistic.py From twitter-stock-recommendation with MIT License | 5 votes |
def test_dtype_match(): # Disabled to unblock the 0.19.2 release. See: # https://github.com/scikit-learn/scikit-learn/issues/11438 # Test that np.float32 input data is not cast to np.float64 when possible raise SkipTest() X_32 = np.array(X).astype(np.float32) y_32 = np.array(Y1).astype(np.float32) X_64 = np.array(X).astype(np.float64) y_64 = np.array(Y1).astype(np.float64) X_sparse_32 = sp.csr_matrix(X, dtype=np.float32) for solver in ['newton-cg']: for multi_class in ['ovr', 'multinomial']: # Check type consistency lr_32 = LogisticRegression(solver=solver, multi_class=multi_class) lr_32.fit(X_32, y_32) assert_equal(lr_32.coef_.dtype, X_32.dtype) # check consistency with sparsity lr_32_sparse = LogisticRegression(solver=solver, multi_class=multi_class) lr_32_sparse.fit(X_sparse_32, y_32) assert_equal(lr_32_sparse.coef_.dtype, X_sparse_32.dtype) # Check accuracy consistency lr_64 = LogisticRegression(solver=solver, multi_class=multi_class) lr_64.fit(X_64, y_64) assert_equal(lr_64.coef_.dtype, X_64.dtype) assert_almost_equal(lr_32.coef_, lr_64.coef_.astype(np.float32))
Example #25
Source File: test_kddcup99.py From Mastering-Elasticsearch-7.0 with MIT License | 5 votes |
def test_shuffle(): try: dataset = fetch_kddcup99(random_state=0, subset='SA', shuffle=True, percent10=True, download_if_missing=False) except IOError: raise SkipTest("kddcup99 dataset can not be loaded.") assert(any(dataset.target[-100:] == b'normal.'))
Example #26
Source File: test_validation.py From twitter-stock-recommendation with MIT License | 5 votes |
def test_check_dataframe_fit_attribute(): # check pandas dataframe with 'fit' column does not raise error # https://github.com/scikit-learn/scikit-learn/issues/8415 try: import pandas as pd X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) X_df = pd.DataFrame(X, columns=['a', 'b', 'fit']) check_consistent_length(X_df) except ImportError: raise SkipTest("Pandas not found")
Example #27
Source File: test_20news.py From Mastering-Elasticsearch-7.0 with MIT License | 5 votes |
def test_20news(): try: data = datasets.fetch_20newsgroups( subset='all', download_if_missing=False, shuffle=False) except IOError: raise SkipTest("Download 20 newsgroups to run this test") # Extract a reduced dataset data2cats = datasets.fetch_20newsgroups( subset='all', categories=data.target_names[-1:-3:-1], shuffle=False) # Check that the ordering of the target_names is the same # as the ordering in the full dataset assert_equal(data2cats.target_names, data.target_names[-2:]) # Assert that we have only 0 and 1 as labels assert_equal(np.unique(data2cats.target).tolist(), [0, 1]) # Check that the number of filenames is consistent with data/target assert_equal(len(data2cats.filenames), len(data2cats.target)) assert_equal(len(data2cats.filenames), len(data2cats.data)) # Check that the first entry of the reduced dataset corresponds to # the first entry of the corresponding category in the full dataset entry1 = data2cats.data[0] category = data2cats.target_names[data2cats.target[0]] label = data.target_names.index(category) entry2 = data.data[np.where(data.target == label)[0][0]] assert_equal(entry1, entry2)
Example #28
Source File: estimator_checks.py From Splunking-Crime with GNU Affero General Public License v3.0 | 5 votes |
def check_class_weight_classifiers(name, classifier_orig): if name == "NuSVC": # the sparse version has a parameter that doesn't do anything raise SkipTest if name.endswith("NB"): # NaiveBayes classifiers have a somewhat different interface. # FIXME SOON! raise SkipTest for n_centers in [2, 3]: # create a very noisy dataset X, y = make_blobs(centers=n_centers, random_state=0, cluster_std=20) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, random_state=0) n_centers = len(np.unique(y_train)) if n_centers == 2: class_weight = {0: 1000, 1: 0.0001} else: class_weight = {0: 1000, 1: 0.0001, 2: 0.0001} classifier = clone(classifier_orig).set_params( class_weight=class_weight) if hasattr(classifier, "n_iter"): classifier.set_params(n_iter=100) if hasattr(classifier, "max_iter"): classifier.set_params(max_iter=1000) if hasattr(classifier, "min_weight_fraction_leaf"): classifier.set_params(min_weight_fraction_leaf=0.01) set_random_state(classifier) classifier.fit(X_train, y_train) y_pred = classifier.predict(X_test) # XXX: Generally can use 0.89 here. On Windows, LinearSVC gets # 0.88 (Issue #9111) assert_greater(np.mean(y_pred == 0), 0.87)
Example #29
Source File: sklearn_patches.py From tslearn with BSD 2-Clause "Simplified" License | 5 votes |
def check_pipeline_consistency(name, estimator_orig): if estimator_orig._get_tags()['non_deterministic']: msg = name + ' is non deterministic' raise SkipTest(msg) # check that make_pipeline(est) gives same score as est X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) X -= X.min() X = pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel) estimator = clone(estimator_orig) y = multioutput_estimator_convert_y_2d(estimator, y) set_random_state(estimator) pipeline = make_pipeline(estimator) estimator.fit(X, y) pipeline.fit(X, y) funcs = ["score", "fit_transform"] for func_name in funcs: func = getattr(estimator, func_name, None) if func is not None: func_pipeline = getattr(pipeline, func_name) result = func(X, y) result_pipe = func_pipeline(X, y) assert_allclose_dense_sparse(result, result_pipe)
Example #30
Source File: estimator_checks.py From Splunking-Crime with GNU Affero General Public License v3.0 | 5 votes |
def check_estimator(Estimator): """Check if estimator adheres to scikit-learn conventions. This estimator will run an extensive test-suite for input validation, shapes, etc. Additional tests for classifiers, regressors, clustering or transformers will be run if the Estimator class inherits from the corresponding mixin from sklearn.base. This test can be applied to classes or instances. Classes currently have some additional tests that related to construction, while passing instances allows the testing of multiple options. Parameters ---------- estimator : estimator object or class Estimator to check. Estimator is a class object or instance. """ if isinstance(Estimator, type): # got a class name = Estimator.__name__ check_parameters_default_constructible(name, Estimator) check_no_fit_attributes_set_in_init(name, Estimator) estimator = Estimator() else: # got an instance estimator = Estimator name = type(estimator).__name__ for check in _yield_all_checks(name, estimator): try: check(name, estimator) except SkipTest as message: # the only SkipTest thrown currently results from not # being able to import pandas. warnings.warn(message, SkipTestWarning)