규도자 개발 블로그
airflow에서 각종 operator로 분기처리하기 (feat. SimpleHttpOperator) 본문
airflow에서 각종 operator로 분기처리하기 (feat. SimpleHttpOperator)
규도자 (gyudoza) 2022. 5. 6. 10:23airflow에서 각종 operator로 분기처리하기 (feat. SimpleHttpOperator)
airflow의 기본적인 tutorial에서 제공되는 BranchPythonOperator
처럼 조건에 따라 여러개의 branch를 태우는 방법에 대해서 고민했다. 나는 당연히 SimpleHttpOperator도, BashOperator도 branch operator가 따로 존재할 줄 알았는데 오직 BranchPythonOperator
만 존재했다.
왜 branch operator가 중요하냐, 다른 operator는 True와 False만으로 분기처리를 해야하고, 조건에 따른 다음 task는 triger_rule을 이용해 컨트롤해야 하므로 보다 복잡한 분기처리가 불가능하다는 문제가 있었다.
예를 들어 전 base_task에서 1이라는 결과값을 받으면 task1을,
base_task에서 2라는 결과값을 받으면 task2를,
base_task에서 3이라는 결과값을 받으면 task3을...
만약에 base_task가 BranchPythonOperator
라면 저런 분기처리가 가능하지만 다른 Operator라면 불가능하다. BranchPythonOperator
만이 오직 return값을 task_id라는 문자열로 반환하여 해당 task를 다음에 실행시킬 수 있기 때문이다. 위에서도 말했듯이 다른 Operator는 True or False에 대한 반환만 가능하다.
그래서 결론은 뭐냐, 어떠한 Operator의 결과값에 따라 task 분기를 태우기 위해선 결과값을 기준으로 분기를 나눠주는 BranchPythonOperator
를 중간에 넣어줘야 한다. 그리고 결과값을 다른 task로 넘기기 위해서는 XComs를 써야 한다.
설명을 일일히 하기는 분량이 길어질 것 같고 샘플로 제작한 webserver와 dag를 통해 어떻게 동작하는지 확인해볼 수 있다. 원본 코드가 포함된 저장소는 https://github.com/jujumilk3/airflow-study 여기에 있다.
webserver는 일부러 그 어떤 framework도 사용하지 않고 apache-airflow를 설치할 때 같이 설치되는 requests를 이용해 작동하게 제작하였다.
# webserver.py
from http.server import BaseHTTPRequestHandler, HTTPServer
from time import sleep
import logging
import random
import json
class S(BaseHTTPRequestHandler):
def _set_response(self, status_code=200):
self.send_response(status_code)
self.send_header('Content-type', 'application/json')
self.end_headers()
def response(self, response_dict: dict):
self.wfile.write(json.dumps(response_dict).encode('utf-8'))
def do_GET(self):
logging.info("GET request,\nPath: %s\nHeaders:\n%s\n", str(self.path), str(self.headers))
if str(self.path) == '/airflow/base-task':
work_rand = random.randrange(0, 10)
success = True if work_rand else False # 0에서 9까지 나오므로 1/10 확률로 실패
if success:
self._set_response(200)
response = {
'status': 'success',
'msg': 'hi',
'next_task_number': (work_rand % 3) + 1
}
self.response(response)
else:
self._set_response(404)
response = {
'status': 'failed',
'msg': 'bye'
}
self.response(response)
elif str(self.path) == '/airflow/dummy-task1':
self._set_response(200)
response = airflow_dummy_task1()
self.response(response)
elif str(self.path) == '/airflow/dummy-task2':
self._set_response(200)
response = airflow_dummy_task2()
self.response(response)
elif str(self.path) == '/airflow/dummy-task3':
self._set_response(200)
response = airflow_dummy_task3()
self.response(response)
def do_POST(self):
content_length = int(self.headers['Content-Length']) # <--- Gets the size of data
post_data = self.rfile.read(content_length) # <--- Gets the data itself
logging.info("POST request,\nPath: %s\nHeaders:\n%s\n\nBody:\n%s\n",
str(self.path), str(self.headers), post_data.decode('utf-8'))
self._set_response()
self.wfile.write("POST request for {}".format(self.path).encode('utf-8'))
def airflow_dummy_task1():
count = random.randrange(1, 10) # 최대 9초가 걸리는 dummy task
while count:
count -= 1
sleep(1)
return {'msg': 'dummy_task1', 'status': 'success'}
def airflow_dummy_task2():
count = random.randrange(1, 10) # 최대 9초가 걸리는 dummy task
while count:
count -= 1
sleep(1)
return {'msg': 'dummy_task2', 'status': 'success'}
def airflow_dummy_task3():
count = random.randrange(1, 10) # 최대 9초가 걸리는 dummy task
while count:
count -= 1
sleep(1)
return {'msg': 'dummy_task3', 'status': 'success'}
def run(server_class=HTTPServer, handler_class=S, port=8000):
logging.basicConfig(level=logging.INFO)
server_address = ('', port)
httpd = server_class(server_address, handler_class)
logging.info('Starting httpd...\n')
try:
httpd.serve_forever()
except KeyboardInterrupt:
pass
httpd.server_close()
logging.info('Stopping httpd...\n')
if __name__ == '__main__':
from sys import argv
if len(argv) == 2:
run(port=int(argv[1]))
else:
run()
아래 dag를 참고하면 XComs와 BranchPythonOperator를 이용해 전 task의 결과값을 어떻게 다음 task로 넘기고 그 결과값으로 어떻게 분기처리를 하는지 확인할 수 있을 것이다.
# http_xcom_sample.py
import json
from datetime import datetime, timedelta
from airflow import DAG
from airflow.providers.http.operators.http import SimpleHttpOperator
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.utils.trigger_rule import TriggerRule
def handle_response(response, **context):
print(response)
print(response.__dict__)
print(response.content)
response_json_as_dict = json.loads(response.content)
print(response_json_as_dict)
if str(response.status_code).startswith('2'): # to catch 2XX http status code
context['task_instance'].xcom_push(key='base_task_xcom', value='success') # 이건 안됨이 아니라 잘됨.
context['task_instance'].xcom_push(key='second_task_number', value=response_json_as_dict.get('next_task_number', 1))
return True
else:
context['task_instance'].xcom_push(key='base_task_xcom', value='fail') # 애초에 다음으로 진행이 안되니 무의미
return False
def treat_as_branch(**context):
print("Here is treat_as_branch")
print(context)
base_task_result = context['task_instance'].xcom_pull(key='base_task_xcom')
next_task_number = context['task_instance'].xcom_pull(key='second_task_number')
print("This is base_task_result")
print(base_task_result)
return 'http_dummy_task' + str(next_task_number)
def complete(**context):
print(context)
with DAG(
dag_id='http_xcom_sample',
description='A simple http DAG',
schedule_interval=timedelta(hours=1),
start_date=datetime(2021, 1, 1),
catchup=False,
tags=['example'],
) as dag:
base_task = SimpleHttpOperator(
task_id='base_task',
method='GET',
endpoint='/airflow/base-task',
http_conn_id='localhost',
response_check=handle_response,
)
branch_task = BranchPythonOperator(
task_id='branch_task',
python_callable=treat_as_branch
)
http_dummy_task1 = SimpleHttpOperator(
task_id='http_dummy_task1',
method='GET',
endpoint='/airflow/dummy-task1',
http_conn_id='localhost',
)
http_dummy_task2 = SimpleHttpOperator(
task_id='http_dummy_task2',
method='GET',
endpoint='/airflow/dummy-task2',
http_conn_id='localhost',
)
http_dummy_task3 = SimpleHttpOperator(
task_id='http_dummy_task3',
method='GET',
endpoint='/airflow/dummy-task3',
http_conn_id='localhost',
)
complete_task = PythonOperator(
task_id='complete_task',
python_callable=complete,
trigger_rule=TriggerRule.ONE_SUCCESS
)
base_task >> branch_task >> [http_dummy_task1, http_dummy_task2, http_dummy_task3] >> complete_task