Is the expected behaviour on a merge of two skippe...
# prefect-community
c
Is the expected behaviour on a merge of two skipped branches that it still executes? How would I have the Merge respect the two upstream skips? Or is this not possible?
Copy code
from prefect import task, Flow
from prefect.triggers import any_failed, some_failed
from prefect.tasks.control_flow.conditional import ifelse, merge, switch

@task
def three_outcomes():
    return "dead_branch"

@task
def fail_branch():
    print("i fail")

@task
def pass_branch():
    print("i pass")

@task
def dead_branch():
    print("im dead")

@task
def do_final_thing():
    print("final")

with Flow("example") as flow:
    switch(three_outcomes, dict(dead_branch=dead_branch, pass_branch=pass_branch, fail_branch=fail_branch))
    do_final_thing.set_upstream(merge(pass_branch, fail_branch))

flow_state = flow.run()
flow.visualize(flow_state=flow_state)
👀 1
n
Hi @Chris O'Brien! The way the merge tasks operates currently is that it has a hardcoded value
skip_on_upstream_skip=False
, since it's intended to receive skipped tasks. However, you bring up a great point, that if ALL its ancestors are skipped, it doesn't make sense that it should still run. I've opened an issue for that (https://github.com/PrefectHQ/prefect/issues/1768) but in the meantime you can handle this by implementing a custom trigger for your merge task:
Copy code
from prefect import signals

def merge_trigger(upstream_states):
    if all(state.is_skipped() for state in upstream_states):
        raise signals.SKIP("All upstreams skipped")
    elif not all(state.is_successful() for state in upstream_states):
        raise signals.TRIGGERFAIL(
            'Trigger was "all_successful" but some of the upstream tasks failed.'
        )
    return True
which would cause the merge task to also be skipped as intended (or failed if one of the ancestors wasn't successful). We have a great section on triggers in the docs (https://docs.prefect.io/core/concepts/tasks.html#triggers) but if you have any problems, feel free to let us know! 😄
success kid 1
c
Thanks Nicholas, works a charm! Had to create a
custom_merge
to do it as I couldn’t see a way of getting the existing merge to take a trigger?
Copy code
def merge_trigger(upstream_states):
    if all(state.is_skipped() for state in upstream_states):
        raise signals.SKIP("All upstreams skipped")
    elif not all(state.is_successful() for state in upstream_states):
        raise signals.TRIGGERFAIL(
            'Trigger was "all_successful" but some of the upstream tasks failed.'
        )
    return True

class Merge(Task):
    def __init__(self, **kwargs) -> None:
        if kwargs.setdefault("skip_on_upstream_skip", False):
            raise ValueError("Merge tasks must have `skip_on_upstream_skip=False`.")
        super().__init__(**kwargs)

    def run(self, **task_results: Any) -> Any:
        return next(
            (
                v
                for k, v in sorted(task_results.items())
                if not isinstance(v, NoResultType)
            ),
            None,
        )

def custom_merge(*tasks: Task) -> Task:
    return Merge(trigger=merge_trigger).bind(**{"task_{}".format(i + 1): t for i, t in enumerate(tasks)})