haf
04/05/2021, 5:11 PMhaf
04/05/2021, 5:12 PMKevin Kho
haf
04/05/2021, 5:14 PMhaf
04/05/2021, 5:15 PMhaf
04/05/2021, 5:16 PMhaf
04/05/2021, 5:17 PMKevin Kho
Kevin Kho
Kevin Kho
haf
04/05/2021, 5:23 PMhaf
04/05/2021, 5:24 PMKevin Kho
Kevin Kho
haf
04/05/2021, 5:28 PMhaf
04/05/2021, 5:28 PMKevin Kho
Kevin Kho
Kevin Kho
haf
04/05/2021, 5:36 PMKevin Kho
Kevin Kho
Kevin Kho
haf
04/05/2021, 5:42 PMhaf
04/05/2021, 5:42 PMKevin Kho
Avi A
04/06/2021, 10:23 AMJupyterTask
)Avi A
04/06/2021, 10:25 AMhaf
04/06/2021, 10:29 AMAvi A
04/06/2021, 11:21 AM@task
def init_mlflow_run(experiment_id=None):
client = MlflowClient(DEFAULT_MLFLOW_TRACKING_URI)
if experiment_id is None:
flow_name = prefect.context.get('flow_name')
artifact_location = os.path.join(DEFAULT_BUCKET_PATH, "mlflow-data", flow_name)
experiment_id = get_or_create_mlflow_experiment_id(flow_name, artifact_location)
run = client.create_run(
experiment_id=experiment_id,
tags=dict(prefect_run_id=prefect.context.get('flow_run_id'))
)
run_id = run.info.run_id
for k, v in prefect.context.get('parameters', {}).items():
client.log_param(run_id, k, v)
return run_id
And then whenever I want to log something, I call the following task that also accepts the run_id
that the init task generated. It’s a bit customed to my use-case but I think you can tweak it to serve you better (and perhaps generalize it and contribute back to P )
class IterablesAsListsJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, frozenset):
return list(o)
return super().default(o)
@task(trigger=prefect.triggers.always_run)
def mlflow_logging(
run_id: Optional[str] = None,
params: Optional[dict] = None,
metrics: Optional[dict] = None,
artifacts: Optional[dict] = None,
terminate=True
):
if run_id is None:
get_logger().warning("Got empty run_id. Skipping MLFlow logging")
return None
client = MlflowClient(DEFAULT_MLFLOW_TRACKING_URI)
if params is not None and not isinstance(params, PrefectStateSignal):
for k, v in params.items():
client.log_param(run_id, k, v)
if metrics is not None and not isinstance(metrics, PrefectStateSignal):
for k, v in metrics.items():
client.log_metric(run_id, k, v)
if artifacts is not None and not isinstance(artifacts, PrefectStateSignal):
shutil.rmtree(ARTIFACTS_PATH, ignore_errors=True)
os.makedirs(ARTIFACTS_PATH) # TODO: figure out how to upload stuff without actually saving to disk
for filename, data in artifacts.items():
try:
with open(os.path.join(ARTIFACTS_PATH, f"{filename}.json"), "w") as json_f:
json.dump(data, json_f, cls=IterablesAsListsJSONEncoder)
except TypeError as e:
get_logger().warning("Failed to save data as JSON. Saving as pickle")
with open(os.path.join(ARTIFACTS_PATH, filename), "wb") as f:
pickle.dump(data, f)
client.log_artifacts(run_id, ARTIFACTS_PATH)
if terminate:
client.set_terminated(run_id)
return run_id
Avi A
04/06/2021, 11:28 AMKevin Kho
Kevin Kho
show-us-what-you-got
channel about your work.Avi A
04/07/2021, 7:40 AM