Python mlflow.tracking() Examples

The following are 16 code examples of mlflow.tracking(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module mlflow , or try the search function .
Example #1
Source File:    From gordo with GNU Affero General Public License v3.0 6 votes vote down vote up
def get_workspace_kwargs() -> dict:
    """Get AzureML keyword arguments from environment

    The name of this environment variable is set in the Argo workflow template,
    and its value should be in the format:

    workspace_kwargs: dict
        AzureML Workspace configuration to use for remote MLFlow tracking. See
    return get_kwargs_from_secret(
        "AZUREML_WORKSPACE_STR", ["subscription_id", "resource_group", "workspace_name"]
Example #2
Source File:    From mlflow with Apache License 2.0 6 votes vote down vote up
def __init__(self, artifact_uri):
        super(DatabricksArtifactRepository, self).__init__(artifact_uri)
        if not artifact_uri.startswith('dbfs:/'):
            raise MlflowException(message='DatabricksArtifactRepository URI must start with dbfs:/',
        if not is_databricks_acled_artifacts_uri(artifact_uri):
            raise MlflowException(message=('Artifact URI incorrect. Expected path prefix to be'
                                           ' databricks/mlflow-tracking/path/to/artifact/..'),
        self.run_id = self._extract_run_id(self.artifact_uri)

        # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute
        # the path of `artifact_uri` relative to the MLflow Run's artifact root
        # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact
        # repository will be performed relative to this computed location
        artifact_repo_root_path = extract_and_normalize_path(artifact_uri)
        run_artifact_root_uri = self._get_run_artifact_root(self.run_id)
        run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri)
        run_relative_root_path = posixpath.relpath(
            path=artifact_repo_root_path, start=run_artifact_root_path
        # If the paths are equal, then use empty string over "./" for ListArtifact compatibility.
        self.run_relative_artifact_repo_root_path = \
            "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path 
Example #3
Source File:    From OpenKiwi with GNU Affero General Public License v3.0 5 votes vote down vote up
def run_uuid(self):
        return mlflow.tracking.fluent.active_run().info.run_uuid 
Example #4
Source File:    From OpenKiwi with GNU Affero General Public License v3.0 5 votes vote down vote up
def experiment_id(self):
        return mlflow.tracking.fluent.active_run().info.experiment_id 
Example #5
Source File:    From OpenKiwi with GNU Affero General Public License v3.0 5 votes vote down vote up
def _is_remote(self):
        return not mlflow.tracking.utils._is_local_uri(
Example #6
Source File:    From gordo with GNU Affero General Public License v3.0 5 votes vote down vote up
def get_run_id(client: MlflowClient, experiment_name: str, model_key: str) -> str:
    Get an existing or create a new run for the given model_key and experiment_name.

    The model key corresponds to a unique configuration of the model. The corresponding
    run must be manually stopped using the `mlflow.tracking.MlflowClient.set_terminated`

    client: mlflow.tracking.MlflowClient
        Client with tracking uri set to AzureML if configured.
    experiment_name: str
        Name of experiment to log to.
    model_key: str
        Unique ID of model configuration.

    run_id: str
        Unique ID of MLflow run to log to.
    experiment = client.get_experiment_by_name(experiment_name)

    experiment_id = (
        getattr(experiment, "experiment_id")
        if experiment
        else client.create_experiment(experiment_name)
    return client.create_run(experiment_id, tags={"model_key": model_key}).info.run_id 
Example #7
Source File:    From gordo with GNU Affero General Public License v3.0 5 votes vote down vote up
def mlflow_context(
    name: str,
    model_key: str = uuid4().hex,
    workspace_kwargs: dict = {},
    service_principal_kwargs: dict = {},
    Generate MLflow logger function with either a local or AzureML backend

    name: str
        The name of the log group to log to (e.g. a model name).
    model_key: str
        Unique ID of logging run.
    workspace_kwargs: dict
        AzureML Workspace configuration to use for remote MLFlow tracking. See
    service_principal_kwargs: dict
        AzureML ServicePrincipalAuthentication keyword arguments. See

    >>> with tempfile.TemporaryDirectory as tmp_dir:
    ...     mlflow.set_tracking_uri(f"file:{tmp_dir}")
    ...     with mlflow_context("log_group", "unique_key", {}, {}) as (mlflow_client, run_id):
    ...         log_machine(machine) # doctest: +SKIP
    mlflow_client = get_mlflow_client(workspace_kwargs, service_principal_kwargs)
    run_id = get_run_id(mlflow_client, experiment_name=name, model_key=model_key)
        f"MLflow client configured to use {'AzureML' if workspace_kwargs else 'local backend'}"

    yield mlflow_client, run_id

Example #8
Source File:    From mlflow with Apache License 2.0 5 votes vote down vote up
def commands():
    Manage runs. To manage runs of experiments associated with a tracking server, set the
    MLFLOW_TRACKING_URI environment variable to the URL of the desired server.
Example #9
Source File:    From mlflow with Apache License 2.0 5 votes vote down vote up
def list_run(experiment_id, view):
    List all runs of the specified experiment in the configured tracking server.
    store = _get_store()
    view_type = ViewType.from_string(view) if view else ViewType.ACTIVE_ONLY
    runs = store.search_runs([experiment_id], None, view_type)
    table = []
    for run in runs:
        tags = {k: v for k, v in}
        run_name = tags.get(MLFLOW_RUN_NAME, "")
        table.append([conv_longdate_to_str(, run_name,])
    print(tabulate(sorted(table, reverse=True), headers=["Date", "Name", "ID"])) 
Example #10
Source File:    From mlflow with Apache License 2.0 5 votes vote down vote up
def _extract_run_id(artifact_uri):
        The artifact_uri is expected to be
        Once the path from the input uri is extracted and normalized, it is
        expected to be of the form

        Hence the run_id is the 4th element of the normalized path.

        :return: run_id extracted from the artifact_uri
        artifact_path = extract_and_normalize_path(artifact_uri)
        return artifact_path.split('/')[3] 
Example #11
Source File:    From mlflow with Apache License 2.0 5 votes vote down vote up
def _call_endpoint(self, service, api, json_body):
        db_profile = get_db_profile_from_uri(mlflow.tracking.get_tracking_uri())
        db_creds = get_databricks_host_creds(db_profile)
        endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api]
        response_proto = api.Response()
        return call_endpoint(db_creds, endpoint, method, json_body, response_proto) 
Example #12
Source File:    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_run_databricks_validations(
        tmpdir, cluster_spec_mock,  # pylint: disable=unused-argument
        tracking_uri_mock, dbfs_mocks, set_tag_mock):  # pylint: disable=unused-argument
    Tests that running on Databricks fails before making any API requests if validations fail.
    with mock.patch.dict(os.environ, {'DATABRICKS_HOST': 'test-host', 'DATABRICKS_TOKEN': 'foo'}),\
            as db_api_req_mock:
        # Test bad tracking URI
        tracking_uri_mock.return_value = tmpdir.strpath
        with pytest.raises(ExecutionException):
            run_databricks_project(cluster_spec_mock, synchronous=True)
        assert db_api_req_mock.call_count == 0
        mlflow_service = mlflow.tracking.MlflowClient()
        assert (len(mlflow_service.list_run_infos(experiment_id=FileStore.DEFAULT_EXPERIMENT_ID))
                == 0)
        tracking_uri_mock.return_value = "http://"
        # Test misspecified parameters
        with pytest.raises(ExecutionException):
                TEST_PROJECT_DIR, backend="databricks", entry_point="greeter",
        assert db_api_req_mock.call_count == 0
        # Test bad cluster spec
        with pytest.raises(ExecutionException):
  , backend="databricks", synchronous=True,
        assert db_api_req_mock.call_count == 0
        # Test that validations pass with good tracking URIs
        databricks.before_run_validations("http://", cluster_spec_mock)
        databricks.before_run_validations("databricks", cluster_spec_mock) 
Example #13
Source File:    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_get_tracking_uri_for_run():
    assert databricks._get_tracking_uri_for_run() == "http://some-uri"
    assert databricks._get_tracking_uri_for_run() == "databricks"
    with mock.patch.dict(os.environ, {mlflow.tracking._TRACKING_URI_ENV_VAR: "http://some-uri"}):
        assert mlflow.tracking._tracking_service.utils.get_tracking_uri() == "http://some-uri" 
Example #14
Source File:    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_integration(dirname):

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    trainer = Engine(update_fn)

    mlflow_logger = MLflowLogger(tracking_uri=os.path.join(dirname, "mlruns"))

    true_values = []

    def dummy_handler(engine, logger, event_name):
        global_step = engine.state.get_event_attrib_value(event_name)
        v = global_step * 0.1
        logger.log_metrics({"{}".format("test_value"): v}, step=global_step)

    mlflow_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

    import mlflow

    active_run = mlflow.active_run(), max_epochs=n_epochs)

    from mlflow.tracking import MlflowClient

    client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns"))
    stored_values = client.get_metric_history(, "test_value")

    for t, s in zip(true_values, stored_values):
        assert pytest.approx(t) == s.value 
Example #15
Source File:    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_integration_as_context_manager(dirname):

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    true_values = []

    with MLflowLogger(os.path.join(dirname, "mlruns")) as mlflow_logger:

        trainer = Engine(update_fn)

        def dummy_handler(engine, logger, event_name):
            global_step = engine.state.get_event_attrib_value(event_name)
            v = global_step * 0.1
            logger.log_metrics({"{}".format("test_value"): v}, step=global_step)

        mlflow_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

        import mlflow

        active_run = mlflow.active_run(), max_epochs=n_epochs)

    from mlflow.tracking import MlflowClient

    client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns"))
    stored_values = client.get_metric_history(, "test_value")

    for t, s in zip(true_values, stored_values):
        assert pytest.approx(t) == s.value 
Example #16
Source File:    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_mlflow_bad_metric_name_handling(dirname):
    import mlflow

    true_values = [123.0, 23.4, 333.4]
    with MLflowLogger(os.path.join(dirname, "mlruns")) as mlflow_logger:

        active_run = mlflow.active_run()

        handler = OutputHandler(tag="training", metric_names="all")
        engine = Engine(lambda e, b: None)
        engine.state = State(metrics={"metric:0 in %": 123.0, "metric 0": 1000.0,})

        with pytest.warns(UserWarning, match=r"MLflowLogger output_handler encountered an invalid metric name"):

            engine.state.epoch = 1
            handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED)

            for i, v in enumerate(true_values):
                engine.state.epoch += 1
                engine.state.metrics["metric 0"] = v
                handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED)

    from mlflow.tracking import MlflowClient

    client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns"))
    stored_values = client.get_metric_history(, "training metric 0")

    for t, s in zip([1000.0,] + true_values, stored_values):
        assert t == s.value