Python mlflow.set_tracking_uri() Examples

The following are 24 code examples of mlflow.set_tracking_uri(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module mlflow , or try the search function .
Example #1
Source File: test_mlflow.py    From interpret-community with MIT License 6 votes vote down vote up
def test_upload_as_model(self, iris, tabular_explainer, tracking_uri):
        mlflow.set_tracking_uri(tracking_uri)
        x_train = iris[DatasetConstants.X_TRAIN]
        x_test = iris[DatasetConstants.X_TEST]
        y_train = iris[DatasetConstants.Y_TRAIN]

        model = create_sklearn_random_forest_classifier(x_train, y_train)

        explainer = tabular_explainer(model, x_train)
        global_explanation = explainer.explain_global(x_test)
        mlflow.set_experiment(TEST_EXPERIMENT)
        with mlflow.start_run() as run:
            log_explanation(TEST_EXPLANATION, global_explanation)
            os.makedirs(TEST_DOWNLOAD, exist_ok=True)
            run_id = run.info.run_id
        downloaded_explanation_mlflow = get_explanation(run_id, TEST_EXPLANATION)
        _assert_explanation_equivalence(global_explanation, downloaded_explanation_mlflow) 
Example #2
Source File: test_mlflow_reporter.py    From gordo with GNU Affero General Public License v3.0 6 votes vote down vote up
def test_mlflow_context_log_metadata(MockClient, tmpdir, metadata):
    """
    Test that call to wrapped function initiates MLflow logging or throws warning
    """
    metadata = Machine(**metadata)
    mlflow.set_tracking_uri(f"file:{tmpdir}")

    mock_client = MockClient()
    mock_client.log_batch.return_value = "test"

    # Function with a metadata dict returned
    with mlu.mlflow_context("returns metadata", "unique_key", {}, {}) as (
        mlflow_client,
        run_id,
    ):
        mlu.log_machine(mlflow_client, run_id, metadata)

    assert mock_client.log_batch.called 
Example #3
Source File: test_h2o_model_export.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_model_log(h2o_iris_model):
    h2o_model = h2o_iris_model.model
    old_uri = mlflow.get_tracking_uri()
    # should_start_run tests whether or not calling log_model() automatically starts a run.
    for should_start_run in [False, True]:
        with TempDir(chdr=True, remove_on_exit=True):
            try:
                artifact_path = "gbm_model"
                mlflow.set_tracking_uri("test")
                if should_start_run:
                    mlflow.start_run()
                mlflow.h2o.log_model(h2o_model=h2o_model, artifact_path=artifact_path)
                model_uri = "runs:/{run_id}/{artifact_path}".format(
                    run_id=mlflow.active_run().info.run_id,
                    artifact_path=artifact_path)

                # Load model
                h2o_model_loaded = mlflow.h2o.load_model(model_uri=model_uri)
                assert all(
                    h2o_model_loaded.predict(h2o_iris_model.inference_data).as_data_frame() ==
                    h2o_model.predict(h2o_iris_model.inference_data).as_data_frame())
            finally:
                mlflow.end_run()
                mlflow.set_tracking_uri(old_uri) 
Example #4
Source File: test_docker_projects.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_docker_project_tracking_uri_propagation(
        ProfileConfigProvider, tmpdir, tracking_uri,
        expected_command_segment, docker_example_base_image):  # pylint: disable=unused-argument
    mock_provider = mock.MagicMock()
    mock_provider.get_config.return_value = \
        DatabricksConfig("host", "user", "pass", None, insecure=True)
    ProfileConfigProvider.return_value = mock_provider
    # Create and mock local tracking directory
    local_tracking_dir = os.path.join(tmpdir.strpath, "mlruns")
    if tracking_uri is None:
        tracking_uri = local_tracking_dir
    old_uri = mlflow.get_tracking_uri()
    try:
        mlflow.set_tracking_uri(tracking_uri)
        with mock.patch("mlflow.tracking._tracking_service.utils._get_store") as _get_store_mock:
            _get_store_mock.return_value = file_store.FileStore(local_tracking_dir)
            mlflow.projects.run(
                TEST_DOCKER_PROJECT_DIR, experiment_id=file_store.FileStore.DEFAULT_EXPERIMENT_ID)
    finally:
        mlflow.set_tracking_uri(old_uri) 
Example #5
Source File: test_mlflow.py    From interpret-community with MIT License 5 votes vote down vote up
def test_upload_two_explanations(self, iris, tabular_explainer, tracking_uri):
        mlflow.set_tracking_uri(tracking_uri)
        x_train = iris[DatasetConstants.X_TRAIN]
        x_test = iris[DatasetConstants.X_TEST]
        y_train = iris[DatasetConstants.Y_TRAIN]

        model = create_sklearn_random_forest_classifier(x_train, y_train)

        explainer = tabular_explainer(model, x_train)
        global_explanation = explainer.explain_global(x_test)
        local_explanation = explainer.explain_local(x_test)
        mlflow.set_experiment(TEST_EXPERIMENT)
        with mlflow.start_run() as run:
            log_explanation('global_explanation', global_explanation)
            log_explanation('local_explanation', local_explanation)
            os.makedirs(TEST_DOWNLOAD, exist_ok=True)
            run_id = run.info.run_id
        downloaded_explanation_mlflow = get_explanation(run_id, 'global_explanation')
        _assert_explanation_equivalence(global_explanation, downloaded_explanation_mlflow) 
Example #6
Source File: utils.py    From FARM with Apache License 2.0 5 votes vote down vote up
def init_experiment(self, experiment_name, run_name=None, nested=True):
        try:
            mlflow.set_tracking_uri(self.tracking_uri)
            mlflow.set_experiment(experiment_name)
            mlflow.start_run(run_name=run_name, nested=nested)
        except ConnectionError:
            raise Exception(
                f"MLFlow cannot connect to the remote server at {self.tracking_uri}.\n"
                f"MLFlow also supports logging runs locally to files. Set the MLFlowLogger "
                f"tracking_uri to an empty string to use that."
            ) 
Example #7
Source File: mlflow_logger.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, tracking_uri=None):
        try:
            import mlflow
        except ImportError:
            raise RuntimeError(
                "This contrib module requires mlflow to be installed. "
                "Please install it with command: \n pip install mlflow"
            )

        if tracking_uri is not None:
            mlflow.set_tracking_uri(tracking_uri)

        self.active_run = mlflow.active_run()
        if self.active_run is None:
            self.active_run = mlflow.start_run() 
