Hello Prefect community, I'm using the `transacti...
# ask-community
r
Hello Prefect community, I'm using the
transaction
context manager with an
on_rollback
hook. I also have a downstream task with retries. What happens now, is that the
on_rollback
hook is invoked after the first attempt of the downstream task fails. What I want is that the
on_rollback
hook is only invoked if the last retry of the downstream task fails. How can I do that?
This is my code:
Copy code
from time import sleep

from prefect import task, flow, get_run_logger
from prefect.transactions import transaction


@task
def task_1():
    logger = get_run_logger()
    <http://logger.info|logger.info>("Executing task 1")
    sleep(1)
    <http://logger.info|logger.info>("Task 1 completed")


@task_1.on_rollback
def task_1_rb(transaction):
    logger = get_run_logger()
    logger.warning("Rolling back task 1")
    sleep(1)
    <http://logger.info|logger.info>("Task 1 rollback completed")


@task(retries=1)
def task_fail():
    logger = get_run_logger()
    <http://logger.info|logger.info>("Executing task that will fail")
    sleep(1)
    logger.error("Task failed")
    raise RuntimeError("This task is expected to fail.")


@flow
def pipeline():
    with transaction():
        task_1()
        task_fail()


if __name__ == "__main__":
    pipeline()
n
hi @Roy Wolters - i think you found a rough edge
Copy code
from time import sleep

from prefect import flow, get_run_logger, task
from prefect.transactions import transaction


@task
def task_1():
    logger = get_run_logger()
    <http://logger.info|logger.info>("Executing task 1")
    sleep(1)
    <http://logger.info|logger.info>("Task 1 completed")


@task_1.on_rollback
def task_1_rb(transaction):
    logger = get_run_logger()
    logger.warning("Rolling back task 1")
    sleep(1)
    <http://logger.info|logger.info>("Task 1 rollback completed")


@task(retries=1)
def task_fail():
    logger = get_run_logger()
    <http://logger.info|logger.info>("Executing task that will fail")
    sleep(1)
    logger.error("Task failed")
    raise RuntimeError("This task is expected to fail.")


# Original problem
@flow
def pipeline_problem():
    with transaction():
        task_1()
        task_fail()  # Rollback fires on first failure


# Workaround until fix is merged
@flow
def pipeline_workaround():
    """
    Workaround: Keep retrying tasks outside the transaction scope.
    This prevents premature rollback triggering.
    """
    with transaction():
        task_1()

    # Run task_fail outside transaction so retries work properly
    # If you need transactional behavior, wrap the final result
    try:
        result = task_fail()
        # Could stage result in transaction if needed
    except Exception as e:
        # Only now would you trigger compensating actions if needed
        get_run_logger().error(f"Task failed after all retries: {e}")
        raise


if __name__ == "__main__":
    print("\n=== PROBLEM: Rollback fires before retry ===")
    try:
        pipeline_problem()
    except:
        pass

    print("\n=== WORKAROUND: Keep retrying tasks outside transaction ===")
    try:
        pipeline_workaround()
    except:
        pass
i think it shouldn't run until all the retries are exhausted, we'll get a PR up for this
thanks for the message!
r
Hi Nate, thanks for your response. I'm looking forward to the fix!
It seems that your suggestion doesn't give the behavior that I'm looking for, because now task_1 is not rolled back at all. I found that the below code gives the behavior that I want:
Copy code
from time import sleep

from prefect import flow, get_run_logger, task
from prefect.transactions import transaction


@task
def task_1():
    logger = get_run_logger()
    <http://logger.info|logger.info>("Executing task 1")
    sleep(1)
    <http://logger.info|logger.info>("Task 1 completed")


@task_1.on_rollback
def task_1_rb(transaction):
    logger = get_run_logger()
    logger.warning("Rolling back task 1")
    sleep(1)
    <http://logger.info|logger.info>("Task 1 rollback completed")


@task(retries=1)
def task_fail():
    logger = get_run_logger()
    <http://logger.info|logger.info>("Executing task that will fail")
    sleep(1)
    logger.error("Task failed")
    raise RuntimeError("This task is expected to fail.")


# Original problem
@flow
def pipeline_problem():
    with transaction():
        task_1()
        task_fail()  # Rollback fires on first failure


# Workaround until fix is merged
@flow
def pipeline_workaround():
    """
    Workaround: Keep retrying tasks in nested transaction scope.
    This prevents premature rollback triggering.
    """
    with transaction():
        task_1()
        with transaction():
            # Run task_fail in nested transaction so retries work properly
            result = task_fail()


if __name__ == "__main__":
    print("\n=== PROBLEM: Rollback fires before retry ===")
    try:
        pipeline_problem()
    except Exception:
        pass

    print("\n=== WORKAROUND: Keep retrying tasks in nested transaction ===")
    try:
        pipeline_workaround()
    except Exception:
        pass