在 Airflow 2.0 中运行多个 Athena 查询

问题描述

我正在尝试创建一个 DAG,其中一个任务使用 athena 执行 boto3 查询。它适用于一个查询,但是当我尝试运行多个 athena 查询时遇到了问题。

这个问题可以解决如下:-

  1. 如果浏览this博客,可以看出athena使用start_query_execution触发查询get_query_execution获取status,{{1 }} 和有关查询的其他数据(athena 的文档)

按照上述模式后,我有以下代码:-

queryExecutionId

在运行上面的代码时,我收到以下错误:-

import json
import time
import asyncio
import boto3
import logging
from airflow import DAG
from airflow.operators.python import Pythonoperator


def execute_query(client,query,database,output_location):
    response = client.start_query_execution(
        QueryString=query,QueryExecutionContext={
            'Database': database
        },ResultConfiguration={
            'OutputLocation': output_location
        }
    )

    return response['QueryExecutionId']


async def get_ids(client_athena,output_location):
    query_responses = []
    for i in range(5):
        query_responses.append(execute_query(client_athena,output_location))    

    res = await asyncio.gather(*query_responses,return_exceptions=True)

    return res

def run_athena_query(query,output_location,region_name,**context):
    BOTO_SESSION = boto3.Session(
        aws_access_key_id = 'YOUR_KEY',aws_secret_access_key = 'YOUR_ACCESS_KEY')
    client_athena = BOTO_SESSION.client('athena',region_name=region_name)

    loop = asyncio.get_event_loop()
    query_execution_ids = loop.run_until_complete(get_ids(client_athena,output_location))
    loop.close()

    repetitions = 900
    error_messages = []
    s3_uris = []

    while repetitions > 0 and len(query_execution_ids) > 0:
        repetitions = repetitions - 1
        
        query_response_list = client_athena.batch_get_query_execution(
            QueryExecutionIds=query_execution_ids)['QueryExecutions']
      
        for query_response in query_response_list:
            if 'QueryExecution' in query_response and \
                    'Status' in query_response['QueryExecution'] and \
                    'State' in query_response['QueryExecution']['Status']:
                state = query_response['QueryExecution']['Status']['State']

                if state in ['Failed','CANCELLED']:
                    error_reason = query_response['QueryExecution']['Status']['StateChangeReason']
                    error_message = 'Final state of Athena job is {},query_execution_id is {}. Error: {}'.format(
                            state,query_execution_id,error_message
                        )
                    error_messages.append(error_message)
                    query_execution_ids.remove(query_response['QueryExecutionId'])
                
                elif state == 'SUCCEEDED':
                    result_location = query_response['QueryExecution']['ResultConfiguration']['OutputLocation']
                    s3_uris.append(result_location)
                    query_execution_ids.remove(query_response['QueryExecutionId'])
                 
                    
        time.sleep(2)
    
    logging.exception(error_messages)
    return s3_uris


DEFAULT_ARGS = {
    'owner': 'ubuntu','depends_on_past': True,'start_date': datetime(2021,6,8),'retries': 0,'concurrency': 2
}

with DAG('resync_job_dag',default_args=DEFAULT_ARGS,schedule_interval=None) as dag:

    ATHENA_QUERY = Pythonoperator(
        task_id='athena_query',python_callable=run_athena_query,provide_context=True,op_kwargs={
            'query': 'SELECT request_timestamp FROM "sampledb"."elb_logs" limit 10;',# query provide in athena tutorial
            'database':'sampledb','output_location':'YOUR_BUCKET','region_name':'YOUR_REGION'
        }
    )

    ATHENA_QUERY

我不知道哪里出错了。希望对这个问题有一些暗示

解决方法

我认为你在这里所做的并不是真正需要的。 您的问题是:

  1. 并行执行多个查询。
  2. 能够恢复每次查询 queryExecutionId

这两个问题都可以通过使用 AWSAthenaOperator 来解决。接线员已经为您处理好了您提到的所有事情。

示例:

from airflow.models import DAG
from airflow.utils.dates import days_ago
from airflow.operators.dummy import DummyOperator
from airflow.providers.amazon.aws.operators.athena import AWSAthenaOperator