Example #8
Source File: mlflow.py    From optuna with MIT License 5 votes vote down vote up
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:

        # This sets the tracking_uri for MLflow.
        if self._tracking_uri is not None:
            mlflow.set_tracking_uri(self._tracking_uri)

        # This sets the experiment of MLflow.
        mlflow.set_experiment(study.study_name)

        with mlflow.start_run(run_name=str(trial.number)):

            # This sets the metric for MLflow.
            trial_value = trial.value if trial.value is not None else float("nan")
            mlflow.log_metric(self._metric_name, trial_value)

            # This sets the params for MLflow.
            mlflow.log_params(trial.params)

            # This sets the tags for MLflow.
            tags = {}  # type: Dict[str, str]
            tags["number"] = str(trial.number)
            tags["datetime_start"] = str(trial.datetime_start)
            tags["datetime_complete"] = str(trial.datetime_complete)

            # Set state and convert it to str and remove the common prefix.
            trial_state = trial.state
            if isinstance(trial_state, TrialState):
                tags["state"] = str(trial_state).split(".")[-1]

            # Set direction and convert it to str and remove the common prefix.
            study_direction = study.direction
            if isinstance(study_direction, StudyDirection):
                tags["direction"] = str(study_direction).split(".")[-1]

            tags.update(trial.user_attrs)
            distributions = {
                (k + "_distribution"): str(v) for (k, v) in trial.distributions.items()
            }
            tags.update(distributions)
            mlflow.set_tags(tags) 
