Hi! Each mapped task is getting executed twice in ...
# ask-community
s
Hi! Each mapped task is getting executed twice in my flow, What might be the reason?
Copy code
INFO
CloudTaskRunner
Task 'collate-quarterly-crdlotcookie[5]': Starting task run...
INFO
CloudTaskRunner
Task 'collate-quarterly-crdlotcookie[5]': Starting task run...
I have restarted the flow in the middle, it created another cloudrunner.
a
Are you on Prefect Cloud? If so, there is a feature called Version Locking that enforces that your work runs once and only once. You can enabled that using GraphQL or directly from the UI.
Copy code
mutation {
  enable_flow_version_lock(input: { flow_id: "your-flow-id-here" }) {
    success
  }
}
Additionally, if you want to find out the root cause of this, can you share your flow code? I saw a similar issue once when the user was calling flow.run() within the script which can be avoided if you leverage the CLI to run and register flows.
s
Copy code
from unittest import skip
from prefect import Flow, Parameter, task, context
from prefect.storage import GitHub
from lakeview import LakeviewRunJobTask
from prefect.run_configs import KubernetesRun
from prefect.engine.results import S3Result
from datetime import datetime
from prefect.executors import LocalDaskExecutor
from dateutil.relativedelta import relativedelta


CID = "c022"
FLOW_NAME = "europe-monthly-lotame-datamart"
REGION = "eu-west-1"
CLUSTER = "aqfer-prod-eks-Ireland"
TASK_TAGS = ["c022-europe-monthly-lotame-datamart"]

LOCATION = "results/"+CID+"/"+FLOW_NAME+"/"+"{task_run_id}.prefect"
S3_RESULT = S3Result(bucket='com.aqfer.prod.prefect', location=LOCATION)


@task
def reduce_map():
    pass


@task(result=S3_RESULT)
def compute_input_args(event_month):
    format = '%Y%m'
    date_format = '%Y%m%d'
    if event_month == None:
        date = datetime.now()+relativedelta(months=-1)
    else:
        date = datetime.strptime(event_month, format)
    months = []
    cg_dates = []
    for x in range(5):
        m = date+relativedelta(months=-x)
        months += [m.strftime(format)]
    for y in range(1, 15):
        c = date.replace(day=1)+relativedelta(days=-y)
        cg_dates += [c.strftime(date_format)]
    cg_dates.append(months[0]+'*')
    for z in range(14):
        c = date.replace(day=1)+relativedelta(months=1)+relativedelta(days=z)
        cg_dates += [c.strftime(date_format)]
    return {"event_month": months[0], "prev_five_months": '{{{}}}'.format(",".join(months)), "cg_dates": '{{{}}}'.format(",".join(cg_dates)), "prev_three_months": '{{{}}}'.format(",".join(months[0:3]))}


@task(result=S3_RESULT)
def generate_arg_map(event_month, prev_three_months):
    m = []
    for c in ['AT', 'BE', 'CH', 'CZ', 'DE', 'DK', 'ES', 'FI', 'FR', 'GB', 'GR', 'HU', 'IE', 'IT', 'NL', 'NO', 'PL', 'PT', 'RO', 'SE', 'SK']:
        m += [{"event_month": event_month, "country": c,
               "prev_three_months": prev_three_months}]
    return m


collate_monthly_crossdevice = LakeviewRunJobTask(cid=CID, job="europe-collate-monthly-crossdevice",
                                                 cluster=CLUSTER, poll_interval=300, name="collate-monthly-crossdevice", timeout=10800, result=S3_RESULT, tags=TASK_TAGS, skip=True)
collate_quarterly_crossdevice = LakeviewRunJobTask(cid=CID, job="europe-collate-quarterly-crossdevice",
                                                   cluster=CLUSTER, poll_interval=300, name="collate-quarterly-crossdevice", timeout=10800, result=S3_RESULT, tags=TASK_TAGS, skip=True)

import_adform_lotame = LakeviewRunJobTask(cid=CID, job="europe-import-adform-lotame",
                                          cluster=CLUSTER, poll_interval=300, name="import-adform-lotame", timeout=10800, result=S3_RESULT, tags=TASK_TAGS, skip=True)
collate_monthly_lotame = LakeviewRunJobTask(cid=CID, job="europe-collate-monthly-lotame",
                                            cluster=CLUSTER, poll_interval=300, concurrent=True, name="collate-monthly-lotame", timeout=10800, result=S3_RESULT, task_run_name="collate-monthly-lotame-{args[country]}", tags=TASK_TAGS, skip=True)
collate_quarterly_lotame_cookie = LakeviewRunJobTask(cid=CID, job="europe-collate-quarterly-lotcookie",
                                                     cluster=CLUSTER, poll_interval=300, concurrent=True, name="collate-quarterly-lotame-cookie", timeout=10800, result=S3_RESULT, task_run_name="collate-quarterly-lotame-cookie-{args[country]}", tags=TASK_TAGS, skip=True)
