Python mlflow.tracking.MlflowClient() Examples

The following are 30 code examples of mlflow.tracking.MlflowClient(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module mlflow.tracking , or try the search function .
Example #1
Source File: loggers.py    From OpenKiwi with GNU Affero General Public License v3.0 6 votes vote down vote up
def _retrieve_mlflow_experiment_id(name, create=False):
        experiment_id = None
        if name:
            existing_experiment = MlflowClient().get_experiment_by_name(name)
            if existing_experiment:
                experiment_id = existing_experiment.experiment_id
            else:
                if create:
                    experiment_id = mlflow.create_experiment(name)
                else:
                    raise Exception(
                        'Experiment "{}" not found in {}'.format(
                            name, mlflow.get_tracking_uri()
                        )
                    )
        return experiment_id 
Example #2
Source File: test_mlflow.py    From optuna with MIT License 6 votes vote down vote up
def test_metric_name(tmpdir: py.path.local) -> None:

    tracking_file_name = "file:{}".format(tmpdir)
    metric_name = "my_metric_name"

    mlflc = MLflowCallback(tracking_uri=tracking_file_name, metric_name=metric_name)
    study = optuna.create_study(study_name="my_study")
    study.optimize(_objective_func, n_trials=3, callbacks=[mlflc])

    mlfl_client = MlflowClient(tracking_file_name)
    experiments = mlfl_client.list_experiments()

    experiment = experiments[0]
    experiment_id = experiment.experiment_id

    run_infos = mlfl_client.list_run_infos(experiment_id)

    first_run_id = run_infos[0].run_id
    first_run = mlfl_client.get_run(first_run_id)
    first_run_dict = first_run.to_dictionary()

    assert metric_name in first_run_dict["data"]["metrics"] 
Example #3
Source File: mlflow.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def __init__(self,
                 experiment_name: str = 'default',
                 tracking_uri: Optional[str] = None,
                 tags: Optional[Dict[str, Any]] = None,
                 save_dir: Optional[str] = None):

        if not _MLFLOW_AVAILABLE:
            raise ImportError('You want to use `mlflow` logger which is not installed yet,'
                              ' install it with `pip install mlflow`.')
        super().__init__()
        if not tracking_uri and save_dir:
            tracking_uri = f'file:{os.sep * 2}{save_dir}'
        self._mlflow_client = MlflowClient(tracking_uri)
        self.experiment_name = experiment_name
        self._run_id = None
        self.tags = tags 
Example #4
Source File: test_models_artifact_repo.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_models_artifact_repo_init_with_stage_uri(
        host_creds_mock):  # pylint: disable=unused-argument
    model_uri = "models:/MyModel/Production"
    artifact_location = "dbfs://databricks/mlflow-registry/12345/models/keras-model"
    model_version_detailed = ModelVersion("MyModel", "10", "2345671890", "234567890",
                                          "some description", "UserID",
                                          "Production", "source", "run12345")
    get_latest_versions_patch = mock.patch.object(MlflowClient, "get_latest_versions",
                                                  return_value=[model_version_detailed])
    get_model_version_download_uri_patch = mock.patch.object(MlflowClient,
                                                             "get_model_version_download_uri",
                                                             return_value=artifact_location)
    with get_latest_versions_patch, get_model_version_download_uri_patch:
        models_repo = ModelsArtifactRepository(model_uri)
        assert models_repo.artifact_uri == model_uri
        assert isinstance(models_repo.repo, DbfsRestArtifactRepository)
        assert models_repo.repo.artifact_uri == artifact_location 
Example #5
Source File: __init__.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def _wait_for(submitted_run_obj):
    """Wait on the passed-in submitted run, reporting its status to the tracking server."""
    run_id = submitted_run_obj.run_id
    active_run = None
    # Note: there's a small chance we fail to report the run's status to the tracking server if
    # we're interrupted before we reach the try block below
    try:
        active_run = tracking.MlflowClient().get_run(run_id) if run_id is not None else None
        if submitted_run_obj.wait():
            _logger.info("=== Run (ID '%s') succeeded ===", run_id)
            _maybe_set_run_terminated(active_run, "FINISHED")
        else:
            _maybe_set_run_terminated(active_run, "FAILED")
            raise ExecutionException("Run (ID '%s') failed" % run_id)
    except KeyboardInterrupt:
        _logger.error("=== Run (ID '%s') interrupted, cancelling run ===", run_id)
        submitted_run_obj.cancel()
        _maybe_set_run_terminated(active_run, "FAILED")
        raise 
Example #6
Source File: test_tracking.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_parent_create_run():
    with mlflow.start_run() as parent_run:
        parent_run_id = parent_run.info.run_id
    os.environ[_RUN_ID_ENV_VAR] = parent_run_id
    with mlflow.start_run() as parent_run:
        assert parent_run.info.run_id == parent_run_id
        with pytest.raises(Exception, match='To start a nested run'):
            mlflow.start_run()
        with mlflow.start_run(nested=True) as child_run:
            assert child_run.info.run_id != parent_run_id
            with mlflow.start_run(nested=True) as grand_child_run:
                pass

    def verify_has_parent_id_tag(child_id, expected_parent_id):
        tags = tracking.MlflowClient().get_run(child_id).data.tags
        assert tags[MLFLOW_PARENT_RUN_ID] == expected_parent_id

    verify_has_parent_id_tag(child_run.info.run_id, parent_run.info.run_id)
    verify_has_parent_id_tag(grand_child_run.info.run_id, child_run.info.run_id)
    assert mlflow.active_run() is None 
Example #7
Source File: test_client.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_client_registry_operations_raise_exception_with_unsupported_registry_store():
    """
    This test case ensures that Model Registry operations invoked on the `MlflowClient`
    fail with an informative error message when the registry store URI refers to a
    store that does not support Model Registry features (e.g., FileStore).
    """
    with TempDir() as tmp:
        client = MlflowClient(registry_uri=tmp.path())
        expected_failure_functions = [
            client._get_registry_client,
            lambda: client.create_registered_model("test"),
            lambda: client.get_registered_model("test"),
            lambda: client.create_model_version("test", "source", "run_id"),
            lambda: client.get_model_version("test", 1),
        ]
        for func in expected_failure_functions:
            with pytest.raises(MlflowException) as exc:
                func()
            assert exc.value.error_code == ErrorCode.Name(FEATURE_DISABLED) 
Example #8
Source File: databricks.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def _print_description_and_log_tags(self):
        _logger.info(
            "=== Launched MLflow run as Databricks job run with ID %s."
            " Getting run status page URL... ===",
            self._databricks_run_id)
        run_info = self._job_runner.jobs_runs_get(self._databricks_run_id)
        jobs_page_url = run_info["run_page_url"]
        _logger.info("=== Check the run's status at %s ===", jobs_page_url)
        host_creds = databricks_utils.get_databricks_host_creds(self._job_runner.databricks_profile)
        tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                        MLFLOW_DATABRICKS_RUN_URL, jobs_page_url)
        tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                        MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, self._databricks_run_id)
        tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                        MLFLOW_DATABRICKS_WEBAPP_URL, host_creds.host)
        job_id = run_info.get('job_id')
        # In some releases of Databricks we do not return the job ID. We start including it in DB
        # releases 2.80 and above.
        if job_id is not None:
            tracking.MlflowClient().set_tag(self._mlflow_run_id,
                                            MLFLOW_DATABRICKS_SHELL_JOB_ID, job_id) 
