규도자 개발 블로그

airflow에서 각종 operator로 분기처리하기 (feat. SimpleHttpOperator) 본문

Python/Airflow

airflow에서 각종 operator로 분기처리하기 (feat. SimpleHttpOperator)

규도자 (gyudoza) 2022. 5. 6. 10:23

airflow에서 각종 operator로 분기처리하기 (feat. SimpleHttpOperator)

airflow의 기본적인 tutorial에서 제공되는 BranchPythonOperator처럼 조건에 따라 여러개의 branch를 태우는 방법에 대해서 고민했다. 나는 당연히 SimpleHttpOperator도, BashOperator도 branch operator가 따로 존재할 줄 알았는데 오직 BranchPythonOperator만 존재했다.

왜 branch operator가 중요하냐, 다른 operator는 True와 False만으로 분기처리를 해야하고, 조건에 따른 다음 task는 triger_rule을 이용해 컨트롤해야 하므로 보다 복잡한 분기처리가 불가능하다는 문제가 있었다.

 

예를 들어 전 base_task에서 1이라는 결과값을 받으면 task1을,
base_task에서 2라는 결과값을 받으면 task2를,
base_task에서 3이라는 결과값을 받으면 task3을...

 

만약에 base_task가 BranchPythonOperator라면 저런 분기처리가 가능하지만 다른 Operator라면 불가능하다. BranchPythonOperator만이 오직 return값을 task_id라는 문자열로 반환하여 해당 task를 다음에 실행시킬 수 있기 때문이다. 위에서도 말했듯이 다른 Operator는 True or False에 대한 반환만 가능하다.

 

그래서 결론은 뭐냐, 어떠한 Operator의 결과값에 따라 task 분기를 태우기 위해선 결과값을 기준으로 분기를 나눠주는 BranchPythonOperator를 중간에 넣어줘야 한다. 그리고 결과값을 다른 task로 넘기기 위해서는 XComs를 써야 한다.

 

설명을 일일히 하기는 분량이 길어질 것 같고 샘플로 제작한 webserver와 dag를 통해 어떻게 동작하는지 확인해볼 수 있다. 원본 코드가 포함된 저장소는 https://github.com/jujumilk3/airflow-study 여기에 있다.

 

webserver는 일부러 그 어떤 framework도 사용하지 않고 apache-airflow를 설치할 때 같이 설치되는 requests를 이용해 작동하게 제작하였다.

# webserver.py
from http.server import BaseHTTPRequestHandler, HTTPServer
from time import sleep
import logging
import random
import json


class S(BaseHTTPRequestHandler):
    def _set_response(self, status_code=200):
        self.send_response(status_code)
        self.send_header('Content-type', 'application/json')
        self.end_headers()

    def response(self, response_dict: dict):
        self.wfile.write(json.dumps(response_dict).encode('utf-8'))

    def do_GET(self):
        logging.info("GET request,\nPath: %s\nHeaders:\n%s\n", str(self.path), str(self.headers))
        if str(self.path) == '/airflow/base-task':
            work_rand = random.randrange(0, 10)
            success = True if work_rand else False  # 0에서 9까지 나오므로 1/10 확률로 실패
            if success:
                self._set_response(200)
                response = {
                    'status': 'success',
                    'msg': 'hi',
                    'next_task_number': (work_rand % 3) + 1
                }
                self.response(response)

            else:
                self._set_response(404)
                response = {
                    'status': 'failed',
                    'msg': 'bye'
                }
                self.response(response)

        elif str(self.path) == '/airflow/dummy-task1':
            self._set_response(200)
            response = airflow_dummy_task1()
            self.response(response)

        elif str(self.path) == '/airflow/dummy-task2':
            self._set_response(200)
            response = airflow_dummy_task2()
            self.response(response)

        elif str(self.path) == '/airflow/dummy-task3':
            self._set_response(200)
            response = airflow_dummy_task3()
            self.response(response)

    def do_POST(self):
        content_length = int(self.headers['Content-Length'])  # <--- Gets the size of data
        post_data = self.rfile.read(content_length)  # <--- Gets the data itself
        logging.info("POST request,\nPath: %s\nHeaders:\n%s\n\nBody:\n%s\n",
                     str(self.path), str(self.headers), post_data.decode('utf-8'))

        self._set_response()
        self.wfile.write("POST request for {}".format(self.path).encode('utf-8'))


