Hi! I'm new to Prefect. I was looking for a flow s...
# ask-community
j
Hi! I'm new to Prefect. I was looking for a flow solution coming from Ray and found this just today. I'm trying to do something similar to Ray's Dataset processing and I have a few questions: 1. Is there a better way to handle batching and unbatching in Prefect? 2. I've noticed that depending on the use of
map
or
submit
the flow graph in the run view may not show as complete (see screenshot) even though the run itself is complete. Is this an issue with how i am implementing the
flow
? (See the commented section in
process_images
)
Copy code
from prefect import task, flow, unmapped
from prefect.futures import wait
from prefect.tasks import task_input_hash

from datetime import timedelta
import os
import numpy as np
import torch
from pathlib import Path
from typing import List
from PIL import Image
import time

def _cleanup(temp_dir):
    for f in temp_dir.iterdir():
        if f.is_file():
            f.unlink()

        if f.is_dir():
            _cleanup(f)
    temp_dir.rmdir()


dummy_data = np.random.rand(20, 3, 1024, 1024)

temp_dir = Path("temp")

if temp_dir.exists():
    _cleanup(temp_dir)

input_dir = temp_dir / "input"
input_dir.mkdir(parents=True)

output_dir = temp_dir / "output"
output_dir.mkdir(parents=True)

for i, x in enumerate(dummy_data):
    im = Image.fromarray((x * 255).astype(np.uint8).transpose(1, 2, 0))
    im.save(input_dir / f"{i}.jpg")


# Define constants
BATCH_SIZE = 8
INPUT_DIR = input_dir
OUTPUT_DIR = output_dir


@task(cache_key_fn=task_input_hash, cache_expiration=timedelta(minutes=5), viz_return_value=INPUT_DIR.glob("*.jpg"))
def read_image_list(input_dir: Path):
    return list(input_dir.glob("*.jpg"))


@task
def preprocess_image(image_path):
    image = Image.open(image_path)
    image = image.resize((224, 224))  # Example resize for a model
    image_array = np.array(image) / 255.0  # Normalize to 0-1
    return image_array.transpose(2, 0, 1)  # HWC to CHW for PyTorch


@task
def load_model():
    model = torch.hub.load("pytorch/vision:v0.10.0", "resnet18", pretrained=True)
    model.eval()
    return model


@task
def run_inference(batch, model):
    batch_array = np.stack(batch)
    inputs = torch.tensor(batch_array, dtype=torch.float32)
    # with torch.no_grad():
    #     outputs = model(inputs)
    # return outputs.argmax(dim=1).tolist()
    time.sleep(0.1)
    return inputs


@task
def save_results(results: torch.Tensor, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    result_path = os.path.join(output_dir, f"_.png")

    for result in results:
        img = result.numpy().transpose(1, 2, 0)  # CHW to HWC
        img = Image.fromarray((img * 255).astype(np.uint8))
        img.save(result_path)

    return result_path


# Flow definition
@flow(log_prints=True)
def process_images(input_dir=INPUT_DIR, output_dir=OUTPUT_DIR, batch_size=BATCH_SIZE):
    model = load_model()

    image_list = read_image_list.submit(input_dir).result()

    # Preprocess images in parallel
    preprocessed_images = preprocess_image.map(image_list)

    # Batch the images
    batches = [preprocessed_images[i : i + batch_size] for i in range(0, len(preprocessed_images), batch_size)]

    # BUG?: Comment this out to see the run view completion issue
    preds = [run_inference.submit(batch, model).result() for batch in batches]

    # BUG?: Uncomment this to see the run view completion issue
    # preds = run_inference.map(batches, model)

    # BUG?: Comment this out to see the run view completion issue
    results = [save_results.submit(pred, output_dir) for pred in preds]

    # BUG?: Uncomment this to see the run view completion issue
    # results = save_results.map(preds, output_dir)

    done, not_done = wait(results)

    print(f"Done: {len(done)}, Not Done: {len(not_done)}")


# Run the flow
if __name__ == "__main__":
    process_images()
j
I think that image is a UI issue, not your code. - I've seen that before in my own jobs. @Craig Harshbarger
j
ok, good to know, thank you. Another question: How can I run flows in parallel?
c
I believe this is an issue with how the client tracks dependencies but not 100% sure. Opening an issue would be a great place to start 👍