oleander-whirlwind
10/06/2025, 9:31 AMMarvin
10/06/2025, 9:31 AMMarvin
10/06/2025, 9:34 AMcache_result_in_memory=True), so downstream tasks can use the original Python object without serialization.
- If you use a process-based runner or need to persist/restore results, then objects must be serialized with the configured serializer (pickle-based by default). Large or GPU-resident models can be slow to serialize or fail entirely.
- Prefect executes all tasks for a flow run in the same Python process by default unless you opt into a different task runner (e.g., ProcessPool). Subflows are separate flow runs and cannot share in-memory objects.
What this means for your PyTorch model
- Works well (no serialization):
- Same flow run, same process, in-memory passing (default or ThreadPoolTaskRunner).
- Keep persist_result=False for the task that returns the model.
- Risky/expensive (requires serialization):
- ProcessPoolTaskRunner (child processes) or any cross-process boundary.
- Persisting task results (persist_result=True) or resuming from persisted results.
- Passing models into subflows (separate flow runs).
Recommended patterns
1) Easiest: keep it all in one task
- Load the model and run inference in the same task to avoid passing the model between tasks.
from prefect import flow, task
@task
def predict_batch(batch):
import torch
model = ... # build/load once here
model.eval()
with torch.no_grad():
return model(batch).cpu().numpy()
@flow
def main_flow(batches):
return [predict_batch.submit(b) for b in batches]
2) If you want separate “load” and “predict” tasks, stay in one process
- Use thread-based execution and avoid persisting the model result.
from prefect import flow, task
from prefect.task_runners import ThreadPoolTaskRunner
@task(persist_result=False)
def load_model():
import torch
model = ...
model.eval()
return model # stays in memory only
@task
def predict(model, batch):
import torch
with torch.no_grad():
return model(batch).cpu().numpy()
@flow(task_runner=ThreadPoolTaskRunner())
def main_flow(batch):
model = load_model()
preds = predict(model, batch)
return preds
Notes:
- ThreadPoolTaskRunner keeps tasks in the same process, so the model isn’t serialized.
- Avoid ProcessPoolTaskRunner for this pattern.
3) Singleton/lazy-init inside tasks (no passing at all)
- Initialize the model once per process and reuse it.
from functools import lru_cache
from prefect import flow, task
@lru_cache(maxsize=1)
def get_model():
import torch
m = ...
m.eval()
return m
@task
def predict(batch):
import torch
model = get_model()
with torch.no_grad():
return model(batch).cpu().numpy()
@flow
def main_flow(batches):
return [predict.submit(b) for b in batches]
4) If you must cross processes/machines or want durability
- Save the model externally and pass a lightweight reference (path/URI). Load it where you need it (e.g., torch.load with map_location='cpu' or as appropriate).
- This avoids passing giant Python objects around.
Restrictions and gotchas
- Process boundaries: Returning a model from a task running in a separate process requires pickling; this can fail for GPU tensors or be very slow/large.
- Persistence: If you set persist_result=True (or end up persisting for retries/resume), Prefect will serialize your model. Consider leaving it False for non-serializable objects.Marvin
10/06/2025, 9:34 AMoleander-whirlwind
10/06/2025, 9:40 AMoleander-whirlwind
10/06/2025, 9:41 AMMarvin
10/06/2025, 9:41 AMMarvin
10/06/2025, 9:45 AMfrom pathlib import Path
from functools import lru_cache
from prefect import flow, task
from prefect.context import get_run_context
from prefect.task_runners import ProcessPoolTaskRunner
# Simple torch model for demo
def build_model():
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
model.eval()
return model
@task
def load_model_to_disk() -> str:
import torch
ctx = get_run_context()
run_id = ctx.flow_run.id
# Use a run-scoped path so concurrent runs don't collide
out_dir = Path("/tmp") / f"prefect-model-{run_id}"
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / "<http://model.pt|model.pt>"
model = build_model()
# Option 1 (simple, requires this code to be importable where you load):
torch.save(model, path) # pickles the full model object
# Option 2 (portable): torch.jit.trace/script and save a TorchScript file
# example:
# example_input = torch.randn(1, 4)
# ts_model = torch.jit.trace(model, example_input)
# ts_path = out_dir / "model.ts"
# ts_model.save(str(ts_path))
# return str(ts_path)
return str(path)
@task
def predict_from_disk(model_path: str, batch: list[float]) -> float:
import torch
import numpy as np
# If you saved with torch.save(model, ...):
model = torch.load(model_path, map_location="cpu")
# If you saved a TorchScript file instead:
# model = torch.jit.load(model_path, map_location="cpu")
x = torch.tensor(np.array(batch, dtype=np.float32)).view(1, -1)
with torch.no_grad():
y = model(x).item()
return y
@flow(task_runner=ProcessPoolTaskRunner())
def multiprocess_flow(batch: list[float]) -> float:
model_path = load_model_to_disk()
return predict_from_disk(model_path, batch)
if __name__ == "__main__":
print(multiprocess_flow([1.0, 2.0, 3.0, 4.0]))
Notes:
- Using ProcessPoolTaskRunner runs tasks in separate processes on the same machine. Passing the raw model object would require pickling and will be slow/fragile for large models — passing a file path is robust.
- TorchScript (commented option) makes the artifact more portable and avoids Python pickling entirely.
Example B: Two different machines (DaskTaskRunner + shared storage like S3)
- The “load” task runs on one worker/machine, writes a model artifact to S3 (or any shared storage), and returns an s3://... URI.
- The “predict” task runs on another worker/machine, downloads and evaluates it.
- You’ll need a Dask cluster with multiple workers on different machines and prefect[dask] + boto3 installed.
```
# pip/uv deps you’ll need:
# uv run --with 'prefect[dask]' --with boto3 python your_script.py
import os
from pathlib import Path
import uuid
from prefect import flow, task
from prefect_dask.task_runners import DaskTaskRunner
# Configure your S3 destination via env or hardcode for demo
S3_BUCKET = os.getenv("S3_BUCKET", "my-bucket")
S3_PREFIX = os.getenv("S3_PREFIX", "artifacts/models")
def build_model():
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
model.eval()
return model
@task
def load_model_to_s3() -> str:
import torch
import boto3
# Create a unique key per run
key = f"{S3_PREFIX}/model-{uuid.uuid4().hex}.pt"
# Save locally first
local_dir = Path("/tmp/prefect-model")
local_dir.mkdir(parents=True, exist_ok=True)Marvin
10/06/2025, 9:45 AMoleander-whirlwind
10/06/2025, 9:49 AMoleander-whirlwind
10/06/2025, 9:52 AMoleander-whirlwind
10/06/2025, 9:57 AMMarvin
10/06/2025, 9:57 AMMarvin
10/06/2025, 9:58 AMaddress="tcp://<scheduler-host>:8786".
- Install extras: prefect[dask] and boto3.
- Set S3_BUCKET and (optionally) S3_PREFIX env vars for the S3 example.
```
# deps (examples):
# uv run --with 'prefect[dask]' --with boto3 python this_file.py
import os
import uuid
from pathlib import Path
from typing import List
from prefect import flow, task
from prefect.task_runners import ThreadPoolTaskRunner, ProcessPoolTaskRunner # not used below, but handy
# -----------------------
# Shared utility: model
# -----------------------
def build_model():
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
model.eval()
return model
# -----------------------
# 1) Local: in-memory passing within a single process
# -----------------------
@task(persist_result=False)
def load_model_in_memory():
# Returns the live Python object (no persistence / no pickling)
return build_model()
@task
def predict_in_memory(model, batch: List[float]) -> float:
import torch
import numpy as np
x = torch.tensor(np.array(batch, dtype=np.float32)).view(1, -1)
with torch.no_grad():
y = model(x).item()
return y
@flow(name="local_in_memory_flow", task_runner=ThreadPoolTaskRunner())
def local_in_memory_flow(batch: List[float]) -> float:
# ThreadPoolTaskRunner keeps tasks in the same process
model = load_model_in_memory()
return predict_in_memory(model, batch)
# -----------------------
# 2) Distributed: two machines via Dask + shared storage (S3)
# -----------------------
S3_BUCKET = os.getenv("S3_BUCKET", "my-bucket")
S3_PREFIX = os.getenv("S3_PREFIX", "artifacts/models")
@task
def load_model_to_s3() -> str:
"""
Build the model on one worker/machine, save it to S3, and return s3:// URI.
"""
import torch
import boto3
key = f"{S3_PREFIX}/model-{uuid.uuid4().hex}.pt"
tmp_dir = Path("/tmp/prefect-model")
tmp_dir.mkdir(parents=True, exist_ok=True)
local_path = tmp_dir / Path(key).name
model = build_model()
# Optionally use TorchScript for portability:
# example_input = torch.randn(1, 4)
# ts_model = torch.jit.trace(model, example_input)
# ts_model.save(str(local_path))
# Otherwise, pickle the full object:
torch.save(model, local_path)
s3 = boto3.client("s3")
s3.upload_file(str(local_path), S3_BUCKET, key)
return f"s3://{S3_BUCKET}/{key}"
@task
def predict_from_s3(model_uri: str, batch: List[float]) -> float:
"""
Download the model on another worker/machine and run inference.
"""
import boto3
import torch
import numpy as np
from urllib.parse import urlparse
u = urlparse(model_uri)
bucket = u.netloc
key = u.path.lstrip("/")
tmp_dir = Path("/tmp/prefect-model")
tmp_dir.mkdir(parents=True, exist_ok=True)
local_path = tmp_dir / Path(key).name
s3 = boto3.client("s3")
s3.download_file(bucket, key, str(local_path))
# If saved with torch.save(model, ...):
model = torch.load(local_path, map_location="cpu")
# If you saved TorchScript instead:
# model = torch.jit.load(local_path, map_location="cpu")
x = torch.tensor(np.array(batch, dtype=np.float32)).view(1, -1)
with torch.no_grad():
y = model(x).item()
return y
# Define the distributed flow with DaskTaskRunner
# IMPORTANT: requires prefect[dask] installed and a reachable Dask scheduler
from prefect_dask.task_runners import DaskTaskRunner
@flow(
name="distributed_cross_machine_flow",Marvin
10/06/2025, 9:58 AMoleander-whirlwind
10/06/2025, 10:09 AMMarvin
10/06/2025, 10:09 AMMarvin
10/06/2025, 10:10 AMMarvin
10/06/2025, 10:10 AMoleander-whirlwind
10/06/2025, 10:45 AMMarvin
10/06/2025, 10:45 AMMarvin
10/06/2025, 10:46 AMoleander-whirlwind
10/06/2025, 11:10 AMMarvin
10/06/2025, 11:10 AMMarvin
10/06/2025, 11:11 AMfrom prefect import flow, task
from prefect.task_runners import ThreadPoolTaskRunner
@task
def run_container(image: str, command: list[str]) -> dict:
import docker
client = docker.from_env() # per-thread client
container = client.containers.run(image, command, detach=True)
try:
result = container.wait() # {'StatusCode': int}
logs = container.logs().decode()
return {"id": container.id, "status": result["StatusCode"], "logs": logs}
finally:
# cleanup
try:
container.remove(force=True)
except Exception:
pass
@task
def aggregate(results: list[dict]) -> dict:
# All results are just Python objects; shared in the same process
return {
"count": len(results),
"exit_codes": [r["status"] for r in results],
"first_log": results[0]["logs"][:200] if results else "",
}
@flow(task_runner=ThreadPoolTaskRunner(max_workers=4))
def docker_threaded_flow():
futures = [
run_container.submit("alpine:3.19", ["sh", "-c", "echo hi && sleep 1 && echo bye"]),
run_container.submit("alpine:3.19", ["sh", "-c", "echo task2 && sleep 2"]),
run_container.submit("alpine:3.19", ["sh", "-c", "echo task3 && sleep 3"]),
run_container.submit("alpine:3.19", ["sh", "-c", "echo task4 && sleep 1"]),
]
results = [f.result() for f in futures]
return aggregate(results)
if __name__ == "__main__":
print(docker_threaded_flow())
Bottom line: ThreadPoolTaskRunner is compatible with launching many containers in parallel, and results can be shared in-memory across tasks in the process. Just ensure Docker access, create per-thread clients, and handle cleanup. If you need multi-machine distribution of those container-launch tasks, switch to DaskTaskRunner or RayTaskRunner and use shared storage or cluster-native object stores/actors for coordination.