collate_quarterly_lotame_mobile = LakeviewRunJobTask(cid=CID, job="europe-collate-quarterly-lotmobile",
                                                     cluster=CLUSTER, poll_interval=300, concurrent=True, name="collate-quarterly-lotame-mobile", timeout=10800, result=S3_RESULT, task_run_name="collate-quarterly-lotame-mobile-{args[country]}", tags=TASK_TAGS, skip=True)

collate_quarterly_crdlotcookie = LakeviewRunJobTask(cid=CID, job="europe-collate-quarterly-crdlotcookie",
                                                    cluster=CLUSTER, poll_interval=300, concurrent=True, name="collate-quarterly-crdlotcookie", timeout=18000, result=S3_RESULT, task_run_name="collate-quarterly-crdlotcookie-{args[country]}", tags=TASK_TAGS)
collate_quarterly_crdlotmobile = LakeviewRunJobTask(cid=CID, job="europe-collate-quarterly-crdlotmobile",
                                                    cluster=CLUSTER, poll_interval=300, concurrent=True, name="collate-quarterly-crdlotmobile", timeout=18000, result=S3_RESULT, task_run_name="collate-quarterly-crdlotmobile-{args[country]}", tags=TASK_TAGS)

with Flow(FLOW_NAME) as flow:
    event_month = Parameter("event_month", default=None)
    a = compute_input_args(event_month)
    t1 = collate_monthly_crossdevice(args=a)
    t2 = collate_quarterly_crossdevice(args=a, upstream_tasks=[t1])
    t3 = import_adform_lotame(args=a)
    arg_map = generate_arg_map(a["event_month"], a["prev_three_months"])
    t4 = collate_monthly_lotame.map(args=arg_map)
    r1 = reduce_map(upstream_tasks=[t4])
    t5 = collate_quarterly_lotame_cookie.map(args=arg_map)
    t5.set_upstream(r1)
    r2 = reduce_map(upstream_tasks=[t5])
    t6 = collate_quarterly_lotame_mobile.map(args=arg_map)
    t6.set_upstream(r2)
    r3 = reduce_map(upstream_tasks=[t6])
    t7 = collate_quarterly_crdlotcookie.map(args=arg_map)
    t8 = collate_quarterly_crdlotmobile.map(args=arg_map)
    t7.set_upstream(r3)
    t8.set_upstream(r3)

flow.storage = GitHub(
    repo="aqfer/product-deployments",
    path="datalake/cids/{}/flows/{}.py".format(CID, FLOW_NAME),
    access_token_secret="GITHUB_ACCESS_TOKEN"
)
flow.run_config = KubernetesRun(
    labels=[REGION],
)

flow.executor = LocalDaskExecutor()
a
I think what causes the issue is the “args” argument on your mapped tasks - it may cause issue with the
*args
set on a task. Can you define the argument explicitly on your
LakeviewRunJobTask
? Here is the gist of a flow I couldn’t get to visualize because of that argument - https://gist.github.com/2c6e1860555d6d8773f54926cd4a3637
s
Copy code
from prefect import Task
from prefect.utilities.tasks import defaults_from_attrs
from prefect.client import Secret
import requests
import time
import json
from datetime import datetime
from prefect import Client
from typing import Any, Dict, Optional
from prefect.engine import signals