with DAG(
    dag_id="athena",schedule_interval='@daily',start_date=days_ago(1),catchup=False,) as dag:

    start_op = DummyOperator(task_id="start_task")
    query_list = ["SELECT 1;","SELECT 2;" "SELECT 3;"]

    for i,sql in enumerate(query_list):
        run_query = AWSAthenaOperator(
            task_id=f'run_query_{i}',query=sql,output_location='s3://my-bucket/my-path/',database='my_database'
        )
        start_op >> query_op

只需向 query_list 添加更多查询,即可动态创建 Athena 任务:

enter image description here

请注意,QueryExecutionIdpushed to xcom,因此您可以在需要时访问下游任务。

,

以下也对我有用。我只是用 asyncio 把简单的问题复杂化了。

由于我最后需要为每个查询使用 S3 URI,因此我从头开始编写脚本。在 AWSAthenaOperator 的当前实现中,可以获取 queryExecutionId,然后进行剩余的处理(即创建另一个任务)以获取 CSV 结果文件的 S3 URI。这会在两个任务(获取 queryExecutionId 和检索 S3 URI)之间的延迟以及增加的资源使用方面增加一些开销。

因此,我在单个运算符中执行完整操作如下:-

代码:-

import json
import time
import asyncio
import boto3
import logging
from airflow import DAG
from airflow.operators.python import PythonOperator


def execute_query(client,query,database,output_location):
    response = client.start_query_execution(
        QueryString=query,QueryExecutionContext={
            'Database': database
        },ResultConfiguration={
            'OutputLocation': output_location
        }
    )

    return response


def run_athena_query(query,output_location,region_name,**context):
    BOTO_SESSION = boto3.Session(
        aws_access_key_id = 'YOUR_KEY',aws_secret_access_key = 'YOUR_ACCESS_KEY')
    client_athena = BOTO_SESSION.client('athena',region_name=region_name)

    query_execution_ids = []
    if message_list:
        for parameter in message_list:
            query_response = execute_query(client_athena,output_location)
            query_execution_ids.append(query_response['QueryExecutionId'])
    else:
        raise Exception(
            'Error in upstream value recived from kafka consumer. Got message list as - {},with type {}'
                .format(message_list,type(message_list))
        )


    repetitions = 900
    error_messages = []
    s3_uris = []

    while repetitions > 0 and len(query_execution_ids) > 0:
        repetitions = repetitions - 1
        
        query_response_list = client_athena.batch_get_query_execution(
            QueryExecutionIds=query_execution_ids)['QueryExecutions']
      
        for query_response in query_response_list:
            if 'QueryExecution' in query_response and \
                    'Status' in query_response['QueryExecution'] and \
                    'State' in query_response['QueryExecution']['Status']:
                state = query_response['QueryExecution']['Status']['State']

                if state in ['FAILED','CANCELLED']:
                    error_reason = query_response['QueryExecution']['Status']['StateChangeReason']
                    error_message = 'Final state of Athena job is {},query_execution_id is {}. Error: {}'.format(
                            state,query_execution_id,error_message
                        )
                    error_messages.append(error_message)
                    query_execution_ids.remove(query_response['QueryExecutionId'])
                
                elif state == 'SUCCEEDED':
                    result_location = query_response['QueryExecution']['ResultConfiguration']['OutputLocation']
                    s3_uris.append(result_location)
                    query_execution_ids.remove(query_response['QueryExecutionId'])
                 
                    
        time.sleep(2)
    
    logging.exception(error_messages)
    return s3_uris


DEFAULT_ARGS = {
    'owner': 'ubuntu','depends_on_past': True,'start_date': datetime(2021,6,8),'retries': 0,'concurrency': 2
}

with DAG('resync_job_dag',default_args=DEFAULT_ARGS,schedule_interval=None) as dag:

    ATHENA_QUERY = PythonOperator(
        task_id='athena_query',python_callable=run_athena_query,provide_context=True,op_kwargs={
            'query': 'SELECT request_timestamp FROM "sampledb"."elb_logs" limit 10;',# query provide in athena tutorial
            'database':'sampledb','output_location':'YOUR_BUCKET','region_name':'YOUR_REGION'
        }
    )

    ATHENA_QUERY

然而,如果想要获得所有查询中的 queryExecutionIds,@Elad 分享的方法更简洁、更贴切。