Example #9
Source File: test_spark_datasource_autologging_crossframework.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def http_tracking_uri_mock():
    mlflow.set_tracking_uri("http://some-cool-uri")
    yield
    mlflow.set_tracking_uri(None) 
Example #10
Source File: test_spark_datasource_autologging.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def http_tracking_uri_mock():
    mlflow.set_tracking_uri("http://some-cool-uri")
    yield
    mlflow.set_tracking_uri(None) 
Example #11
Source File: test_spark_model_export.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_sparkml_estimator_model_log(tmpdir, spark_model_estimator):
    # Print the coefficients and intercept for multinomial logistic regression
    old_tracking_uri = mlflow.get_tracking_uri()
    cnt = 0
    # should_start_run tests whether or not calling log_model() automatically starts a run.
    for should_start_run in [False, True]:
        for dfs_tmp_dir in [None, os.path.join(str(tmpdir), "test")]:
            print("should_start_run =", should_start_run, "dfs_tmp_dir =", dfs_tmp_dir)
            try:
                tracking_dir = os.path.abspath(str(tmpdir.join("mlruns")))
                mlflow.set_tracking_uri("file://%s" % tracking_dir)
                if should_start_run:
                    mlflow.start_run()
                artifact_path = "model%d" % cnt
                cnt += 1
                sparkm.log_model(
                    artifact_path=artifact_path,
                    spark_model=spark_model_estimator.model,
                    dfs_tmpdir=dfs_tmp_dir)
                model_uri = "runs:/{run_id}/{artifact_path}".format(
                    run_id=mlflow.active_run().info.run_id,
                    artifact_path=artifact_path)

                # test reloaded model
                reloaded_model = sparkm.load_model(model_uri=model_uri, dfs_tmpdir=dfs_tmp_dir)
                preds_df = reloaded_model.transform(spark_model_estimator.spark_df)
                preds = [x.prediction for x in preds_df.select("prediction").collect()]
                assert spark_model_estimator.predictions == preds
            finally:
                mlflow.end_run()
                mlflow.set_tracking_uri(old_tracking_uri)
                x = dfs_tmp_dir or sparkm.DFS_TMP
                shutil.rmtree(x)
                shutil.rmtree(tracking_dir) 
Example #12
Source File: test_spark_model_export.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_sparkml_model_log(tmpdir, spark_model_iris):
    # Print the coefficients and intercept for multinomial logistic regression
    old_tracking_uri = mlflow.get_tracking_uri()
    cnt = 0
    # should_start_run tests whether or not calling log_model() automatically starts a run.
    for should_start_run in [False, True]:
        for dfs_tmp_dir in [None, os.path.join(str(tmpdir), "test")]:
            print("should_start_run =", should_start_run, "dfs_tmp_dir =", dfs_tmp_dir)
            try:
                tracking_dir = os.path.abspath(str(tmpdir.join("mlruns")))
                mlflow.set_tracking_uri("file://%s" % tracking_dir)
                if should_start_run:
                    mlflow.start_run()
                artifact_path = "model%d" % cnt
                cnt += 1
                sparkm.log_model(artifact_path=artifact_path, spark_model=spark_model_iris.model,
                                 dfs_tmpdir=dfs_tmp_dir)
                model_uri = "runs:/{run_id}/{artifact_path}".format(
                    run_id=mlflow.active_run().info.run_id,
                    artifact_path=artifact_path)

                # test reloaded model
                reloaded_model = sparkm.load_model(model_uri=model_uri, dfs_tmpdir=dfs_tmp_dir)
                preds_df = reloaded_model.transform(spark_model_iris.spark_df)
                preds = [x.prediction for x in preds_df.select("prediction").collect()]
                assert spark_model_iris.predictions == preds
            finally:
                mlflow.end_run()
                mlflow.set_tracking_uri(old_tracking_uri)
                x = dfs_tmp_dir or sparkm.DFS_TMP
                shutil.rmtree(x)
                shutil.rmtree(tracking_dir) 