class LakeviewRunJobTask(Task):
    AUTH_ENDPOINT = "<http://auth.api.aqfer.net|auth.api.aqfer.net>"
    LAKEVIEW_ENDPOINT = "<http://lakeview.api.aqfer.net|lakeview.api.aqfer.net>"
    ACCESS_TOKEN = "ACCESS_TOKEN"
    REFRESH_TOKEN = "REFRESH_TOKEN"

    def __init__(
        self,
        cid: str,
        job: str,
        args: Optional[Dict[Any, Any]] = None,
        cluster: Optional[str] = None,
        concurrent: bool = False,
        poll_interval: int = 300,
        skip: bool = False,
        **kwargs
    ):
        self.cid = cid
        self.job = job
        self.args = args
        self.cluster = cluster
        self.concurrent = concurrent
        self.poll_interval = poll_interval
        self.skip = skip
        super().__init__(**kwargs)

    def _read_access_token(self):
        try:
            access_token = Secret(self.ACCESS_TOKEN).get()
        except KeyError:
            <http://self.logger.info|self.logger.info>(
                "Access token do not exist in the Secret store, creating one")
            self.access_token_expiry = 0
        else:
            self.access_token = access_token["access_token"]
            self.access_token_expiry = access_token["access_token_expiry"]
            self.request_headers = {
                "Authorization": "Bearer {}".format(self.access_token)
            }

    def __get_access_token_expiry(self):
        expiry = datetime.utcfromtimestamp(self.access_token_expiry).strftime(
            "%Y-%m-%d %H:%M:%S"
        )
        return expiry

    def __write_access_token(self):
        variable = {
            "access_token": self.access_token,
            "access_token_expiry": self.access_token_expiry,
        }
        client = Client()
        client.set_secret(name=self.ACCESS_TOKEN, value=json.dumps(variable))
        <http://self.logger.info|self.logger.info>(
            "Updated access_token to variable store expiring at %s",
            self.__get_access_token_expiry(),
        )

    def _refresh_access_token(self):
        self._read_access_token()
        if self.access_token_expiry > int(time.time()):
            return
        refresh_token = Secret(self.REFRESH_TOKEN).get()
        data = {"grant_type": "refresh_token",
                "refresh_token": refresh_token}
        endpoint = ("https://{}/v1/access_token".format(self.AUTH_ENDPOINT))
        r = <http://requests.post|requests.post>(url=endpoint, json=data)
        <http://self.logger.info|self.logger.info>(r)
        if r.status_code != 200:
            self.logger.error("Failed to get access_token %s", r.text)
            raise Exception("Error fetching access token")
        rj = r.json()
        self.access_token = rj["jwt_token"]
        self.access_token_expiry = (
            rj["expires_in"] + int(time.time()) - 100
        )
        self.request_headers = {
            "Authorization": "Bearer {}".format(self.access_token)
        }
        self.__write_access_token()

    def _create_job(self, cid, job, args, cluster, concurrent):
        self._refresh_access_token()
        <http://self.logger.info|self.logger.info>("Launching Job: cid=%s, job_name=%s", cid, job)
        endpoint = (
            "https://{}/v1/cids/{}/jobs/{}/executions".format(
                self.LAKEVIEW_ENDPOINT, cid, job
            )
        )
        body = {}
        if args:
            body["parameters"] = args

        params = {}
        if cluster:
            params["cluster"] = cluster
        if concurrent:
            params["concurrent"] = "true"

        r = <http://requests.post|requests.post>(
            url=endpoint, json=body, params=params, headers=self.request_headers
        )
        if r.status_code != 200:
            self.logger.error(
                "Failed to launch job - Response code: %s, Response: %s",
                r.status_code,
                r.text,
            )
            raise Exception("Failed to launch job")
        rj = r.json()
        if rj["status"] != "RUNNING":
            self.logger.error(
                "Failed to launch job - Response code: %s, Response: %s",
                r.status_code,
                r.text,
            )
            raise Exception(
                "Failed to launch job, expected status=RUNNING but found status="
                + rj["status"]
            )
        execution_id = rj["execution_id"].split("-")[-1]
        <http://self.logger.info|self.logger.info>(
            "Lauched Job: cid=%s, job_name=%s, execution_id=%s",
            cid,
            job,
            execution_id,
        )
        return execution_id

    def _get_job(self, cid, job, poll_interval):
        <http://self.logger.info|self.logger.info>(
            "Polling enabled! with poll_interval=%d",
            poll_interval
        )
        status = "RUNNING"
        while status == "RUNNING":
            time.sleep(poll_interval)
            self._refresh_access_token()
            <http://self.logger.info|self.logger.info>(
                "Getting Job status: cid=%s, job_name=%s, execution_id=%s",
                cid,
                job,
                self.execution_id,
            )
            endpoint = (
                "<https://lakeview.api.aqfer.net/v1/cids/{}/jobs/{}/runs/{}/status>".format(
                    cid, job, self.execution_id
                )
            )
            r = requests.get(url=endpoint, headers=self.request_headers)
            if r.status_code != 200:
                self.logger.error(
                    "Failed to get job status - Response code: %s, Response: %s",
                    r.status_code,
                    r.text,
                )
                raise Exception("Failed to get job status")
            rj = r.json()
            status = rj["status"].upper()
            <http://self.logger.info|self.logger.info>(
                "Retrieved job status: cid=%s, job_name=%s, execution_id=%s, status=%s",
                cid,
                job,
                self.execution_id,
                status,
            )
        if status == "SUCCEEDED":
            pass
        else:
            raise Exception("Job run didn't succeed, status = " + status)

    def _get_metrics(self, cid, job):
        self._refresh_access_token()
        <http://self.logger.info|self.logger.info>(
            "Getting Job metrics: cid=%s, job_name=%s, execution_id=%s",
            cid,
            job,
            self.execution_id,
        )
        endpoint = (
            "<https://lakeview.api.aqfer.net/v1/cids/{}/jobs/{}/runs/{}>".format(
                cid, job, self.execution_id
            )
        )
        r = requests.get(url=endpoint, headers=self.request_headers)
        if r.status_code != 200:
            self.logger.error(
                "Failed to get job status - Response code: %s, Response: %s",
                r.status_code,
                r.text,
            )
            raise Exception("Failed to get job status")
        rj = r.json()
        return rj

    @defaults_from_attrs("cid", "job", "cluster", "concurrent", "args", "poll_interval", "skip")
    def run(self, cid: str = None, job: str = None, cluster: str = None, concurrent: bool = False, args: str = None, poll_interval: str = None, skip: bool = False):
        if skip:
            raise signals.SKIP()
        self.execution_id = self._create_job(
            cid, job, args, cluster, concurrent)
        self._get_job(cid, job, poll_interval)
        rj = self._get_metrics(cid, job)
        return rj
Here is the class
Should in rename
args
to something else?
a
that would be a good way to do it, since args typically refer to additional arguments you can pass on class/task initialization:
Copy code
def __init__(self, *args, **kwargs):
s
Ok