Example #9
Source File: test_tracking.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_log_params():
    expected_params = {"name_1": "c", "name_2": "b", "nested/nested/name": 5}
    with start_run() as active_run:
        run_id = active_run.info.run_id
        mlflow.log_params(expected_params)
    finished_run = tracking.MlflowClient().get_run(run_id)
    # Validate params
    assert finished_run.data.params == {"name_1": "c", "name_2": "b", "nested/nested/name": "5"} 
Example #10
Source File: test_client.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_client_search_runs_defaults(mock_store):
    MlflowClient().search_runs([1, 2, 3])
    mock_store.search_runs.assert_called_once_with(experiment_ids=[1, 2, 3],
                                                   filter_string="",
                                                   run_view_type=ViewType.ACTIVE_ONLY,
                                                   max_results=SEARCH_MAX_RESULTS_DEFAULT,
                                                   order_by=None,
                                                   page_token=None) 
Example #11
Source File: test_client.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_client_create_run_overrides(mock_store):
    experiment_id = mock.Mock()
    user = mock.Mock()
    start_time = mock.Mock()
    tags = {
        MLFLOW_USER: user,
        MLFLOW_PARENT_RUN_ID: mock.Mock(),
        MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.JOB),
        MLFLOW_SOURCE_NAME: mock.Mock(),
        MLFLOW_PROJECT_ENTRY_POINT: mock.Mock(),
        MLFLOW_GIT_COMMIT: mock.Mock(),
        "other-key": "other-value"
    }

    MlflowClient().create_run(experiment_id, start_time, tags)

    mock_store.create_run.assert_called_once_with(
        experiment_id=experiment_id,
        user_id=user,
        start_time=start_time,
        tags=[RunTag(key, value) for key, value in tags.items()],
    )
    mock_store.reset_mock()
    MlflowClient().create_run(experiment_id, start_time, tags)
    mock_store.create_run.assert_called_once_with(
        experiment_id=experiment_id,
        user_id=user,
        start_time=start_time,
        tags=[RunTag(key, value) for key, value in tags.items()]
    ) 