Example #13
Source File: test_databricks.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_get_tracking_uri_for_run():
    mlflow.set_tracking_uri("http://some-uri")
    assert databricks._get_tracking_uri_for_run() == "http://some-uri"
    mlflow.set_tracking_uri("databricks://profile")
    assert databricks._get_tracking_uri_for_run() == "databricks"
    mlflow.set_tracking_uri(None)
    with mock.patch.dict(os.environ, {mlflow.tracking._TRACKING_URI_ENV_VAR: "http://some-uri"}):
        assert mlflow.tracking._tracking_service.utils.get_tracking_uri() == "http://some-uri" 
Example #14
Source File: conftest.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def tracking_uri_mock(tmpdir, request):
    try:
        if 'notrackingurimock' not in request.keywords:
            tracking_uri = path_to_local_sqlite_uri(
                os.path.join(tmpdir.strpath, 'mlruns'))
            mlflow.set_tracking_uri(tracking_uri)
            os.environ["MLFLOW_TRACKING_URI"] = tracking_uri
        yield tmpdir
    finally:
        mlflow.set_tracking_uri(None)
        if 'notrackingurimock' not in request.keywords:
            del os.environ["MLFLOW_TRACKING_URI"] 
Example #15
Source File: experiment.py    From LaSO with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def start(self):
        """Start the whole thing"""

        self._setup_logging()

        if self.generate_config:
            self.write_config()

        #
        # Setup mlflow
        #
        import mlflow
        mlflow.set_tracking_uri(self.mlflow_server)
        experiment_id = mlflow.set_experiment(self.name)

        #
        # Run the script under mlflow
        #
        with mlflow.start_run(experiment_id=experiment_id):
            #
            # Log the run parametres to mlflow.
            #
            mlflow.log_param("results_path", self.results_path)

            cls = self.__class__
            for k, trait in sorted(cls.class_own_traits(config=True).items()):
                mlflow.log_param(trait.name, repr(trait.get(self)))

            self.run() 
Example #16
Source File: mlflow.py    From gordo with GNU Affero General Public License v3.0 5 votes vote down vote up
def mlflow_context(
    name: str,
    model_key: str = uuid4().hex,
    workspace_kwargs: dict = {},
    service_principal_kwargs: dict = {},
):
    """
    Generate MLflow logger function with either a local or AzureML backend

    Parameters
    ----------
    name: str
        The name of the log group to log to (e.g. a model name).
    model_key: str
        Unique ID of logging run.
    workspace_kwargs: dict
        AzureML Workspace configuration to use for remote MLFlow tracking. See
        :func:`gordo.builder.mlflow_utils.get_mlflow_client`.
    service_principal_kwargs: dict
        AzureML ServicePrincipalAuthentication keyword arguments. See
        :func:`gordo.builder.mlflow_utils.get_mlflow_client`

    Example
    -------
    >>> with tempfile.TemporaryDirectory as tmp_dir:
    ...     mlflow.set_tracking_uri(f"file:{tmp_dir}")
    ...     with mlflow_context("log_group", "unique_key", {}, {}) as (mlflow_client, run_id):
    ...         log_machine(machine) # doctest: +SKIP
    """
    mlflow_client = get_mlflow_client(workspace_kwargs, service_principal_kwargs)
    run_id = get_run_id(mlflow_client, experiment_name=name, model_key=model_key)

    logger.info(
        f"MLflow client configured to use {'AzureML' if workspace_kwargs else 'local backend'}"
    )

    yield mlflow_client, run_id

    mlflow_client.set_terminated(run_id) 
