Abhinav Chordia
05/08/2023, 4:11 PMAbhinav Chordia
05/08/2023, 4:11 PMfrom prefect import flow, task, get_run_logger
from hypo.platform.prefect.task_runners import get_ray_runner
import asyncio
from hypo.flows.examples import Num, echo
from prefect_ray.task_runners import RayTaskRunner
from typing import List
def get_ray_runner(address: str = "<ray://ray.mls.gh.st:10001>",
working_dir: str = ".",
extra_pip: List[str] = []):
return RayTaskRunner(
address=address,
init_kwargs= {
"runtime_env" : {
"working_dir" : working_dir,
"pip": extra_pip
}
}
)
@task
def mult_even(num: Num):
logger = get_run_logger()
<http://logger.info|logger.info>(f"Returning num: {num.val}")
assert num.val % 2 == 0
return num.val * 2
@task
def mult_odd(num: Num):
logger = get_run_logger()
<http://logger.info|logger.info>(f"Returning num: {num.val}")
assert num.val % 2 == 1
return num.val * 2
@flow()
async def mult2_even(f: int = 10):
logger = get_run_logger()
nums = echo.map(range(2, f + 1, 2))
result = [r.result() for r in mult_even.map(nums)]
<http://logger.info|logger.info>(f"even results: {result}")
return result
@flow()
async def mult2_odd(f: int = 10):
logger = get_run_logger()
nums = echo.map(range(1, f + 1, 2))
result = [r.result() for r in mult_odd.map(nums)]
<http://logger.info|logger.info>(f"odd results: {result}")
return result
@flow(
task_runner=get_ray_runner(),
result_storage="s3/s3-prefect-result-storage",
)
async def mult_add(f: int = 10):
logger = get_run_logger()
# results = await asyncio.gather(
results = await asyncio.gather(
mult2_even(f),
mult2_odd(f)
)
results[0].extend(results[1])
results = results[0]
answer = sum(results)
<http://logger.info|logger.info>(f"Answer: {answer}")
if __name__ == "__main__":
asyncio.run(mult_add())