Example #12
Source File: test_tracking.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_search_runs_multiple_experiments():
    experiment_ids = [mlflow.create_experiment("exp__{}".format(exp_id)) for exp_id in range(1, 4)]
    for eid in experiment_ids:
        with mlflow.start_run(experiment_id=eid):
            mlflow.log_metric("m0", 1)
            mlflow.log_metric("m_{}".format(eid), 2)

    assert len(MlflowClient().search_runs(experiment_ids, "metrics.m0 > 0", ViewType.ALL)) == 3

    assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_1 > 0", ViewType.ALL)) == 1
    assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_2 = 2", ViewType.ALL)) == 1
    assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_3 < 4", ViewType.ALL)) == 1 
Example #13
Source File: test_tracking.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_get_artifact_uri_appends_to_uri_path_component_correctly(
        artifact_location, expected_uri_format):
    client = MlflowClient()
    client.create_experiment("get-artifact-uri-test", artifact_location=artifact_location)
    mlflow.set_experiment("get-artifact-uri-test")
    with mlflow.start_run():
        run_id = mlflow.active_run().info.run_id
        for artifact_path in ["path/to/artifact", "/artifact/path", "arty.txt"]:
            artifact_uri = mlflow.get_artifact_uri(artifact_path)
            assert artifact_uri == tracking.artifact_utils.get_artifact_uri(run_id, artifact_path)
            assert artifact_uri == expected_uri_format.format(
                run_id=run_id, path=artifact_path.lstrip("/")) 
Example #14
Source File: test_tracking.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_start_run_exp_id_0():
    mlflow.set_experiment("some-experiment")
    # Create a run and verify that the current active experiment is the one we just set
    with mlflow.start_run() as active_run:
        exp_id = active_run.info.experiment_id
        assert exp_id != FileStore.DEFAULT_EXPERIMENT_ID
        assert MlflowClient().get_experiment(exp_id).name == "some-experiment"
    # Set experiment ID to 0 when creating a run, verify that the specified experiment ID is honored
    with mlflow.start_run(experiment_id=0) as active_run:
        assert active_run.info.experiment_id == FileStore.DEFAULT_EXPERIMENT_ID 
Example #15
Source File: test_tracking.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_start_deleted_run():
    run_id = None
    with mlflow.start_run() as active_run:
        run_id = active_run.info.run_id
    tracking.MlflowClient().delete_run(run_id)
    with pytest.raises(MlflowException, matches='because it is in the deleted state.'):
        with mlflow.start_run(run_id=run_id):
            pass
    assert mlflow.active_run() is None 
Example #16
Source File: test_client.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_client_search_runs_string_experiment_id(mock_store):
    MlflowClient().search_runs("abc")
    mock_store.search_runs.assert_called_once_with(experiment_ids=["abc"],
                                                   filter_string="",
                                                   run_view_type=ViewType.ACTIVE_ONLY,
                                                   max_results=SEARCH_MAX_RESULTS_DEFAULT,
                                                   order_by=None,
                                                   page_token=None) 