Example #17
Source File: test_cli.py    From gordo with GNU Affero General Public License v3.0 5 votes vote down vote up
def test_mlflow_reporter_set_cli_build(
    MockClient,
    mock_get_workspace_kwargs,
    mock_get_spauth_kwargs,
    monkeypatch,
    runner,
    tmpdir,
    machine,
):
    """
    Tests disabling MlFlow logging in cli, and missing env var when enabled
    """

    mlflow.set_tracking_uri(f"file:{tmpdir}")
    machine.runtime = dict(
        reporters=[{"gordo.reporters.mlflow.MlFlowReporter": dict()}]
    )

    with temp_env_vars(MACHINE=json.dumps(machine.to_dict()), OUTPUT_DIR=str(tmpdir)):
        # Logging enabled, without env vars set:
        # Raise error
        result = runner.invoke(cli.gordo, ["build"])
        result.exit_code != 0

        # Logging enabled, with env vars set:
        # Build success, remote logging executed
        with monkeypatch.context() as m:
            m.setenv("DL_SERVICE_AUTH_STR", "test:test:test")
            m.setenv("AZUREML_WORKSPACE_STR", "test:test:test")

            result = runner.invoke(cli.gordo, ["build"])
            result.exit_code == 1
            assert MockClient.called
            assert mock_get_workspace_kwargs.called
            assert mock_get_spauth_kwargs.called

        # Reset call counts
        for m in [MockClient, mock_get_workspace_kwargs, mock_get_spauth_kwargs]:
            m.reset_mock()

    # Logging not enabled:
    # Build success, remote logging not executed
    machine.runtime = dict(builder=dict(remote_logging=dict(enable=False)))
    with temp_env_vars(MACHINE=json.dumps(machine.to_dict()), OUTPUT_DIR=str(tmpdir)):
        result = runner.invoke(cli.gordo, ["build"])
        result.exit_code == 0
        assert not MockClient.called
        assert not mock_get_workspace_kwargs.called
        assert not mock_get_spauth_kwargs.called 
Example #18
Source File: test_cli.py    From gordo with GNU Affero General Public License v3.0 5 votes vote down vote up
def runner(tmpdir):
    mlflow.set_tracking_uri(f"file:{tmpdir}")
    yield CliRunner() 
Example #19
Source File: loggers.py    From OpenKiwi with GNU Affero General Public License v3.0 5 votes vote down vote up
def configure(
        self,
        run_uuid,
        experiment_name,
        tracking_uri,
        run_name=None,
        always_log_artifacts=False,
        create_run=True,
        create_experiment=True,
        nest_run=True,
    ):
        if mlflow.active_run() and not nest_run:
            logger.info('Ending previous MLFlow run: {}.'.format(self.run_uuid))
            mlflow.end_run()

        self.always_log_artifacts = always_log_artifacts
        self._experiment_name = experiment_name
        self._run_name = run_name

        # MLflow specific
        if tracking_uri:
            mlflow.set_tracking_uri(tracking_uri)

        if run_uuid:
            existing_run = MlflowClient().get_run(run_uuid)
            if not existing_run and not create_run:
                raise FileNotFoundError(
                    'Run ID {} not found under {}'.format(
                        run_uuid, mlflow.get_tracking_uri()
                    )
                )

        experiment_id = self._retrieve_mlflow_experiment_id(
            experiment_name, create=create_experiment
        )
        return mlflow.start_run(
            run_uuid,
            experiment_id=experiment_id,
            run_name=run_name,
            nested=nest_run,
        ) 
Example #20
Source File: mlflow_utils_test.py    From nucleus7 with Mozilla Public License 2.0 5 votes vote down vote up
def setUp(self):
        TestCaseWithReset.setUp(self)
        TestCaseWithTempDir.setUp(self)
        if "MLFLOW_TRACKING_URI" in os.environ:
            del os.environ["MLFLOW_TRACKING_URI"]
        mlflow.set_tracking_uri(None) 
