haf
04/05/2021, 5:11 PMhaf
04/05/2021, 5:12 PMKevin Kho
haf
04/05/2021, 5:14 PMhaf
04/05/2021, 5:15 PMhaf
04/05/2021, 5:16 PMhaf
04/05/2021, 5:17 PMKevin Kho
Kevin Kho
Kevin Kho
haf
04/05/2021, 5:23 PMhaf
04/05/2021, 5:24 PMKevin Kho
Kevin Kho
haf
04/05/2021, 5:28 PMhaf
04/05/2021, 5:28 PMKevin Kho
Kevin Kho
Kevin Kho
haf
04/05/2021, 5:36 PMKevin Kho
Kevin Kho
Kevin Kho
haf
04/05/2021, 5:42 PMhaf
04/05/2021, 5:42 PMKevin Kho
Avi A
04/06/2021, 10:23 AMJupyterTask)Avi A
04/06/2021, 10:25 AMhaf
04/06/2021, 10:29 AMAvi A
04/06/2021, 11:21 AM@task
def init_mlflow_run(experiment_id=None):
    client = MlflowClient(DEFAULT_MLFLOW_TRACKING_URI)
    if experiment_id is None:
        flow_name = prefect.context.get('flow_name')
        artifact_location = os.path.join(DEFAULT_BUCKET_PATH, "mlflow-data", flow_name)
        experiment_id = get_or_create_mlflow_experiment_id(flow_name, artifact_location)
    run = client.create_run(
        experiment_id=experiment_id,
        tags=dict(prefect_run_id=prefect.context.get('flow_run_id'))
    )
    run_id = run.info.run_id
    for k, v in prefect.context.get('parameters', {}).items():
        client.log_param(run_id, k, v)
    return run_id
And then whenever I want to log something, I call the following task that also accepts the run_id that the init task generated. It’s a bit customed to my use-case but I think you can tweak it to serve you better (and perhaps generalize it and contribute back to P )
class IterablesAsListsJSONEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, frozenset):
            return list(o)
        return super().default(o)
@task(trigger=prefect.triggers.always_run)
def mlflow_logging(
        run_id: Optional[str] = None,
        params: Optional[dict] = None,
        metrics: Optional[dict] = None,
        artifacts: Optional[dict] = None,
        terminate=True
):
    if run_id is None:
        get_logger().warning("Got empty run_id. Skipping MLFlow logging")
        return None
    client = MlflowClient(DEFAULT_MLFLOW_TRACKING_URI)
    if params is not None and not isinstance(params, PrefectStateSignal):
        for k, v in params.items():
            client.log_param(run_id, k, v)
    if metrics is not None and not isinstance(metrics, PrefectStateSignal):
        for k, v in metrics.items():
            client.log_metric(run_id, k, v)
    if artifacts is not None and not isinstance(artifacts, PrefectStateSignal):
        shutil.rmtree(ARTIFACTS_PATH, ignore_errors=True)
        os.makedirs(ARTIFACTS_PATH)  # TODO: figure out how to upload stuff without actually saving to disk
        for filename, data in artifacts.items():
            try:
                with open(os.path.join(ARTIFACTS_PATH, f"{filename}.json"), "w") as json_f:
                    json.dump(data, json_f, cls=IterablesAsListsJSONEncoder)
            except TypeError as e:
                get_logger().warning("Failed to save data as JSON. Saving as pickle")
                with open(os.path.join(ARTIFACTS_PATH, filename), "wb") as f:
                    pickle.dump(data, f)
        client.log_artifacts(run_id, ARTIFACTS_PATH)
    if terminate:
        client.set_terminated(run_id)
    return run_idAvi A
04/06/2021, 11:28 AMKevin Kho
Kevin Kho
show-us-what-you-got channel about your work.Avi A
04/07/2021, 7:40 AM