Example #17
Source File: test_tracking.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_log_param():
    with start_run() as active_run:
        run_id = active_run.info.run_id
        mlflow.log_param("name_1", "a")
        mlflow.log_param("name_2", "b")
        mlflow.log_param("nested/nested/name", 5)
    finished_run = tracking.MlflowClient().get_run(run_id)
    # Validate params
    assert finished_run.data.params == {"name_1": "a", "name_2": "b", "nested/nested/name": "5"} 
Example #18
Source File: test_tracking.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_log_metric_validation():
    with start_run() as active_run:
        run_id = active_run.info.run_id
        with pytest.raises(MlflowException) as e:
            mlflow.log_metric("name_1", "apple")
    assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
    finished_run = tracking.MlflowClient().get_run(run_id)
    assert len(finished_run.data.metrics) == 0 
Example #19
Source File: test_tracking.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_set_tags():
    exact_expected_tags = {"name_1": "c", "name_2": "b", "nested/nested/name": 5}
    approx_expected_tags = set([MLFLOW_USER, MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE])
    with start_run() as active_run:
        run_id = active_run.info.run_id
        mlflow.set_tags(exact_expected_tags)
    finished_run = tracking.MlflowClient().get_run(run_id)
    # Validate tags
    assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags)
    for tag_key, tag_val in finished_run.data.tags.items():
        if tag_key in approx_expected_tags:
            pass
        else:
            assert str(exact_expected_tags[tag_key]) == tag_val 
Example #20
Source File: test_tracking.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_log_metrics_uses_common_timestamp_and_step_per_invocation(step_kwarg):
    expected_metrics = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
    with start_run() as active_run:
        run_id = active_run.info.run_id
        mlflow.log_metrics(expected_metrics, step=step_kwarg)
    finished_run = tracking.MlflowClient().get_run(run_id)
    # Validate metric key/values match what we expect, and that all metrics have the same timestamp
    assert len(finished_run.data.metrics) == len(expected_metrics)
    for key, value in finished_run.data.metrics.items():
        assert expected_metrics[key] == value
    common_timestamp = finished_run.data._metric_objs[0].timestamp
    expected_step = step_kwarg if step_kwarg is not None else 0
    for metric_obj in finished_run.data._metric_objs:
        assert metric_obj.timestamp == common_timestamp
        assert metric_obj.step == expected_step 
Example #21
Source File: test_client.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_client_search_runs_filter(mock_store):
    MlflowClient().search_runs(["a", "b", "c"], "my filter")
    mock_store.search_runs.assert_called_once_with(experiment_ids=["a", "b", "c"],
                                                   filter_string="my filter",
                                                   run_view_type=ViewType.ACTIVE_ONLY,
                                                   max_results=SEARCH_MAX_RESULTS_DEFAULT,
                                                   order_by=None,
                                                   page_token=None) 
Example #22
Source File: test_client.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_client_search_runs_view_type(mock_store):
    MlflowClient().search_runs(["a", "b", "c"], "my filter", ViewType.DELETED_ONLY)
    mock_store.search_runs.assert_called_once_with(experiment_ids=["a", "b", "c"],
                                                   filter_string="my filter",
                                                   run_view_type=ViewType.DELETED_ONLY,
                                                   max_results=SEARCH_MAX_RESULTS_DEFAULT,
                                                   order_by=None,
                                                   page_token=None) 
Example #23
Source File: test_client.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_client_search_runs_max_results(mock_store):
    MlflowClient().search_runs([5], "my filter", ViewType.ALL, 2876)
    mock_store.search_runs.assert_called_once_with(experiment_ids=[5],
                                                   filter_string="my filter",
                                                   run_view_type=ViewType.ALL,
                                                   max_results=2876,
                                                   order_by=None,
                                                   page_token=None) 
Example #24
Source File: loggers.py    From OpenKiwi with GNU Affero General Public License v3.0 5 votes vote down vote up
def experiment_name(self):
        # return MlflowClient().get_experiment(self.experiment_id).name
        return self._experiment_name 
