Python sklearn.model_selection._search.BaseSearchCV() Examples

The following are 3 code examples of sklearn.model_selection._search.BaseSearchCV(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module sklearn.model_selection._search , or try the search function .
Example #1
Source File: test_search.py    From Mastering-Elasticsearch-7.0 with MIT License 6 votes vote down vote up
def test__custom_fit_no_run_search():
    class NoRunSearchSearchCV(BaseSearchCV):
        def __init__(self, estimator, **kwargs):
            super().__init__(estimator, **kwargs)

        def fit(self, X, y=None, groups=None, **fit_params):
            return self

    # this should not raise any exceptions
    NoRunSearchSearchCV(SVC(), cv=5).fit(X, y)

    class BadSearchCV(BaseSearchCV):
        def __init__(self, estimator, **kwargs):
            super().__init__(estimator, **kwargs)

    with pytest.raises(NotImplementedError,
                       match="_run_search not implemented."):
        # this should raise a NotImplementedError
        BadSearchCV(SVC(), cv=5).fit(X, y) 
Example #2
Source File: __main__.py    From fake-news-detection-pipeline with Apache License 2.0 5 votes vote down vote up
def print_cv_result(result, n):
    if isinstance(result, BaseSearchCV):
        result = result.cv_results_

    scores = result['mean_test_score']
    params = result['params']

    if n < 0:
        n = len(scores)

    print("Cross Validation result in descending order: (totalling {} trials)".format(n))
    for rank, candidate, in enumerate(heapq.nlargest(n, zip(scores, params), key=lambda tup: tup[0])):
        print("rank {}, score = {}\n hyperparams = {}".format(rank + 1, *candidate)) 
Example #3
Source File: test_search.py    From Mastering-Elasticsearch-7.0 with MIT License 4 votes vote down vote up
def test_custom_run_search():
    def check_results(results, gscv):
        exp_results = gscv.cv_results_
        assert sorted(results.keys()) == sorted(exp_results)
        for k in results:
            if not k.endswith('_time'):
                # XXX: results['params'] is a list :|
                results[k] = np.asanyarray(results[k])
                if results[k].dtype.kind == 'O':
                    assert_array_equal(exp_results[k], results[k],
                                       err_msg='Checking ' + k)
                else:
                    assert_allclose(exp_results[k], results[k],
                                    err_msg='Checking ' + k)

    def fit_grid(param_grid):
        return GridSearchCV(clf, param_grid, cv=5,
                            return_train_score=True).fit(X, y)

    class CustomSearchCV(BaseSearchCV):
        def __init__(self, estimator, **kwargs):
            super().__init__(estimator, **kwargs)

        def _run_search(self, evaluate):
            results = evaluate([{'max_depth': 1}, {'max_depth': 2}])
            check_results(results, fit_grid({'max_depth': [1, 2]}))
            results = evaluate([{'min_samples_split': 5},
                                {'min_samples_split': 10}])
            check_results(results, fit_grid([{'max_depth': [1, 2]},
                                             {'min_samples_split': [5, 10]}]))

    # Using regressor to make sure each score differs
    clf = DecisionTreeRegressor(random_state=0)
    X, y = make_classification(n_samples=100, n_informative=4,
                               random_state=0)
    mycv = CustomSearchCV(clf, cv=5, return_train_score=True).fit(X, y)
    gscv = fit_grid([{'max_depth': [1, 2]},
                     {'min_samples_split': [5, 10]}])

    results = mycv.cv_results_
    check_results(results, gscv)
    for attr in dir(gscv):
        if attr[0].islower() and attr[-1:] == '_' and \
           attr not in {'cv_results_', 'best_estimator_',
                        'refit_time_'}:
            assert getattr(gscv, attr) == getattr(mycv, attr), \
                   "Attribute %s not equal" % attr