Hey y'all, I can't seem to override the container ...
# ask-community
d
Hey y'all, I can't seem to override the container definition in a work pool. I'm trying to require a GPU on the ECS task definition. I can get it to work correctly if I manually create a task definition, but it looks like somehow this setting doesn't make it to task definitions created by prefect. I traced through the task definition code on Github and it isn't obvious to me why this setting would not be respected. Any thoughts?
q
@Dev Dabke Hi Dev, did you manage to figure this out? I'm trying to do a similar set-up too
d
Yeah we ended up with a custom task definition entirely!
Much better for CPUs or GPUs ultimately
q
Interesting! Do you mind sharing with me your task definition?
d
Yeah are you set up in Prefect Cloud with a hybrid ECS queue? Do you already have an ASG setup?
q
Hey Dev, yup I'm using Prefect Cloud and self-hosting Prefect ECS worker. Managed to create a prefect.yaml and deployment run that works using FARGATE and CPU workload Currently, I've set up the Capacity Provider with ASG and GPU instance. It works when I manually create a task definition and trigger the task run using this capacity provider in AWS (w/o prefect). However, with prefect, I'm still trying to find out how to define the capacity_provider_strategy and resource_requirement in the prefect.yaml In your case, did you pass in the task_definition_arn instead of defining the resource_requirement in prefect.yaml? How did you configure capacity_provider_strategy in the prefect.yaml?
d
Ahhh okay we're not using the
yaml
We have a custom function to create ECS Task Definitions. AwsEcsMachine is just a dataclass that holds some basic information about the machine, like memory, CPUs, etc.
Copy code
"""An ECS task definition for GPU machines."""

from .aws_ecs_machine import AwsEcsMachine

MAX_MEMORY = 61440
MEMORY_OVERHEAD = 1024
ROLE_NAME = "PrefectEcsTaskExecutionRole"


def get_cpu_task_dict(
    flow_name: str, commit: str, machine: AwsEcsMachine, is_main_line: bool
):
    """
    Get the ECS task definition for CPU machines.

    Args:
        flow_name: the name of the flow.
        commit: the commit hash to use for the image.
        machine: the machine configuration to use.
        is_main_line: whether the flow is the main line or not.
    """
    main_display = "main" if is_main_line else "feat"
    log_prefix = f"prefect-cpu-{main_display}__{flow_name}__{commit}"
    task_dict: dict[
        str,
        str
        | list[
            dict[
                str,
                str
                | int
                | list[str | int]
                | list[dict[str, str] | dict[str, str | dict[str, str]]],
            ]
        ],
    ] = {
        "family": f"prefect-cpu-{main_display}__{flow_name}",
        "containerDefinitions": [
            {
                "name": "prefect",
                "image": f"{machine.image.to_ecr()}:{commit}",
                "cpu": 0,
                "portMappings": [],
                "essential": True,
                "environment": [],
                "mountPoints": [],
                "volumesFrom": [],
                "logConfiguration": {
                    "logDriver": "awslogs",
                    "options": {
                        "awslogs-group": "prefect",
                        "awslogs-create-group": "true",
                        "awslogs-region": "us-east-2",
                        "awslogs-stream-prefix": log_prefix,
                    },
                },
                "systemControls": [],
            }
        ],
        # NOTE: change xxx to your AWS account number
        "executionRoleArn": f"arn:aws:iam::xxx:role/{ROLE_NAME}",
        "networkMode": "awsvpc",
        "requiresCompatibilities": ["FARGATE"],
        "cpu": f"{str(machine.cpu_value)}",
        "memory": f"{str(machine.memory_value)}",
        "ephemeralStorage": {"sizeInGiB": machine.storage},
        "tags": [
            {
                "key": "commit",
                "value": commit,
            },
            {
                "key": "is_main_line",
                "value": str(is_main_line),
            },
        ],
    }

    return task_dict


def get_gpu_task_dict(
    flow_name: str, commit: str, machine: AwsEcsMachine, is_main_line: bool
):
    """
    Get the ECS task definition for GPU machines.

    Args:
        flow_name: the name of the flow.
        commit: the commit hash to use for the image.
        machine: the machine configuration
        is_main_line: whether the flow is the main line or not.
    """
    main_display = "main" if is_main_line else "feat"
    log_prefix = f"prefect-gpu-{main_display}__{flow_name}__{commit}"
    task_dict: dict[
        str,
        str
        | list[
            dict[
                str,
                str
                | int
                | list[str | int]
                | list[dict[str, str] | dict[str, str | dict[str, str]]],
            ]
        ],
    ] = {
        "family": f"prefect-gpu-{main_display}__{flow_name}",
        "containerDefinitions": [
            {
                "name": "prefect",
                "image": f"{machine.image.to_ecr()}:{commit}",
                "cpu": 8192,  # Hardcoded because of the instance type
                "memory": MAX_MEMORY - MEMORY_OVERHEAD,
                "portMappings": [],
                "essential": True,
                "environment": [],
                "mountPoints": [],
                "volumesFrom": [],
                "logConfiguration": {
                    "logDriver": "awslogs",
                    "options": {
                        "awslogs-group": "prefect",
                        "awslogs-create-group": "true",
                        "awslogs-region": "us-east-2",
                        "awslogs-stream-prefix": log_prefix,
                    },
                },
                "systemControls": [],
                "resourceRequirements": [{"value": "1", "type": "GPU"}],
            }
        ],
        # NOTE: change xxx to your AWS account number
        "executionRoleArn": f"arn:aws:iam::xxx:role/{ROLE_NAME}",
        "cpu": "8192",
        "memory": f"{MAX_MEMORY - MEMORY_OVERHEAD}",
        "ipcMode": "host",
        "tags": [
            {
                "key": "commit",
                "value": commit,
            },
            {
                "key": "is_main_line",
                "value": str(is_main_line),
            },
        ],
    }

    return task_dict
We then call:
Copy code
task_definition_arn = boto3.client("ecs").register_task_definition(**task_dict)["taskDefinition"]["taskDefinitionArn"]
job_variables = {"task_definition_arn": task_definition_arn}
flow.deploy(
    ...,
    job_variables=job_variables,
    ...
)