Example #21
Source File: evaluate.py    From orbyter-cookiecutter with MIT License 4 votes vote down vote up
def log_experiment(
    params={},
    metrics={},
    artifacts={},
    experiment_name="my_experiment",
    mlflow_tracking_uri="./experiments",
    mlflow_artifact_location=None,
):
    """
    Evaluate the model and log it with mlflow

    Args:
        params (dict): dictionary of parameters to log
        metrics (dict): dictionary of metrics to log
        artifacts (dict): dictionary of artifacts (path) to log
        experiment_name (str): experiment name
        mlflow_tracking_uri (str): path or sql url for mlflow logging
        mlflow_artifact_location (str): path or s3bucket url for artifact
            logging. If none, it will default to a standard.

    Returns:
        None
    """
    # Try to create an experiment if it doesn't exist
    try:
        exp_0 = mlflow.create_experiment(
            experiment_name, artifact_location=mlflow_artifact_location
        )
        # set uri
        mlflow.set_tracking_uri(mlflow_tracking_uri)
        logger.info(f"Created new experiment id: {exp_0}")
    except Exception as E:
        logger.info(f"{E}. Writing to same URI/artifact store")
    # Always set the experiment
    mlflow.set_experiment(experiment_name)
    logger.info(f"Running experiment {experiment_name}")
    with mlflow.start_run():
        # param logging
        for key, val in params.items():
            logger.info(f"Logging param {key}")
            mlflow.log_param(key, val)
        # metric logging
        for key, val in metrics.items():
            logger.info(f"Logging metric {key}")
            mlflow.log_metric(key, val)
        # artifact logging
        for key, val in artifacts.items():
            logger.info(f"Logging artifact {key}")
            mlflow.log_artifact(val) 
Example #22
Source File: test_sqlalchemy_store.py    From mlflow with Apache License 2.0 4 votes vote down vote up
def test_metrics_materialization_upgrade_succeeds_and_produces_expected_latest_metric_values(
            self):
        """
        Tests the ``89d4b8295536_create_latest_metrics_table`` migration by migrating and querying
        the MLflow Tracking SQLite database located at
        /mlflow/tests/resources/db/db_version_7ac759974ad8_with_metrics.sql. This database contains
        metric entries populated by the following metrics generation script:
        https://gist.github.com/dbczumar/343173c6b8982a0cc9735ff19b5571d9.

        First, the database is upgraded from its HEAD revision of
        ``7ac755974ad8_update_run_tags_with_larger_limit`` to the latest revision via
        ``mlflow db upgrade``.

        Then, the test confirms that the metric entries returned by calls
        to ``SqlAlchemyStore.get_run()`` are consistent between the latest revision and the
        ``7ac755974ad8_update_run_tags_with_larger_limit`` revision. This is confirmed by
        invoking ``SqlAlchemyStore.get_run()`` for each run id that is present in the upgraded
        database and comparing the resulting runs' metric entries to a JSON dump taken from the
        SQLite database prior to the upgrade (located at
        mlflow/tests/resources/db/db_version_7ac759974ad8_with_metrics_expected_values.json).
        This JSON dump can be replicated by installing MLflow version 1.2.0 and executing the
        following code from the directory containing this test suite:

        >>> import json
        >>> import mlflow
        >>> from mlflow.tracking.client import MlflowClient
        >>> mlflow.set_tracking_uri(
        ...     "sqlite:///../../resources/db/db_version_7ac759974ad8_with_metrics.sql")
        >>> client = MlflowClient()
        >>> summary_metrics = {
        ...     run.info.run_id: run.data.metrics for run
        ...     in client.search_runs(experiment_ids="0")
        ... }
        >>> with open("dump.json", "w") as dump_file:
        >>>     json.dump(summary_metrics, dump_file, indent=4)
        """
        current_dir = os.path.dirname(os.path.abspath(__file__))
        db_resources_path = os.path.normpath(
            os.path.join(current_dir, os.pardir, os.pardir, "resources", "db"))
        expected_metric_values_path = os.path.join(
            db_resources_path, "db_version_7ac759974ad8_with_metrics_expected_values.json")
        with TempDir() as tmp_db_dir:
            db_path = tmp_db_dir.path("tmp_db.sql")
            db_url = "sqlite:///" + db_path
            shutil.copyfile(
                src=os.path.join(db_resources_path, "db_version_7ac759974ad8_with_metrics.sql"),
                dst=db_path)

            invoke_cli_runner(mlflow.db.commands, ['upgrade', db_url])
            store = self._get_store(db_uri=db_url)
            with open(expected_metric_values_path, "r") as f:
                expected_metric_values = json.load(f)

            for run_id, expected_metrics in expected_metric_values.items():
                fetched_run = store.get_run(run_id=run_id)
                assert fetched_run.data.metrics == expected_metrics 
