Is the expected behaviour on a merge of two skippe...
# prefect-community
Is the expected behaviour on a merge of two skipped branches that it still executes? How would I have the Merge respect the two upstream skips? Or is this not possible?
Copy code
from prefect import task, Flow
from prefect.triggers import any_failed, some_failed
from prefect.tasks.control_flow.conditional import ifelse, merge, switch

def three_outcomes():
    return "dead_branch"

def fail_branch():
    print("i fail")

def pass_branch():
    print("i pass")

def dead_branch():
    print("im dead")

def do_final_thing():

with Flow("example") as flow:
    switch(three_outcomes, dict(dead_branch=dead_branch, pass_branch=pass_branch, fail_branch=fail_branch))
    do_final_thing.set_upstream(merge(pass_branch, fail_branch))

flow_state =
👀 1
Hi @Chris O'Brien! The way the merge tasks operates currently is that it has a hardcoded value
, since it's intended to receive skipped tasks. However, you bring up a great point, that if ALL its ancestors are skipped, it doesn't make sense that it should still run. I've opened an issue for that ( but in the meantime you can handle this by implementing a custom trigger for your merge task:
Copy code
from prefect import signals

def merge_trigger(upstream_states):
    if all(state.is_skipped() for state in upstream_states):
        raise signals.SKIP("All upstreams skipped")
    elif not all(state.is_successful() for state in upstream_states):
        raise signals.TRIGGERFAIL(
            'Trigger was "all_successful" but some of the upstream tasks failed.'
    return True
which would cause the merge task to also be skipped as intended (or failed if one of the ancestors wasn't successful). We have a great section on triggers in the docs ( but if you have any problems, feel free to let us know! 😄
success kid 1
Thanks Nicholas, works a charm! Had to create a
to do it as I couldn’t see a way of getting the existing merge to take a trigger?
Copy code
def merge_trigger(upstream_states):
    if all(state.is_skipped() for state in upstream_states):
        raise signals.SKIP("All upstreams skipped")
    elif not all(state.is_successful() for state in upstream_states):
        raise signals.TRIGGERFAIL(
            'Trigger was "all_successful" but some of the upstream tasks failed.'
    return True

class Merge(Task):
    def __init__(self, **kwargs) -> None:
        if kwargs.setdefault("skip_on_upstream_skip", False):
            raise ValueError("Merge tasks must have `skip_on_upstream_skip=False`.")

    def run(self, **task_results: Any) -> Any:
        return next(
                for k, v in sorted(task_results.items())
                if not isinstance(v, NoResultType)

def custom_merge(*tasks: Task) -> Task:
    return Merge(trigger=merge_trigger).bind(**{"task_{}".format(i + 1): t for i, t in enumerate(tasks)})