Is there a base class like BaseFlow and BaseTask t...
# ask-community
u
Is there a base class like BaseFlow and BaseTask to set default parameters, such as
max_retries
, or set default
state_handlers
k
Hey @张强, subclassing these is the right way to do it. For
@task
, you can make your own decorator that sets the defaults. Some of these things like
max_retries
can be set on the config.toml but the
skip_on_upstream_skip
and
state_handlers
can't
u
Okay, I got it.
Demo
Copy code
class BaseFlow(Flow):
    def __init__(self, *args, **kwargs):
        if "state_handlers" not in kwargs:
            kwargs["state_handlers"] = []
        kwargs["state_handlers"].append(flow_state_handler)
        kwargs.setdefault("terminal_state_handler", flow_terminal_state_handler)
        super().__init__(*args, **kwargs)

class BaseTask(Task):
    def __init__(self, **kwargs):
        kwargs.setdefault("max_retries", 3)
        kwargs.setdefault("retry_delay", datetime.timedelta(seconds=1))
        kwargs.setdefault("skip_on_upstream_skip", False)
        kwargs.setdefault("state_handlers", [my_state_handler])
        super().__init__(**kwargs)
k
Here is an example for the decorator (someone from the community showed me earlier)
Copy code
def custom_task(func=None, **task_init_kwargs):
    if func is None:
        return partial(custom_task, **task_init_kwargs)

    @wraps(func)
    def safe_func(**kwargs):
        try:
            return func(**kwargs)
        except Exception as e:
            print(f"Full Traceback: {traceback.format_exc()}")
            raise RuntimeError(type(e)) from None  # from None is necessary to not log the stacktrace

    safe_func.__name__ = func.__name__
    return task(safe_func, **task_init_kwargs)

@custom_task
def abc(x):
    return x

with Flow("custom-decorator-test") as flow:
    abc.map([1,2,3,4,5])
j
@Kevin Kho this might be due to my lack of python skill but if I would like to add
Copy code
kwargs.setdefault("max_retries", 3)
        kwargs.setdefault("retry_delay", datetime.timedelta(seconds=1))
        kwargs.setdefault("skip_on_upstream_skip", False)
        kwargs.setdefault("state_handlers", [my_state_handler])
to your example will this go under
Copy code
def safe_func(**kwargs):
        try:
            kwargs.setdefault("max_retries", 3)
            kwargs.setdefault("retry_delay", datetime.timedelta(seconds=1))
            kwargs.setdefault("skip_on_upstream_skip", False)
            kwargs.setdefault("state_handlers", [my_state_handler])
            return func(**kwargs)
like this?
k
I think that might work.
🙌 1