Dnyaneshwar
07/09/2020, 2:52 PM@task
def get_list_1():
l = list()
l.append((t1, t2))
l.append((t3, t4))
return l
@task
def function_1(arg1, arg2):
# do something
return arg2, False/True
# Other two tasks also have the similar structure.
with Flow('test_flow') as flow:
l_1 = get_list_1()
res_1 = function_1.map(l_1)
l_2 = get_list_2()
get_list_2.set_upstream(function_1)
res_2 = function_2.map(l_2)
I want to run this flow with DaskExecutor where the dask cluster has been created with YarnCluster API from dask.
The number of workers which will be assigned to scheduler will be limited
However, the lists l_1 and l_2 might run into hundreds of tuples.
To avoid any worker being killed because of memory or other issues, I want to pass the slices of l_1 (or l_2) to the functions instead of full list.
How many slices will depend on the the length of output lists.
Can anyone help on how to handle this? or may be there is better solution that passing the slices to the map. Thanks.Dylan
07/09/2020, 3:01 PMDnyaneshwar
07/09/2020, 4:00 PMdebug=True
in the DaskExecutor and that is not giving me any more outputs.
I can try storing the results locally. However, I am not getting how to do the splits. It seems that output of task
is not a list and I am not able to implement the usual len, range
functions on lists.
Can you tell me how can I process this?Dylan
07/09/2020, 4:04 PMLIMIT 100 OFFSET 100
in sql?task
is not a list. Prefect defers computation to runtime and then passes along prefect-native objects. To inspect a task’s output, you’ll need to call the task.run()
method directly or inspect the output inside of a taskDnyaneshwar
07/09/2020, 4:19 PMDylan
07/09/2020, 4:20 PMSELECT COUNT(*)
query, you can also pre-prepare the chunks, pass a list of offsets, and map those to produce a list of references to cloud storage@task(
name="Get New Record Counts",
checkpoint=True,
max_retries=5,
retry_delay=timedelta(minutes=3),
)
def get_new_record_counts(table_data, batch_size, pg_connection_string):
logger = prefect.context.get("logger")
# extract table name & last updated timestamp
document_id, table = table_data
table["document_id"] = document_id
if not table["last_updated"]:
table["last_updated"] = pendulum.now().subtract(years=20).to_datetime_string()
time_sort_field = table_time_sort_field(table["table"])
# get count of new records
result = pd.read_sql(
con=pg_connection_string,
sql=f"""
SELECT COUNT(*)
FROM {table["schema"]}."{table["table"]}"
WHERE {time_sort_field} > '{table["last_updated"]}';
""",
)
count = result["count"][0]
<http://logger.info|logger.info>(f"{count} new records")
offsets = list(range(0, count, batch_size))
table_offsets = []
for offset in offsets:
table_offset = table.copy()
table_offset["offset"] = offset
table_offsets.append(table_offset)
<http://logger.info|logger.info>(table_offsets)
# table_offsets = [{ doucument_id: "", table: "", schema: "", last_updated: "timestamp", offset: 0 }]
return table_offsets
@task(
name="Flatten Table Offset Lists",
checkpoint=True,
max_retries=5,
retry_delay=timedelta(minutes=1),
)
def flatten_table_offset_listst(table_offset_lists):
logger = prefect.context.get("logger")
<http://logger.info|logger.info>(f"Table offset lists: {table_offset_lists}")
flattened_table_offsets = []
for table_offset_list in table_offset_lists:
<http://logger.info|logger.info>(f"Table offset list: {table_offset_list}")
for table_offset in table_offset_list:
flattened_table_offsets.append(table_offset)
return flattened_table_offsets
@task(
name="Extract from Postgres",
checkpoint=True,
max_retries=5,
retry_delay=timedelta(minutes=1),
)
def extract_from_postgres(table_offset, batch_size, pg_connection_string):
logger = prefect.context.get("logger")
now = pendulum.now()
<http://logger.info|logger.info>(f"Table Data {table_offset}")
last_updated = table_offset["last_updated"]
offset = table_offset["offset"]
schema = table_offset["schema"]
table = table_offset["table"]
time_sort_field = table_time_sort_field(table_offset["table"])
<http://logger.info|logger.info>(f"Last updated {last_updated}")
<http://logger.info|logger.info>(f"Time sort field {time_sort_field}")
data_frame = pd.read_sql(
con=pg_connection_string,
sql=f"""
SELECT *
FROM {schema}."{table}"
WHERE {time_sort_field} > '{last_updated}'
ORDER BY {time_sort_field} ASC, id
LIMIT {batch_size}
OFFSET {offset};
""",
)
<http://logger.info|logger.info>(f"new batch retrieved, {data_frame.count()} rows")
if "tenant_id" in data_frame.columns:
table_offset["clustering_fields"] = ["tenant_id"]
# determine max updated
table_offset["new_max_updated"] = pendulum.instance(data_frame["updated"].max())
<http://logger.info|logger.info>(f"max updated {table_offset['new_max_updated']}")
now = pendulum.now()
blob_name = f"transfer_to_bigquery/{now.to_date_string()}_{now.to_time_string()}_{table_offset['schema']}_{table_offset['table']}_{table_offset['offset']}.json"
<http://logger.info|logger.info>(f"attempting to upload: {table_offset} \n {blob_name}")
# specifically ensure that JSON fields are cast as strings
json_columns = table_offset.get("json_columns")
if json_columns:
transform_config = {}
for column in data_frame.columns:
if column in json_columns:
transform_config[column] = json.dumps
else:
transform_config[column] = lambda value: value
data_frame = data_frame.transform(func=transform_config)
# add the is_deleted column after the id field
data_frame.insert(1, "is_deleted", False)
json_string = data_frame.to_json(
date_format="iso",
compression=None,
orient="records",
default_handler=str,
lines=True,
)
storage_client = gcs.Client(project="prefect-data-warehouse")
bucket = storage_client.bucket("prefect_data_warehouse")
blob = bucket.blob(blob_name=blob_name)
blob.upload_from_string(json_string)
gcs_uri = f"<gs://prefect_data_warehouse/{blob_name}>"
<http://logger.info|logger.info>(f"upload complete {gcs_uri}")
table_offset["gcs_uri"] = gcs_uri
return table_offset
Dnyaneshwar
07/09/2020, 4:24 PMDylan
07/09/2020, 4:25 PMGet Record Counts
taskwith Flow(
name="ELT to Data Warehouse",
result_handler=GCSResultHandler(
bucket="data_warehouse_flow_result_handler",
credentials_secret="DATA_WAREHOUSE_GOOGLE_CLOUD_CREDENTIALS",
),
environment=environment,
storage=storage,
schedule=schedule,
state_handlers=[state_handler],
) as flow:
# setup
BATCH_COUNT = Parameter(name="batch_count", default=0)
BATCH_SIZE = Parameter(name="batch_size", default=10000)
BQ_DATASET = Parameter(name="bq_dataset", default="prefect_cloud__production")
TABLE_NAMES = Parameter(name="table_names", default=None, required=False)
FIRESTORE_COLLECTION = Parameter(
name="firestore_collection", default="tables_metadata"
)
PG_CONNECTION_STRING = PrefectSecret()
# get table metadata
table_meta_data = load_table_metadata(
table_names=TABLE_NAMES, firestore_collection=FIRESTORE_COLLECTION
)
# get counts
table_offset_lists = get_new_record_counts.map(
table_data=table_meta_data,
batch_size=unmapped(BATCH_SIZE),
batch_count=unmapped(BATCH_COUNT),
pg_connection_string=unmapped(PG_CONNECTION_STRING),
)
# flatten
flattened_table_offsets = flatten_table_offset_listst(
table_offset_lists=table_offset_lists
)
# map over each table and etl
table_offsets_and_uris = extract_from_postgres.map(
table_offset=flattened_table_offsets,
batch_size=unmapped(BATCH_SIZE),
pg_connection_string=unmapped(PG_CONNECTION_STRING),
)
## do other stuff with the lists
Dnyaneshwar
07/10/2020, 1:42 PMLocalResult
class for storing the output from the task and it is working. This will store the rows coming from DB in cloudpickle file.
Will the map
for the next function take input from this file automatically or we have to pass this explicitly.Dylan
07/10/2020, 3:38 PM