Example #25
Source File: test_client.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_client_search_runs_order_by(mock_store):
    MlflowClient().search_runs([5], order_by=["a", "b"])
    mock_store.search_runs.assert_called_once_with(experiment_ids=[5],
                                                   filter_string="",
                                                   run_view_type=ViewType.ACTIVE_ONLY,
                                                   max_results=SEARCH_MAX_RESULTS_DEFAULT,
                                                   order_by=["a", "b"],
                                                   page_token=None) 
Example #26
Source File: test_client.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_client_search_runs_page_token(mock_store):
    MlflowClient().search_runs([5], page_token="blah")
    mock_store.search_runs.assert_called_once_with(experiment_ids=[5],
                                                   filter_string="",
                                                   run_view_type=ViewType.ACTIVE_ONLY,
                                                   max_results=SEARCH_MAX_RESULTS_DEFAULT,
                                                   order_by=None,
                                                   page_token="blah") 
Example #27
Source File: test_client.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_update_registered_model(mock_registry_store):
    """
    Update registered model no longer supports name change.
    """
    expected_return_value = "some expected return value."
    mock_registry_store.rename_registered_model.return_value = expected_return_value
    expected_return_value_2 = "other expected return value."
    mock_registry_store.update_registered_model.return_value = expected_return_value_2
    res = MlflowClient(registry_uri="sqlite:///somedb.db").update_registered_model(
        name="orig name", description="new description")
    assert expected_return_value_2 == res
    mock_registry_store.update_registered_model.assert_called_once_with(
        name="orig name", description="new description")
    mock_registry_store.rename_registered_model.assert_not_called() 
Example #28
Source File: test_models_artifact_repo.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_models_artifact_repo_init_with_version_uri(
        host_creds_mock):  # pylint: disable=unused-argument
    model_uri = "models:/MyModel/12"
    artifact_location = "dbfs://databricks/mlflow-registry/12345/models/keras-model"
    get_model_version_download_uri_patch = mock.patch.object(MlflowClient,
                                                             "get_model_version_download_uri",
                                                             return_value=artifact_location)
    with get_model_version_download_uri_patch:
        models_repo = ModelsArtifactRepository(model_uri)
        assert models_repo.artifact_uri == model_uri
        assert isinstance(models_repo.repo, DbfsRestArtifactRepository)
        assert models_repo.repo.artifact_uri == artifact_location 
Example #29
Source File: test_models_artifact_repo.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_models_artifact_repo_uses_repo_download_artifacts():
    """
    ``ModelsArtifactRepository`` should delegate `download_artifacts` to its
    ``self.repo.download_artifacts`` function.
    """
    model_uri = "models:/MyModel/12"
    artifact_location = "s3://blah_bucket/"
    get_model_version_download_uri_patch = mock.patch.object(MlflowClient,
                                                             "get_model_version_download_uri",
                                                             return_value=artifact_location)
    with get_model_version_download_uri_patch:
        models_repo = ModelsArtifactRepository(model_uri)
        models_repo.repo = Mock()
        models_repo.download_artifacts('artifact_path', 'dst_path')
        models_repo.repo.download_artifacts.assert_called_once() 
Example #30
Source File: test_mlflow_logger.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_integration(dirname):

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    trainer = Engine(update_fn)

    mlflow_logger = MLflowLogger(tracking_uri=os.path.join(dirname, "mlruns"))

    true_values = []

    def dummy_handler(engine, logger, event_name):
        global_step = engine.state.get_event_attrib_value(event_name)
        v = global_step * 0.1
        true_values.append(v)
        logger.log_metrics({"{}".format("test_value"): v}, step=global_step)

    mlflow_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

    import mlflow

    active_run = mlflow.active_run()

    trainer.run(data, max_epochs=n_epochs)
    mlflow_logger.close()

    from mlflow.tracking import MlflowClient

    client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns"))
    stored_values = client.get_metric_history(active_run.info.run_id, "test_value")

    for t, s in zip(true_values, stored_values):
        assert pytest.approx(t) == s.value