Hello! We are currently using prefect-ray integrat...
# ask-community
a
Hello! We are currently using prefect-ray integration to connect prefect and Ray.io. When creating a flow of flows, it doesn’t seem like the task runner is kept from parent to child flows. If my parent flow is defined with the ray task runner, my child flows don’t end up running on ray but instead inside the prefect k8s pod. I’m wondering if I’m doing anything wrong here.
Here’s an example
Copy code
from prefect import flow, task, get_run_logger
from hypo.platform.prefect.task_runners import get_ray_runner
import asyncio
from hypo.flows.examples import Num, echo
from prefect_ray.task_runners import RayTaskRunner
from typing import List

def get_ray_runner(address: str = "<ray://ray.mls.gh.st:10001>",
                   working_dir: str = ".",
                   extra_pip: List[str] = []):
    return RayTaskRunner(
        address=address,
        init_kwargs= { 
            "runtime_env" : {
                "working_dir" : working_dir,
                "pip": extra_pip
            } 
        }
    )

@task
def mult_even(num: Num):
    logger = get_run_logger()
    <http://logger.info|logger.info>(f"Returning num: {num.val}")
    assert num.val % 2 == 0
    return num.val * 2

@task
def mult_odd(num: Num):
    logger = get_run_logger()
    <http://logger.info|logger.info>(f"Returning num: {num.val}")
    assert num.val % 2 == 1
    return num.val * 2

@flow()
async def mult2_even(f: int = 10):
    logger = get_run_logger()
    nums = echo.map(range(2, f + 1, 2))
    result = [r.result() for r in mult_even.map(nums)]
    <http://logger.info|logger.info>(f"even results: {result}")
    return result

@flow()
async def mult2_odd(f: int = 10):
    logger = get_run_logger()
    nums = echo.map(range(1, f + 1, 2))
    result = [r.result() for r in mult_odd.map(nums)]
    <http://logger.info|logger.info>(f"odd results: {result}")
    return result

@flow(
    task_runner=get_ray_runner(),
    result_storage="s3/s3-prefect-result-storage",
)
async def mult_add(f: int = 10):
    logger = get_run_logger()
    # results = await asyncio.gather(
    results = await asyncio.gather(
        mult2_even(f),
        mult2_odd(f)
    )
    results[0].extend(results[1])
    results = results[0]
    answer = sum(results)
    <http://logger.info|logger.info>(f"Answer: {answer}")

if __name__ == "__main__":
    asyncio.run(mult_add())