Hello, I have a flow which looks like below: ```@...
# prefect-community
d
Hello, I have a flow which looks like below:
Copy code
@task
def get_list_1():
	
	l = list()
	l.append((t1, t2))
	l.append((t3, t4))
	
	return l

@task
def function_1(arg1, arg2):
	
	# do something
	
	return arg2, False/True

# Other two tasks also have the similar structure.
	
with Flow('test_flow') as flow:
	l_1 = get_list_1()
	res_1 = function_1.map(l_1)
	
	l_2 = get_list_2()
	get_list_2.set_upstream(function_1)
	
	res_2 = function_2.map(l_2)
I want to run this flow with DaskExecutor where the dask cluster has been created with YarnCluster API from dask. The number of workers which will be assigned to scheduler will be limited However, the lists l_1 and l_2 might run into hundreds of tuples. To avoid any worker being killed because of memory or other issues, I want to pass the slices of l_1 (or l_2) to the functions instead of full list. How many slices will depend on the the length of output lists. Can anyone help on how to handle this? or may be there is better solution that passing the slices to the map. Thanks.
👀 1
d
Hi @Dnyaneshwar! I’ll say from my own experience running a similar flow on kubernetes that I don’t generally run into memory issues when passing around large numbers of items like that. Prefect is pretty smart about working with Dask to pass around task inputs as needed. I’d say that I’d give it a shot with a naive implementation first and see if you experience memory issues. If you do experience memory issues, you might try pulling parts of that list at a time and storing them in some sort of cloud storage (S3 or GCS) and then returning a list of references to slices of the list. That way your Dask workers only pull items that they’re actively working on into memory.
I tend to use buckets that store objects for a week (just in case) for this purpose and it doesn’t cost much
d
Hi @Dylan, Thanks for quick response. I tried the naive implementation and it throws KilledWorker Error. I have tried adding the
debug=True
in the DaskExecutor and that is not giving me any more outputs. I can try storing the results locally. However, I am not getting how to do the splits. It seems that output of
task
is not a list and I am not able to implement the usual
len, range
functions on lists. Can you tell me how can I process this?
d
Hey Dnyaneshwar, where are you retrieving the initial list from?
Is it possible to only retrieve chunks of the list at a time, such as with
LIMIT 100 OFFSET 100
in sql?
Storing the result locally while using dask workers is a tricky proposition because you’re not always sure whether a worker will have access to the same local file storage
Which is why I suggested a cloud-based solution
The output of
task
is not a list. Prefect defers computation to runtime and then passes along prefect-native objects. To inspect a task’s output, you’ll need to call the
task.run()
method directly or inspect the output inside of a task
d
Hi @Dylan, The initial list is coming from DB. It is storing set of rows. Second list is also similar in nature. I think I will try to form the chunks as you have suggested by limiting the number of rows in the task and return list of chunks which will be easier to iterate over in flow.
d
If it’s possible to use a
SELECT COUNT(*)
query, you can also pre-prepare the chunks, pass a list of offsets, and map those to produce a list of references to cloud storage
Something like:
Copy code
@task(
    name="Get New Record Counts",
    checkpoint=True,
    max_retries=5,
    retry_delay=timedelta(minutes=3),
)
def get_new_record_counts(table_data, batch_size, pg_connection_string):
    logger = prefect.context.get("logger")

    # extract table name & last updated timestamp
    document_id, table = table_data
    table["document_id"] = document_id

    if not table["last_updated"]:
        table["last_updated"] = pendulum.now().subtract(years=20).to_datetime_string()

    time_sort_field = table_time_sort_field(table["table"])

    # get count of new records
    result = pd.read_sql(
        con=pg_connection_string,
        sql=f"""
        SELECT COUNT(*)
        FROM {table["schema"]}."{table["table"]}"
        WHERE {time_sort_field} > '{table["last_updated"]}';
        """,
    )

    count = result["count"][0]
    <http://logger.info|logger.info>(f"{count} new records")

    offsets = list(range(0, count, batch_size))
    table_offsets = []

    for offset in offsets:
        table_offset = table.copy()
        table_offset["offset"] = offset
        table_offsets.append(table_offset)

    <http://logger.info|logger.info>(table_offsets)
    # table_offsets = [{ doucument_id: "", table: "", schema: "", last_updated: "timestamp", offset: 0 }]
    return table_offsets


@task(
    name="Flatten Table Offset Lists",
    checkpoint=True,
    max_retries=5,
    retry_delay=timedelta(minutes=1),
)
def flatten_table_offset_listst(table_offset_lists):
    logger = prefect.context.get("logger")
    <http://logger.info|logger.info>(f"Table offset lists: {table_offset_lists}")
    flattened_table_offsets = []

    for table_offset_list in table_offset_lists:
        <http://logger.info|logger.info>(f"Table offset list: {table_offset_list}")
        for table_offset in table_offset_list:
            flattened_table_offsets.append(table_offset)

    return flattened_table_offsets


@task(
    name="Extract from Postgres",
    checkpoint=True,
    max_retries=5,
    retry_delay=timedelta(minutes=1),
)
def extract_from_postgres(table_offset, batch_size, pg_connection_string):
    logger = prefect.context.get("logger")
    now = pendulum.now()

    <http://logger.info|logger.info>(f"Table Data {table_offset}")

    last_updated = table_offset["last_updated"]
    offset = table_offset["offset"]
    schema = table_offset["schema"]
    table = table_offset["table"]
    time_sort_field = table_time_sort_field(table_offset["table"])

    <http://logger.info|logger.info>(f"Last updated {last_updated}")
    <http://logger.info|logger.info>(f"Time sort field {time_sort_field}")

    data_frame = pd.read_sql(
        con=pg_connection_string,
        sql=f"""
        SELECT *
        FROM {schema}."{table}"
        WHERE {time_sort_field} > '{last_updated}'
        ORDER BY {time_sort_field} ASC, id
        LIMIT {batch_size}
        OFFSET {offset};
        """,
    )
    <http://logger.info|logger.info>(f"new batch retrieved, {data_frame.count()} rows")

    if "tenant_id" in data_frame.columns:
        table_offset["clustering_fields"] = ["tenant_id"]

    # determine max updated
    table_offset["new_max_updated"] = pendulum.instance(data_frame["updated"].max())
    <http://logger.info|logger.info>(f"max updated {table_offset['new_max_updated']}")

    now = pendulum.now()
    blob_name = f"transfer_to_bigquery/{now.to_date_string()}_{now.to_time_string()}_{table_offset['schema']}_{table_offset['table']}_{table_offset['offset']}.json"

    <http://logger.info|logger.info>(f"attempting to upload: {table_offset} \n {blob_name}")

    # specifically ensure that JSON fields are cast as strings
    json_columns = table_offset.get("json_columns")

    if json_columns:
        transform_config = {}

        for column in data_frame.columns:
            if column in json_columns:
                transform_config[column] = json.dumps
            else:
                transform_config[column] = lambda value: value

        data_frame = data_frame.transform(func=transform_config)

    # add the is_deleted column after the id field
    data_frame.insert(1, "is_deleted", False)

    json_string = data_frame.to_json(
        date_format="iso",
        compression=None,
        orient="records",
        default_handler=str,
        lines=True,
    )

    storage_client = gcs.Client(project="prefect-data-warehouse")
    bucket = storage_client.bucket("prefect_data_warehouse")
    blob = bucket.blob(blob_name=blob_name)
    blob.upload_from_string(json_string)

    gcs_uri = f"<gs://prefect_data_warehouse/{blob_name}>"

    <http://logger.info|logger.info>(f"upload complete {gcs_uri}")

    table_offset["gcs_uri"] = gcs_uri

    return table_offset
Which I freely admit is ripped almost wholesale from one of my flows haha
d
Thanks. 🙂 I will go through this. Hope this works!
d
I store table metadata in a firestore collection and extract that first, which I pass to the
Get Record Counts
task
and pull based on the timestamp in the configuration
Hope this helps!
For more context:
Copy code
with Flow(
    name="ELT to Data Warehouse",
    result_handler=GCSResultHandler(
        bucket="data_warehouse_flow_result_handler",
        credentials_secret="DATA_WAREHOUSE_GOOGLE_CLOUD_CREDENTIALS",
    ),
    environment=environment,
    storage=storage,
    schedule=schedule,
    state_handlers=[state_handler],
) as flow:
    # setup
    BATCH_COUNT = Parameter(name="batch_count", default=0)
    BATCH_SIZE = Parameter(name="batch_size", default=10000)
    BQ_DATASET = Parameter(name="bq_dataset", default="prefect_cloud__production")
    TABLE_NAMES = Parameter(name="table_names", default=None, required=False)
    FIRESTORE_COLLECTION = Parameter(
        name="firestore_collection", default="tables_metadata"
    )

    PG_CONNECTION_STRING = PrefectSecret()

    # get table metadata
    table_meta_data = load_table_metadata(
        table_names=TABLE_NAMES, firestore_collection=FIRESTORE_COLLECTION
    )
    # get counts
    table_offset_lists = get_new_record_counts.map(
        table_data=table_meta_data,
        batch_size=unmapped(BATCH_SIZE),
        batch_count=unmapped(BATCH_COUNT),
        pg_connection_string=unmapped(PG_CONNECTION_STRING),
    )
    # flatten
    flattened_table_offsets = flatten_table_offset_listst(
        table_offset_lists=table_offset_lists
    )

    # map over each table and etl
    table_offsets_and_uris = extract_from_postgres.map(
        table_offset=flattened_table_offsets,
        batch_size=unmapped(BATCH_SIZE),
        pg_connection_string=unmapped(PG_CONNECTION_STRING),
    )

    ## do other stuff with the lists
d
Hi @Dylan, I tried the
LocalResult
class for storing the output from the task and it is working. This will store the rows coming from DB in cloudpickle file. Will the
map
for the next function take input from this file automatically or we have to pass this explicitly.
d
@Dnyaneshwar the next task should take that input automatically 👍