Python mlflow.tracking() Examples
The following are 16
code examples of mlflow.tracking().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
mlflow
, or try the search function
.
Example #1
Source File: mlflow.py From gordo with GNU Affero General Public License v3.0 | 6 votes |
def get_workspace_kwargs() -> dict: """Get AzureML keyword arguments from environment The name of this environment variable is set in the Argo workflow template, and its value should be in the format: `<subscription_id>:<resource_group>:<workspace_name>`. Returns ------- workspace_kwargs: dict AzureML Workspace configuration to use for remote MLFlow tracking. See :func:`gordo.builder.mlflow_utils.get_mlflow_client`. """ return get_kwargs_from_secret( "AZUREML_WORKSPACE_STR", ["subscription_id", "resource_group", "workspace_name"] )
Example #2
Source File: databricks_artifact_repo.py From mlflow with Apache License 2.0 | 6 votes |
def __init__(self, artifact_uri): super(DatabricksArtifactRepository, self).__init__(artifact_uri) if not artifact_uri.startswith('dbfs:/'): raise MlflowException(message='DatabricksArtifactRepository URI must start with dbfs:/', error_code=INVALID_PARAMETER_VALUE) if not is_databricks_acled_artifacts_uri(artifact_uri): raise MlflowException(message=('Artifact URI incorrect. Expected path prefix to be' ' databricks/mlflow-tracking/path/to/artifact/..'), error_code=INVALID_PARAMETER_VALUE) self.run_id = self._extract_run_id(self.artifact_uri) # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute # the path of `artifact_uri` relative to the MLflow Run's artifact root # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact # repository will be performed relative to this computed location artifact_repo_root_path = extract_and_normalize_path(artifact_uri) run_artifact_root_uri = self._get_run_artifact_root(self.run_id) run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri) run_relative_root_path = posixpath.relpath( path=artifact_repo_root_path, start=run_artifact_root_path ) # If the paths are equal, then use empty string over "./" for ListArtifact compatibility. self.run_relative_artifact_repo_root_path = \ "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path
Example #3
Source File: loggers.py From OpenKiwi with GNU Affero General Public License v3.0 | 5 votes |
def run_uuid(self): return mlflow.tracking.fluent.active_run().info.run_uuid
Example #4
Source File: loggers.py From OpenKiwi with GNU Affero General Public License v3.0 | 5 votes |
def experiment_id(self): return mlflow.tracking.fluent.active_run().info.experiment_id
Example #5
Source File: loggers.py From OpenKiwi with GNU Affero General Public License v3.0 | 5 votes |
def _is_remote(self): return not mlflow.tracking.utils._is_local_uri( mlflow.get_tracking_uri() )
Example #6
Source File: mlflow.py From gordo with GNU Affero General Public License v3.0 | 5 votes |
def get_run_id(client: MlflowClient, experiment_name: str, model_key: str) -> str: """ Get an existing or create a new run for the given model_key and experiment_name. The model key corresponds to a unique configuration of the model. The corresponding run must be manually stopped using the `mlflow.tracking.MlflowClient.set_terminated` method. Parameters ---------- client: mlflow.tracking.MlflowClient Client with tracking uri set to AzureML if configured. experiment_name: str Name of experiment to log to. model_key: str Unique ID of model configuration. Returns ------- run_id: str Unique ID of MLflow run to log to. """ experiment = client.get_experiment_by_name(experiment_name) experiment_id = ( getattr(experiment, "experiment_id") if experiment else client.create_experiment(experiment_name) ) return client.create_run(experiment_id, tags={"model_key": model_key}).info.run_id
Example #7
Source File: mlflow.py From gordo with GNU Affero General Public License v3.0 | 5 votes |
def mlflow_context( name: str, model_key: str = uuid4().hex, workspace_kwargs: dict = {}, service_principal_kwargs: dict = {}, ): """ Generate MLflow logger function with either a local or AzureML backend Parameters ---------- name: str The name of the log group to log to (e.g. a model name). model_key: str Unique ID of logging run. workspace_kwargs: dict AzureML Workspace configuration to use for remote MLFlow tracking. See :func:`gordo.builder.mlflow_utils.get_mlflow_client`. service_principal_kwargs: dict AzureML ServicePrincipalAuthentication keyword arguments. See :func:`gordo.builder.mlflow_utils.get_mlflow_client` Example ------- >>> with tempfile.TemporaryDirectory as tmp_dir: ... mlflow.set_tracking_uri(f"file:{tmp_dir}") ... with mlflow_context("log_group", "unique_key", {}, {}) as (mlflow_client, run_id): ... log_machine(machine) # doctest: +SKIP """ mlflow_client = get_mlflow_client(workspace_kwargs, service_principal_kwargs) run_id = get_run_id(mlflow_client, experiment_name=name, model_key=model_key) logger.info( f"MLflow client configured to use {'AzureML' if workspace_kwargs else 'local backend'}" ) yield mlflow_client, run_id mlflow_client.set_terminated(run_id)
Example #8
Source File: runs.py From mlflow with Apache License 2.0 | 5 votes |
def commands(): """ Manage runs. To manage runs of experiments associated with a tracking server, set the MLFLOW_TRACKING_URI environment variable to the URL of the desired server. """ pass
Example #9
Source File: runs.py From mlflow with Apache License 2.0 | 5 votes |
def list_run(experiment_id, view): """ List all runs of the specified experiment in the configured tracking server. """ store = _get_store() view_type = ViewType.from_string(view) if view else ViewType.ACTIVE_ONLY runs = store.search_runs([experiment_id], None, view_type) table = [] for run in runs: tags = {k: v for k, v in run.data.tags.items()} run_name = tags.get(MLFLOW_RUN_NAME, "") table.append([conv_longdate_to_str(run.info.start_time), run_name, run.info.run_id]) print(tabulate(sorted(table, reverse=True), headers=["Date", "Name", "ID"]))
Example #10
Source File: databricks_artifact_repo.py From mlflow with Apache License 2.0 | 5 votes |
def _extract_run_id(artifact_uri): """ The artifact_uri is expected to be dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path> Once the path from the input uri is extracted and normalized, it is expected to be of the form databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path> Hence the run_id is the 4th element of the normalized path. :return: run_id extracted from the artifact_uri """ artifact_path = extract_and_normalize_path(artifact_uri) return artifact_path.split('/')[3]
Example #11
Source File: databricks_artifact_repo.py From mlflow with Apache License 2.0 | 5 votes |
def _call_endpoint(self, service, api, json_body): db_profile = get_db_profile_from_uri(mlflow.tracking.get_tracking_uri()) db_creds = get_databricks_host_creds(db_profile) endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api] response_proto = api.Response() return call_endpoint(db_creds, endpoint, method, json_body, response_proto)
Example #12
Source File: test_databricks.py From mlflow with Apache License 2.0 | 5 votes |
def test_run_databricks_validations( tmpdir, cluster_spec_mock, # pylint: disable=unused-argument tracking_uri_mock, dbfs_mocks, set_tag_mock): # pylint: disable=unused-argument """ Tests that running on Databricks fails before making any API requests if validations fail. """ with mock.patch.dict(os.environ, {'DATABRICKS_HOST': 'test-host', 'DATABRICKS_TOKEN': 'foo'}),\ mock.patch("mlflow.projects.databricks.DatabricksJobRunner._databricks_api_request")\ as db_api_req_mock: # Test bad tracking URI tracking_uri_mock.return_value = tmpdir.strpath with pytest.raises(ExecutionException): run_databricks_project(cluster_spec_mock, synchronous=True) assert db_api_req_mock.call_count == 0 db_api_req_mock.reset_mock() mlflow_service = mlflow.tracking.MlflowClient() assert (len(mlflow_service.list_run_infos(experiment_id=FileStore.DEFAULT_EXPERIMENT_ID)) == 0) tracking_uri_mock.return_value = "http://" # Test misspecified parameters with pytest.raises(ExecutionException): mlflow.projects.run( TEST_PROJECT_DIR, backend="databricks", entry_point="greeter", backend_config=cluster_spec_mock) assert db_api_req_mock.call_count == 0 db_api_req_mock.reset_mock() # Test bad cluster spec with pytest.raises(ExecutionException): mlflow.projects.run(TEST_PROJECT_DIR, backend="databricks", synchronous=True, backend_config=None) assert db_api_req_mock.call_count == 0 db_api_req_mock.reset_mock() # Test that validations pass with good tracking URIs databricks.before_run_validations("http://", cluster_spec_mock) databricks.before_run_validations("databricks", cluster_spec_mock)
Example #13
Source File: test_databricks.py From mlflow with Apache License 2.0 | 5 votes |
def test_get_tracking_uri_for_run(): mlflow.set_tracking_uri("http://some-uri") assert databricks._get_tracking_uri_for_run() == "http://some-uri" mlflow.set_tracking_uri("databricks://profile") assert databricks._get_tracking_uri_for_run() == "databricks" mlflow.set_tracking_uri(None) with mock.patch.dict(os.environ, {mlflow.tracking._TRACKING_URI_ENV_VAR: "http://some-uri"}): assert mlflow.tracking._tracking_service.utils.get_tracking_uri() == "http://some-uri"
Example #14
Source File: test_mlflow_logger.py From ignite with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_integration(dirname): n_epochs = 5 data = list(range(50)) losses = torch.rand(n_epochs * len(data)) losses_iter = iter(losses) def update_fn(engine, batch): return next(losses_iter) trainer = Engine(update_fn) mlflow_logger = MLflowLogger(tracking_uri=os.path.join(dirname, "mlruns")) true_values = [] def dummy_handler(engine, logger, event_name): global_step = engine.state.get_event_attrib_value(event_name) v = global_step * 0.1 true_values.append(v) logger.log_metrics({"{}".format("test_value"): v}, step=global_step) mlflow_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED) import mlflow active_run = mlflow.active_run() trainer.run(data, max_epochs=n_epochs) mlflow_logger.close() from mlflow.tracking import MlflowClient client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns")) stored_values = client.get_metric_history(active_run.info.run_id, "test_value") for t, s in zip(true_values, stored_values): assert pytest.approx(t) == s.value
Example #15
Source File: test_mlflow_logger.py From ignite with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_integration_as_context_manager(dirname): n_epochs = 5 data = list(range(50)) losses = torch.rand(n_epochs * len(data)) losses_iter = iter(losses) def update_fn(engine, batch): return next(losses_iter) true_values = [] with MLflowLogger(os.path.join(dirname, "mlruns")) as mlflow_logger: trainer = Engine(update_fn) def dummy_handler(engine, logger, event_name): global_step = engine.state.get_event_attrib_value(event_name) v = global_step * 0.1 true_values.append(v) logger.log_metrics({"{}".format("test_value"): v}, step=global_step) mlflow_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED) import mlflow active_run = mlflow.active_run() trainer.run(data, max_epochs=n_epochs) from mlflow.tracking import MlflowClient client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns")) stored_values = client.get_metric_history(active_run.info.run_id, "test_value") for t, s in zip(true_values, stored_values): assert pytest.approx(t) == s.value
Example #16
Source File: test_mlflow_logger.py From ignite with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_mlflow_bad_metric_name_handling(dirname): import mlflow true_values = [123.0, 23.4, 333.4] with MLflowLogger(os.path.join(dirname, "mlruns")) as mlflow_logger: active_run = mlflow.active_run() handler = OutputHandler(tag="training", metric_names="all") engine = Engine(lambda e, b: None) engine.state = State(metrics={"metric:0 in %": 123.0, "metric 0": 1000.0,}) with pytest.warns(UserWarning, match=r"MLflowLogger output_handler encountered an invalid metric name"): engine.state.epoch = 1 handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED) for i, v in enumerate(true_values): engine.state.epoch += 1 engine.state.metrics["metric 0"] = v handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED) from mlflow.tracking import MlflowClient client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns")) stored_values = client.get_metric_history(active_run.info.run_id, "training metric 0") for t, s in zip([1000.0,] + true_values, stored_values): assert t == s.value