Example #23
Source File: test_mlflow_reporter.py    From gordo with GNU Affero General Public License v3.0 4 votes vote down vote up
def test_get_run_id_external_calls(
    mock_create_run, mock_create_experiment, mock_get_experiment, tmpdir
):
    """
    Test logic for creating an experiment if it does not exist to create new runs
    """

    class MockRunInfo:
        def __init__(self, run_id):
            self.run_id = run_id

    class MockRun:
        def __init__(self, run_id):
            self.info = MockRunInfo(run_id)

    class MockExperiment:
        def __init__(self, experiment_id):
            self.experiment_id = experiment_id

    def _test_calls(test_run_id, n_create_exp, n_create_run):
        """Test that number of calls match those specified"""
        run_id = mlu.get_run_id(client, experiment_name, model_key)
        assert mock_get_experiment.call_count == 1
        assert mock_create_experiment.call_count == n_create_exp
        assert mock_create_run.call_count == n_create_run
        assert run_id == test_run_id

    # Dummy test name/IDs
    experiment_name = "test_experiment"
    test_experiment_id = "dummy_exp_id"
    test_run_id = "dummy_run_id"
    model_key = "dummy_model_key"

    mlflow.set_tracking_uri(f"file:{tmpdir}")
    client = mlu.MlflowClient()

    # Experiment exists
    # Create a run with existing experiment_id
    mock_get_experiment.return_value = MockExperiment(test_experiment_id)
    mock_create_experiment.return_value = MockExperiment(test_experiment_id)
    mock_create_run.return_value = MockRun(test_run_id)
    _test_calls(test_run_id, n_create_exp=0, n_create_run=1)

    # Reset call counts
    for m in [mock_get_experiment, mock_create_experiment, mock_create_run]:
        m.call_count = 0

    # Experiment doesn't exist
    # Create an experiment and use its ID to create a run
    mock_get_experiment.return_value = None
    mock_create_experiment.return_value = MockExperiment(test_experiment_id)
    mock_create_run.return_value = MockRun(test_run_id)
    _test_calls(test_run_id, n_create_exp=1, n_create_run=1) 
Example #24
Source File: mlflow_utils.py    From nucleus7 with Mozilla Public License 2.0 4 votes vote down vote up
def create_new_or_continue_experiment(project_dir: str):
    """
    Creates a new experiment or continues already existing one.

    Experiment name is the name of the project_dir

    Parameters
    ----------
    project_dir
        project directory
    """
    mlflow.set_tracking_uri(None)
    experiment_name = project_utils.get_project_name_from_directory(project_dir)
    if "MLFLOW_TRACKING_URI" not in os.environ:
        tracking_uri = os.path.join(os.path.split(project_dir)[0], "mlruns")
        tracking_uri = os.path.realpath(tracking_uri)
        mlflow.set_tracking_uri(tracking_uri)
    mlflow.set_experiment(experiment_name)