Python mlflow.tracking.MlflowClient() Examples
The following are 30
code examples of mlflow.tracking.MlflowClient().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
mlflow.tracking
, or try the search function
.
Example #1
Source File: loggers.py From OpenKiwi with GNU Affero General Public License v3.0 | 6 votes |
def _retrieve_mlflow_experiment_id(name, create=False): experiment_id = None if name: existing_experiment = MlflowClient().get_experiment_by_name(name) if existing_experiment: experiment_id = existing_experiment.experiment_id else: if create: experiment_id = mlflow.create_experiment(name) else: raise Exception( 'Experiment "{}" not found in {}'.format( name, mlflow.get_tracking_uri() ) ) return experiment_id
Example #2
Source File: test_mlflow.py From optuna with MIT License | 6 votes |
def test_metric_name(tmpdir: py.path.local) -> None: tracking_file_name = "file:{}".format(tmpdir) metric_name = "my_metric_name" mlflc = MLflowCallback(tracking_uri=tracking_file_name, metric_name=metric_name) study = optuna.create_study(study_name="my_study") study.optimize(_objective_func, n_trials=3, callbacks=[mlflc]) mlfl_client = MlflowClient(tracking_file_name) experiments = mlfl_client.list_experiments() experiment = experiments[0] experiment_id = experiment.experiment_id run_infos = mlfl_client.list_run_infos(experiment_id) first_run_id = run_infos[0].run_id first_run = mlfl_client.get_run(first_run_id) first_run_dict = first_run.to_dictionary() assert metric_name in first_run_dict["data"]["metrics"]
Example #3
Source File: mlflow.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def __init__(self, experiment_name: str = 'default', tracking_uri: Optional[str] = None, tags: Optional[Dict[str, Any]] = None, save_dir: Optional[str] = None): if not _MLFLOW_AVAILABLE: raise ImportError('You want to use `mlflow` logger which is not installed yet,' ' install it with `pip install mlflow`.') super().__init__() if not tracking_uri and save_dir: tracking_uri = f'file:{os.sep * 2}{save_dir}' self._mlflow_client = MlflowClient(tracking_uri) self.experiment_name = experiment_name self._run_id = None self.tags = tags
Example #4
Source File: test_models_artifact_repo.py From mlflow with Apache License 2.0 | 6 votes |
def test_models_artifact_repo_init_with_stage_uri( host_creds_mock): # pylint: disable=unused-argument model_uri = "models:/MyModel/Production" artifact_location = "dbfs://databricks/mlflow-registry/12345/models/keras-model" model_version_detailed = ModelVersion("MyModel", "10", "2345671890", "234567890", "some description", "UserID", "Production", "source", "run12345") get_latest_versions_patch = mock.patch.object(MlflowClient, "get_latest_versions", return_value=[model_version_detailed]) get_model_version_download_uri_patch = mock.patch.object(MlflowClient, "get_model_version_download_uri", return_value=artifact_location) with get_latest_versions_patch, get_model_version_download_uri_patch: models_repo = ModelsArtifactRepository(model_uri) assert models_repo.artifact_uri == model_uri assert isinstance(models_repo.repo, DbfsRestArtifactRepository) assert models_repo.repo.artifact_uri == artifact_location
Example #5
Source File: __init__.py From mlflow with Apache License 2.0 | 6 votes |
def _wait_for(submitted_run_obj): """Wait on the passed-in submitted run, reporting its status to the tracking server.""" run_id = submitted_run_obj.run_id active_run = None # Note: there's a small chance we fail to report the run's status to the tracking server if # we're interrupted before we reach the try block below try: active_run = tracking.MlflowClient().get_run(run_id) if run_id is not None else None if submitted_run_obj.wait(): _logger.info("=== Run (ID '%s') succeeded ===", run_id) _maybe_set_run_terminated(active_run, "FINISHED") else: _maybe_set_run_terminated(active_run, "FAILED") raise ExecutionException("Run (ID '%s') failed" % run_id) except KeyboardInterrupt: _logger.error("=== Run (ID '%s') interrupted, cancelling run ===", run_id) submitted_run_obj.cancel() _maybe_set_run_terminated(active_run, "FAILED") raise
Example #6
Source File: test_tracking.py From mlflow with Apache License 2.0 | 6 votes |
def test_parent_create_run(): with mlflow.start_run() as parent_run: parent_run_id = parent_run.info.run_id os.environ[_RUN_ID_ENV_VAR] = parent_run_id with mlflow.start_run() as parent_run: assert parent_run.info.run_id == parent_run_id with pytest.raises(Exception, match='To start a nested run'): mlflow.start_run() with mlflow.start_run(nested=True) as child_run: assert child_run.info.run_id != parent_run_id with mlflow.start_run(nested=True) as grand_child_run: pass def verify_has_parent_id_tag(child_id, expected_parent_id): tags = tracking.MlflowClient().get_run(child_id).data.tags assert tags[MLFLOW_PARENT_RUN_ID] == expected_parent_id verify_has_parent_id_tag(child_run.info.run_id, parent_run.info.run_id) verify_has_parent_id_tag(grand_child_run.info.run_id, child_run.info.run_id) assert mlflow.active_run() is None
Example #7
Source File: test_client.py From mlflow with Apache License 2.0 | 6 votes |
def test_client_registry_operations_raise_exception_with_unsupported_registry_store(): """ This test case ensures that Model Registry operations invoked on the `MlflowClient` fail with an informative error message when the registry store URI refers to a store that does not support Model Registry features (e.g., FileStore). """ with TempDir() as tmp: client = MlflowClient(registry_uri=tmp.path()) expected_failure_functions = [ client._get_registry_client, lambda: client.create_registered_model("test"), lambda: client.get_registered_model("test"), lambda: client.create_model_version("test", "source", "run_id"), lambda: client.get_model_version("test", 1), ] for func in expected_failure_functions: with pytest.raises(MlflowException) as exc: func() assert exc.value.error_code == ErrorCode.Name(FEATURE_DISABLED)
Example #8
Source File: databricks.py From mlflow with Apache License 2.0 | 6 votes |
def _print_description_and_log_tags(self): _logger.info( "=== Launched MLflow run as Databricks job run with ID %s." " Getting run status page URL... ===", self._databricks_run_id) run_info = self._job_runner.jobs_runs_get(self._databricks_run_id) jobs_page_url = run_info["run_page_url"] _logger.info("=== Check the run's status at %s ===", jobs_page_url) host_creds = databricks_utils.get_databricks_host_creds(self._job_runner.databricks_profile) tracking.MlflowClient().set_tag(self._mlflow_run_id, MLFLOW_DATABRICKS_RUN_URL, jobs_page_url) tracking.MlflowClient().set_tag(self._mlflow_run_id, MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, self._databricks_run_id) tracking.MlflowClient().set_tag(self._mlflow_run_id, MLFLOW_DATABRICKS_WEBAPP_URL, host_creds.host) job_id = run_info.get('job_id') # In some releases of Databricks we do not return the job ID. We start including it in DB # releases 2.80 and above. if job_id is not None: tracking.MlflowClient().set_tag(self._mlflow_run_id, MLFLOW_DATABRICKS_SHELL_JOB_ID, job_id)
Example #9
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_log_params(): expected_params = {"name_1": "c", "name_2": "b", "nested/nested/name": 5} with start_run() as active_run: run_id = active_run.info.run_id mlflow.log_params(expected_params) finished_run = tracking.MlflowClient().get_run(run_id) # Validate params assert finished_run.data.params == {"name_1": "c", "name_2": "b", "nested/nested/name": "5"}
Example #10
Source File: test_client.py From mlflow with Apache License 2.0 | 5 votes |
def test_client_search_runs_defaults(mock_store): MlflowClient().search_runs([1, 2, 3]) mock_store.search_runs.assert_called_once_with(experiment_ids=[1, 2, 3], filter_string="", run_view_type=ViewType.ACTIVE_ONLY, max_results=SEARCH_MAX_RESULTS_DEFAULT, order_by=None, page_token=None)
Example #11
Source File: test_client.py From mlflow with Apache License 2.0 | 5 votes |
def test_client_create_run_overrides(mock_store): experiment_id = mock.Mock() user = mock.Mock() start_time = mock.Mock() tags = { MLFLOW_USER: user, MLFLOW_PARENT_RUN_ID: mock.Mock(), MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.JOB), MLFLOW_SOURCE_NAME: mock.Mock(), MLFLOW_PROJECT_ENTRY_POINT: mock.Mock(), MLFLOW_GIT_COMMIT: mock.Mock(), "other-key": "other-value" } MlflowClient().create_run(experiment_id, start_time, tags) mock_store.create_run.assert_called_once_with( experiment_id=experiment_id, user_id=user, start_time=start_time, tags=[RunTag(key, value) for key, value in tags.items()], ) mock_store.reset_mock() MlflowClient().create_run(experiment_id, start_time, tags) mock_store.create_run.assert_called_once_with( experiment_id=experiment_id, user_id=user, start_time=start_time, tags=[RunTag(key, value) for key, value in tags.items()] )
Example #12
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_search_runs_multiple_experiments(): experiment_ids = [mlflow.create_experiment("exp__{}".format(exp_id)) for exp_id in range(1, 4)] for eid in experiment_ids: with mlflow.start_run(experiment_id=eid): mlflow.log_metric("m0", 1) mlflow.log_metric("m_{}".format(eid), 2) assert len(MlflowClient().search_runs(experiment_ids, "metrics.m0 > 0", ViewType.ALL)) == 3 assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_1 > 0", ViewType.ALL)) == 1 assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_2 = 2", ViewType.ALL)) == 1 assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_3 < 4", ViewType.ALL)) == 1
Example #13
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_get_artifact_uri_appends_to_uri_path_component_correctly( artifact_location, expected_uri_format): client = MlflowClient() client.create_experiment("get-artifact-uri-test", artifact_location=artifact_location) mlflow.set_experiment("get-artifact-uri-test") with mlflow.start_run(): run_id = mlflow.active_run().info.run_id for artifact_path in ["path/to/artifact", "/artifact/path", "arty.txt"]: artifact_uri = mlflow.get_artifact_uri(artifact_path) assert artifact_uri == tracking.artifact_utils.get_artifact_uri(run_id, artifact_path) assert artifact_uri == expected_uri_format.format( run_id=run_id, path=artifact_path.lstrip("/"))
Example #14
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_start_run_exp_id_0(): mlflow.set_experiment("some-experiment") # Create a run and verify that the current active experiment is the one we just set with mlflow.start_run() as active_run: exp_id = active_run.info.experiment_id assert exp_id != FileStore.DEFAULT_EXPERIMENT_ID assert MlflowClient().get_experiment(exp_id).name == "some-experiment" # Set experiment ID to 0 when creating a run, verify that the specified experiment ID is honored with mlflow.start_run(experiment_id=0) as active_run: assert active_run.info.experiment_id == FileStore.DEFAULT_EXPERIMENT_ID
Example #15
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_start_deleted_run(): run_id = None with mlflow.start_run() as active_run: run_id = active_run.info.run_id tracking.MlflowClient().delete_run(run_id) with pytest.raises(MlflowException, matches='because it is in the deleted state.'): with mlflow.start_run(run_id=run_id): pass assert mlflow.active_run() is None
Example #16
Source File: test_client.py From mlflow with Apache License 2.0 | 5 votes |
def test_client_search_runs_string_experiment_id(mock_store): MlflowClient().search_runs("abc") mock_store.search_runs.assert_called_once_with(experiment_ids=["abc"], filter_string="", run_view_type=ViewType.ACTIVE_ONLY, max_results=SEARCH_MAX_RESULTS_DEFAULT, order_by=None, page_token=None)
Example #17
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_log_param(): with start_run() as active_run: run_id = active_run.info.run_id mlflow.log_param("name_1", "a") mlflow.log_param("name_2", "b") mlflow.log_param("nested/nested/name", 5) finished_run = tracking.MlflowClient().get_run(run_id) # Validate params assert finished_run.data.params == {"name_1": "a", "name_2": "b", "nested/nested/name": "5"}
Example #18
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_log_metric_validation(): with start_run() as active_run: run_id = active_run.info.run_id with pytest.raises(MlflowException) as e: mlflow.log_metric("name_1", "apple") assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) finished_run = tracking.MlflowClient().get_run(run_id) assert len(finished_run.data.metrics) == 0
Example #19
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_set_tags(): exact_expected_tags = {"name_1": "c", "name_2": "b", "nested/nested/name": 5} approx_expected_tags = set([MLFLOW_USER, MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE]) with start_run() as active_run: run_id = active_run.info.run_id mlflow.set_tags(exact_expected_tags) finished_run = tracking.MlflowClient().get_run(run_id) # Validate tags assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags) for tag_key, tag_val in finished_run.data.tags.items(): if tag_key in approx_expected_tags: pass else: assert str(exact_expected_tags[tag_key]) == tag_val
Example #20
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_log_metrics_uses_common_timestamp_and_step_per_invocation(step_kwarg): expected_metrics = {"name_1": 30, "name_2": -3, "nested/nested/name": 40} with start_run() as active_run: run_id = active_run.info.run_id mlflow.log_metrics(expected_metrics, step=step_kwarg) finished_run = tracking.MlflowClient().get_run(run_id) # Validate metric key/values match what we expect, and that all metrics have the same timestamp assert len(finished_run.data.metrics) == len(expected_metrics) for key, value in finished_run.data.metrics.items(): assert expected_metrics[key] == value common_timestamp = finished_run.data._metric_objs[0].timestamp expected_step = step_kwarg if step_kwarg is not None else 0 for metric_obj in finished_run.data._metric_objs: assert metric_obj.timestamp == common_timestamp assert metric_obj.step == expected_step
Example #21
Source File: test_client.py From mlflow with Apache License 2.0 | 5 votes |
def test_client_search_runs_filter(mock_store): MlflowClient().search_runs(["a", "b", "c"], "my filter") mock_store.search_runs.assert_called_once_with(experiment_ids=["a", "b", "c"], filter_string="my filter", run_view_type=ViewType.ACTIVE_ONLY, max_results=SEARCH_MAX_RESULTS_DEFAULT, order_by=None, page_token=None)
Example #22
Source File: test_client.py From mlflow with Apache License 2.0 | 5 votes |
def test_client_search_runs_view_type(mock_store): MlflowClient().search_runs(["a", "b", "c"], "my filter", ViewType.DELETED_ONLY) mock_store.search_runs.assert_called_once_with(experiment_ids=["a", "b", "c"], filter_string="my filter", run_view_type=ViewType.DELETED_ONLY, max_results=SEARCH_MAX_RESULTS_DEFAULT, order_by=None, page_token=None)
Example #23
Source File: test_client.py From mlflow with Apache License 2.0 | 5 votes |
def test_client_search_runs_max_results(mock_store): MlflowClient().search_runs([5], "my filter", ViewType.ALL, 2876) mock_store.search_runs.assert_called_once_with(experiment_ids=[5], filter_string="my filter", run_view_type=ViewType.ALL, max_results=2876, order_by=None, page_token=None)
Example #24
Source File: loggers.py From OpenKiwi with GNU Affero General Public License v3.0 | 5 votes |
def experiment_name(self): # return MlflowClient().get_experiment(self.experiment_id).name return self._experiment_name
Example #25
Source File: test_client.py From mlflow with Apache License 2.0 | 5 votes |
def test_client_search_runs_order_by(mock_store): MlflowClient().search_runs([5], order_by=["a", "b"]) mock_store.search_runs.assert_called_once_with(experiment_ids=[5], filter_string="", run_view_type=ViewType.ACTIVE_ONLY, max_results=SEARCH_MAX_RESULTS_DEFAULT, order_by=["a", "b"], page_token=None)
Example #26
Source File: test_client.py From mlflow with Apache License 2.0 | 5 votes |
def test_client_search_runs_page_token(mock_store): MlflowClient().search_runs([5], page_token="blah") mock_store.search_runs.assert_called_once_with(experiment_ids=[5], filter_string="", run_view_type=ViewType.ACTIVE_ONLY, max_results=SEARCH_MAX_RESULTS_DEFAULT, order_by=None, page_token="blah")
Example #27
Source File: test_client.py From mlflow with Apache License 2.0 | 5 votes |
def test_update_registered_model(mock_registry_store): """ Update registered model no longer supports name change. """ expected_return_value = "some expected return value." mock_registry_store.rename_registered_model.return_value = expected_return_value expected_return_value_2 = "other expected return value." mock_registry_store.update_registered_model.return_value = expected_return_value_2 res = MlflowClient(registry_uri="sqlite:///somedb.db").update_registered_model( name="orig name", description="new description") assert expected_return_value_2 == res mock_registry_store.update_registered_model.assert_called_once_with( name="orig name", description="new description") mock_registry_store.rename_registered_model.assert_not_called()
Example #28
Source File: test_models_artifact_repo.py From mlflow with Apache License 2.0 | 5 votes |
def test_models_artifact_repo_init_with_version_uri( host_creds_mock): # pylint: disable=unused-argument model_uri = "models:/MyModel/12" artifact_location = "dbfs://databricks/mlflow-registry/12345/models/keras-model" get_model_version_download_uri_patch = mock.patch.object(MlflowClient, "get_model_version_download_uri", return_value=artifact_location) with get_model_version_download_uri_patch: models_repo = ModelsArtifactRepository(model_uri) assert models_repo.artifact_uri == model_uri assert isinstance(models_repo.repo, DbfsRestArtifactRepository) assert models_repo.repo.artifact_uri == artifact_location
Example #29
Source File: test_models_artifact_repo.py From mlflow with Apache License 2.0 | 5 votes |
def test_models_artifact_repo_uses_repo_download_artifacts(): """ ``ModelsArtifactRepository`` should delegate `download_artifacts` to its ``self.repo.download_artifacts`` function. """ model_uri = "models:/MyModel/12" artifact_location = "s3://blah_bucket/" get_model_version_download_uri_patch = mock.patch.object(MlflowClient, "get_model_version_download_uri", return_value=artifact_location) with get_model_version_download_uri_patch: models_repo = ModelsArtifactRepository(model_uri) models_repo.repo = Mock() models_repo.download_artifacts('artifact_path', 'dst_path') models_repo.repo.download_artifacts.assert_called_once()
Example #30
Source File: test_mlflow_logger.py From ignite with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_integration(dirname): n_epochs = 5 data = list(range(50)) losses = torch.rand(n_epochs * len(data)) losses_iter = iter(losses) def update_fn(engine, batch): return next(losses_iter) trainer = Engine(update_fn) mlflow_logger = MLflowLogger(tracking_uri=os.path.join(dirname, "mlruns")) true_values = [] def dummy_handler(engine, logger, event_name): global_step = engine.state.get_event_attrib_value(event_name) v = global_step * 0.1 true_values.append(v) logger.log_metrics({"{}".format("test_value"): v}, step=global_step) mlflow_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED) import mlflow active_run = mlflow.active_run() trainer.run(data, max_epochs=n_epochs) mlflow_logger.close() from mlflow.tracking import MlflowClient client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns")) stored_values = client.get_metric_history(active_run.info.run_id, "test_value") for t, s in zip(true_values, stored_values): assert pytest.approx(t) == s.value