def airflow_dummy_task1():
    count = random.randrange(1, 10)  # 최대 9초가 걸리는 dummy task
    while count:
        count -= 1
        sleep(1)
    return {'msg': 'dummy_task1', 'status': 'success'}


def airflow_dummy_task2():
    count = random.randrange(1, 10)  # 최대 9초가 걸리는 dummy task
    while count:
        count -= 1
        sleep(1)
    return {'msg': 'dummy_task2', 'status': 'success'}


def airflow_dummy_task3():
    count = random.randrange(1, 10)  # 최대 9초가 걸리는 dummy task
    while count:
        count -= 1
        sleep(1)
    return {'msg': 'dummy_task3', 'status': 'success'}


def run(server_class=HTTPServer, handler_class=S, port=8000):
    logging.basicConfig(level=logging.INFO)
    server_address = ('', port)
    httpd = server_class(server_address, handler_class)
    logging.info('Starting httpd...\n')
    try:
        httpd.serve_forever()
    except KeyboardInterrupt:
        pass
    httpd.server_close()
    logging.info('Stopping httpd...\n')


if __name__ == '__main__':
    from sys import argv

    if len(argv) == 2:
        run(port=int(argv[1]))
    else:
        run()

아래 dag를 참고하면 XComs와 BranchPythonOperator를 이용해 전 task의 결과값을 어떻게 다음 task로 넘기고 그 결과값으로 어떻게 분기처리를 하는지 확인할 수 있을 것이다.

# http_xcom_sample.py
import json
from datetime import datetime, timedelta

from airflow import DAG
from airflow.providers.http.operators.http import SimpleHttpOperator
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.utils.trigger_rule import TriggerRule


def handle_response(response, **context):
    print(response)
    print(response.__dict__)
    print(response.content)
    response_json_as_dict = json.loads(response.content)
    print(response_json_as_dict)
    if str(response.status_code).startswith('2'):  # to catch 2XX http status code
        context['task_instance'].xcom_push(key='base_task_xcom', value='success')  # 이건 안됨이 아니라 잘됨.
        context['task_instance'].xcom_push(key='second_task_number', value=response_json_as_dict.get('next_task_number', 1))
        return True
    else:
        context['task_instance'].xcom_push(key='base_task_xcom', value='fail')  # 애초에 다음으로 진행이 안되니 무의미
        return False


def treat_as_branch(**context):
    print("Here is treat_as_branch")
    print(context)
    base_task_result = context['task_instance'].xcom_pull(key='base_task_xcom')
    next_task_number = context['task_instance'].xcom_pull(key='second_task_number')
    print("This is base_task_result")
    print(base_task_result)
    return 'http_dummy_task' + str(next_task_number)


def complete(**context):
    print(context)


with DAG(
    dag_id='http_xcom_sample',
    description='A simple http DAG',
    schedule_interval=timedelta(hours=1),
    start_date=datetime(2021, 1, 1),
    catchup=False,
    tags=['example'],
) as dag:
    base_task = SimpleHttpOperator(
        task_id='base_task',
        method='GET',
        endpoint='/airflow/base-task',
        http_conn_id='localhost',
        response_check=handle_response,
    )

    branch_task = BranchPythonOperator(
        task_id='branch_task',
        python_callable=treat_as_branch
    )

    http_dummy_task1 = SimpleHttpOperator(
        task_id='http_dummy_task1',
        method='GET',
        endpoint='/airflow/dummy-task1',
        http_conn_id='localhost',
    )

    http_dummy_task2 = SimpleHttpOperator(
        task_id='http_dummy_task2',
        method='GET',
        endpoint='/airflow/dummy-task2',
        http_conn_id='localhost',
    )

    http_dummy_task3 = SimpleHttpOperator(
        task_id='http_dummy_task3',
        method='GET',
        endpoint='/airflow/dummy-task3',
        http_conn_id='localhost',
    )

    complete_task = PythonOperator(
        task_id='complete_task',
        python_callable=complete,
        trigger_rule=TriggerRule.ONE_SUCCESS
    )

    base_task >> branch_task >> [http_dummy_task1, http_dummy_task2, http_dummy_task3] >> complete_task

 

 

0 Comments
댓글쓰기 폼