t

    Theo Platt

    10 months ago
    I'm seeing some odd behavior with mapped tasks and calling AWS Batch jobs. Essentially I have one mapped task that kicks off N batch jobs and returns a list of batch job ids. The next mapped task takes those job ids as input and runs
    AWSClientWait
    to monitor the status of each of those jobs. Here's the code for that mapped task -
    @task
    def wait_batch(job_id, delay, max_attempts):
    
        logger = prefect.context.get('logger')
    
        <http://logger.info|logger.info>(f"Waiting for job to complete: {job_id}")
    
        waiter = AWSClientWait(
            client='batch',
            waiter_name='JobComplete',
        )
        waiter.run(
            waiter_kwargs={
                'jobs': [job_id],
                'WaiterConfig': {
                    'Delay': delay,
                    'MaxAttempts': max_attempts
                }
            },   
        )
    
        <http://logger.info|logger.info>(f"Job complete: {job_id}")
    
        return job_id
    But what we are sometimes seeing are one or more batch jobs failing, which then somehow stops the other jobs from responding to this AWSClientWait call... and so the mapped task keeps running even though all the jobs have either failed or completed. Any ideas?
    Kevin Kho

    Kevin Kho

    10 months ago
    Hey @Theo Platt, what is the behavior you are seeing?
    t

    Theo Platt

    10 months ago
    (hit return too soon - edited post)
    Kevin Kho

    Kevin Kho

    10 months ago
    Do you call the AWSBatch jobs with the Prefect task?
    It looks good so far, what do your mapped calls look like?
    t

    Theo Platt

    10 months ago
    Thanks @Kevin Kho. Here's a reduced/redacted version of how we call the batch tasks (as a map)
    @task(name="My Batch call")
    def run_batch(data):
        logger = prefect.context.get('logger')
    
        <http://logger.info|logger.info>(f'{data=}')
    
        batchjob =  BatchSubmit(job_name='myjob_batch', job_definition='myjob_batch', job_queue='prefect_efs_queue_spot')
        job_id = batchjob.run(batch_kwargs={
            'containerOverrides': {
                  'environment': [
                    {
                      'name': 'data',
                      'value': data
                    }                  
                  ]
            }
        })
    
        return job_id
    Kevin Kho

    Kevin Kho

    10 months ago
    I mean in the flow block. Am looking if you used unmapped?
    t

    Theo Platt

    10 months ago
    Ah right - ok - here's a redacted version of that (hopefully no errors)
    #
    # Kick off in parallel all the jobs in parallel
    #
    batch_job_ids = run_batch.map(
        data=somedata
    )
    
    #
    # Poll for the completion of all jobs
    #
    exit_codes = wait_batch.map(
        batch_job_ids,
        delay=unmapped(5),
        max_attempts=unmapped(1000)
    )
    exit_codes.set_upstream(unmapped(batch_job_ids))
    Kevin Kho

    Kevin Kho

    10 months ago
    What is your executor?
    t

    Theo Platt

    10 months ago
    Local Dask
    executor=LocalDaskExecutor(scheduler="processes")
    (which is a problem I need to solve as the instance running the flow has four cores so my waiting step is only really looking at 4 at a time)
    Ah... could that be why? Maybe by the time it gets to check the last four, the Batch jobs have long disappeared and so the AWSClientWait has nothing to query?
    Kevin Kho

    Kevin Kho

    10 months ago
    Everything really looks good to me and I looked at the task code as well. The task is written in a thread safe way. The boto client seems to be created on the fly so they aren’t reused between batches. My only thought is maybe using waiter_definition instead of waiter_kwargs.
    Your thought is a good one to explore. Is doing 4 waits stable?
    t

    Theo Platt

    10 months ago
    I only see the problem when some of the Batch jobs have failed but it may have happened in larger runs - often it will be 50+ parallel Batch jobs. Do you know what would happen if AWSClientWait got called with a jobid that is no longer running??
    Kevin Kho

    Kevin Kho

    10 months ago
    That I do not know
    I guess you can try right away right? Just call the
    AWSClientWait
    on a job id you have that is done?
    t

    Theo Platt

    10 months ago
    will do
    Or... I guess it would work if I didn't do a mapped task for the waiting and just passed in a list of jobids to waiter_kwargs
    waiter = AWSClientWait(
            client='batch',
            waiter_name='JobComplete',
        )
        waiter.run(
            waiter_kwargs={
                'jobs': list_of_job_ids,
                'WaiterConfig': {
                    'Delay': delay,
                    'MaxAttempts': max_attempts
                }
            },   
        )
    @Kevin Kho to close this thread - that last idea above of passing through a list of jobids works and solves my problem. However it has a limit of 100 jobs it will wait for, even though in Batch you can queue up thousands and it will process them as it can depending on the max resources you allow the compute environment.
    Kevin Kho

    Kevin Kho

    10 months ago
    Ah I see. Thanks for circling back on this
    t

    Theo Platt

    9 months ago
    @Kevin Kho and others in case you run into the same problem. Here's a fix we came up with. It's not perfect but it solves the problem for waiting for more than 100 batch jobs. Basically we chunk the jobs up into 100s and wait for each 100 in turn in a for loop. Not perfect but adequate as we have to wait until the very last one finishes before moving on and it doesn't matter if that's in the first 100 or the last 100.
    @task
    def wait_batches(job_ids, delay, max_attempts):
    
        logger = prefect.context.get('logger')
    
        if len(job_ids) > 0:
          <http://logger.info|logger.info>(f"Waiting for job(s) to complete: {job_ids}")
    
          waiter = AWSClientWait(
              client='batch',
              waiter_name='JobComplete',
          )
    
          aws_waiter_limit: int = 100 #aws imposed limit on batch.describe_job() in boto3
          tranched_job_ids = [job_ids[pos:pos + aws_waiter_limit] for pos in range(0, len(job_ids), aws_waiter_limit)]
          for tranch in tranched_job_ids:
    
            <http://logger.info|logger.info>(f"Tranch: {tranch[0]}, {tranch[-1]}, {len(tranch)}")
            waiter.run(
                waiter_kwargs={
                    'jobs': tranch,
                    'WaiterConfig': {
                        'Delay': delay,
                        'MaxAttempts': max_attempts
                    }
                },   
            )
    
        return
    Kevin Kho

    Kevin Kho

    7 months ago
    @Marvin archive “AWS Batch and Wait”
    Marvin

    Marvin

    7 months ago