Apache airflow - Learning resources and a real world DAG

airflow
Author

red

Published

August 31, 2022

Used airflow docker image: apache/airflow:2.3.2

These are my two cents on getting started learning apache airflow with my own dockerised airflow instance. The docker setup is described in my previous post, including it here would have made this post even longer.

I will cover:

This is what my quick glance at it produced, maybe it can help you, but there are no guarantees this will work as–is with other airflow versions, but that should be clear in any (open source) software environment.

Learning airflow

Airflow’s documentation may be useful for people who already know what they are doing. Things may have changed since mid 2022 but I found it insufficient when you start from scratch. Especially since the search results within the documentation are less than helpful1.

Web resources where severely lacking and I found that the usual medium posts and youtube videos where thinly disguised adaptions of the official documentation. There is a datacamp course but it does more harm than good since its horrendously out of date. They “teach” the old, pre 2.0 syntax which is deprecated and useful only for getting error messages. What really helped was the book Data Pipelines with Apache Airflow by Harenslak and de Ruiter. With it you can get started to a level from which you can learn by yourself.

A DAG to scrape reddit posts with API authentication

Its a small example of a real world application featuring interaction between host file systems, safe credential storage and basic airflow controls. It should be sufficient to get you started building small (proof-of-concept) real world applications.

Ill try to not go into details you can easily find in the official documentation or can be deduced by reading the code. Instead ill point out the issues that weren’t that easy to find but necessary to build a functioning DAG.

Stepping through bit by bit

DAG definition

local_tz = pendulum.timezone("Europe/Vienna")

with DAG(dag_id="redditscrape_v1", #name of the dag
         default_args={    # set default arguments for all operators
             "retries": 2   
         },
         description="Scrape r/austria subreddit",
         schedule_interval= "0 0/6 * * *", 
         start_date=pendulum.datetime(2022, 7, 17, tz=local_tz),
         catchup=False,
         tags=["redditscrape"],
         ) as dag:
  • schedule_interval is in cron notation which represents MIN HOUR DOM MON DOW.
    • schedule_interval= "0 0/6 * * *" means the DAG is run at minute 0 every 6th hour starting from midnight—in plain English at 6, 12, 18 and 24 o clock.
    • there are a few other ways of doing that, but see the documentation.
  • The start_date=pendulum.datetime(2022, 7, 17, tz=local_tz) sets the first run of the dag anchored in my timezone. The reason for not just putting in an arbitrary date like 2022-01-01 is the potential interaction with catchup. If set to TRUE airflow will attempt to run the DAG for every interval scheduled since that date.
    • This can potentially be very useful when you have a job like computing statistics/run transformations to build up a results table 2.

First task: Check for existence of the results file

scraped_data_present = FileSensor(task_id="sense_scraped_file",
                                  timeout=4*60, # seconds
                                  filepath="scraped_data.csv",
                                  fs_conn_id="fs_data",
                                  poke_interval=60,
                                  
scraped_data_present.doc_md = dedent(
"""
Checks whether the file with the already scraped posts is present in the `opt/airflow/data folder`
""")
  • The FileSensor checks for the existence of a file—here my results container—for the duration of timeout every poke_interval.
    • If no container file is present, the pipeline will fail.
  • The important part is fs_conn_id which tells airflow to use this pre defined connection. In this case this holds the path to check on the host machine.
    • The reason for using a connection, instead of a docker bind mount, is that I can easily manage that from the airflow GUI. All connections are in one place instead of defining them all separately in docker.

File System Connection
  • With scraped_data_present.doc_md You can set provide a docstring in markdown format which will be rendered in the airflow GUIs details of the task.

Second Task: Scraping the data

Before stepping through the function to scrape reddit, lets see the PythonOperator that defines the scrape task

scrape_data = PythonOperator(task_id="scrape_data",
                             python_callable=_redditscrape_job,
                             op_kwargs={"creds": Variable.get("reddit_pw", deserialize_json=True),
                                        "subreddit": "austria",
                                        "limit": 100})
  • The function to call in the operator is passed via python_callable and function arguments are provided via op_kwargs as json.

Passing credentials safely

  • "creds": variable.get("reddit_pw", deserialize_json=True) loads the password and token which is stored by airflow after defining it in the GUI (as a json). deserialize_json=True converts the json into a python dict.

    • The advantages are safe central storage, only airflow dags and privileged users can access the secrets.
  • The remaining kwargs just set the subreddit to scrape and the number of posts to retrieve via limit.

  • To store values in the GUI go to Admin/Variables and store the json under Val.

    • If the element contains either password or secret the values are masked in the GUIs rendered template section.
{
"api_id": "my_login",
"api_secret": "my_api_password",
"password":"my_super_secret_password"
}

Define operator precedence

# create an empty task after scrape_data so there is something to skip.
dummy_task = EmptyOperator(task_id="dummy_task")

# This sets operator precedence, what comes when in the dag.
scraped_data_present >> scrape_data >> Label("To see skipped vals") >> dummy_task
  • The dummy_task is just here so I can see the cases when the scrape_data task was skipped after no new posts where available.
  • In airflow operator task precedence / task dependencies must be specified explicitly by listing them like task_1 >> task_2 >> task3.

The scraping function

auxiliary functions

# aux functions
def _get_date(created_utc):
    return datetime.fromtimestamp(created_utc)

# get api object
def _bot_login(username, pw, client_id, client_secret):
    r = praw.Reddit(username=username,
                    password=pw,
                    client_id=client_id,
                    client_secret=client_secret,
                    user_agent="<user_name_here> tests reddit crawler v0.1")
    return r
  • The _redditscrape_job function relies on three helper functions, one to convert utc_timestamps to human readable time, a generic reddit _bot_login to access the reddit api and the implementation of the scraping _scrape_data itself.

scraping function


def _scrape_data(reddit_con, scraped_ids, limit=1000):
    """Takes a praw.Reddit connector and an scrapes submission text and metadata.
    Parameters
    ----------
    reddit_con : praw.Reddit
        PRAW reddit connector
    scraped_ids : existing post ids
        Existing post ids to load only new posts (by id)
    limit : int, optional
        max number of new submissions to pull, by default 1000
    Returns
    -------
    pd.DataFrame
        DataFrame with title, author, ups, downs, score, url, comms_num, created_utc and body.
    """

    internal_con = reddit_con.new(limit=limit) 
    
    # setting up results container
    topics_dict = {"id": [],
                   "title": [],
                   "author": [],
                   "ups": [],
                   "downs": [],
                   "score": [],
                   "url": [],
                   "comms_num": [],
                   "created_utc": [],
                   "body": []}
  • Firstly, write a (useful) docstring.
  • a new api connection using praw is created with reddit_con.new(limit=limit)
  • The connector object used is passed with the function call.
  • We set up a results dict that has all the fields we want to retain from the json the api yields.
 print("Scraping Data")
  # scraped ids are the unique reddit-post ids that are already present in our results file.
    # each new unique id gets its data added the results dict.

     for submission in internal_con:
        unique_id = submission.id not in tuple(scraped_ids)
        if unique_id:
            topics_dict["id"].append(submission.id)
            topics_dict["title"].append(submission.title)
            topics_dict["author"].append(submission.author)
            topics_dict["ups"].append(submission.ups)
            topics_dict["downs"].append(submission.downs)
            topics_dict["score"].append(submission.score)
            topics_dict["url"].append(submission.url)
            topics_dict["comms_num"].append(submission.num_comments)
            topics_dict["created_utc"].append(submission.created_utc)
            topics_dict["body"].append(submission.selftext)

    # converting to a pandas dataframe
    topics_data = pd.DataFrame(topics_dict)

    if topics_data.empty:
        raise AirflowSkipException("There are no new posts")
    print(f"Found {topics_data.shape[0]} new posts!")

    # convert posted time
    topics_data["timestamp"] = topics_data["created_utc"].apply(_get_date)

    # add timestamp
    topics_data["scraped_time"] = datetime.now().strftime(
        "%Y-%m-%d %H:%m:%S")

    return topics_data
  • The only noteworthy thing is the AirflowSkipException: if its raised the task, and all its subsequent tasks in the pipeline are skipped. This allows to prevent wasting resources if there is nothing to do. Here the exception is raised when there aren’t any new posts.
    • There are more of these exceptions that allow control over pipeline execution.

Finally the job definition

 def _redditscrape_job(creds, subreddit, limit=1000):

    # creating a api connector to reddit
    reddit = _bot_login(username="<my_bot_username>",
                        pw=creds["pw"],
                        client_id=creds["api_id"],
                        client_secret=creds["api_secret"]
                        )

    subreddit = reddit.subreddit(subreddit)
    print(f"Connected to r/{subreddit}")

    # loading data
    print("loading scraped data")
    scraped_data = pd.read_csv("/opt/airflow/data/scraped_data.csv")
    print("loading scraped data succesful")

    new_subs = _scrape_data(subreddit, scraped_data["id"], limit=limit)

    # check for duplicates
    print("Checking for duplicates")

    if new_subs["id"].isin(scraped_data["id"]).any():
        num_dupes = scraped_data["id"].isin(new_subs["id"]).sum()
        print(f"Duplicates found! There are {num_dupes} duplicated lines!")
        raise AssertionError("Duplicated ids found!")
    
    # Display new data
    print("New data rows:", new_subs.shape[0])
    print("Old data rows:", scraped_data.shape[0])
    print(f"Combined data rows: {new_subs.shape[0] + scraped_data.shape[0]}")

    # save new data
    new_subs.to_csv("/opt/airflow/data/new_scraped_subs.csv",
                    header=False, index=False)

    # Append new lines to file
    new_subs.to_csv("/opt/airflow/data/scraped_data.csv", header=False, index=False, mode="a")
    print("Scraping finished")
  • If duplicates are found we can simply fail the pipeline (task) by raising an AssertionError

Complete Code

from datetime import datetime
from textwrap import dedent

import pandas as pd
import praw

import pendulum
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.operators.empty import EmptyOperator
from airflow.exceptions import AirflowSkipException
from airflow.sensors.filesystem import FileSensor
from airflow.models import Variable
from airflow.utils.edgemodifier import Label


# In[ ]:

local_tz = pendulum.timezone("Europe/Vienna")

with DAG(dag_id="redditscrape_v1",
         default_args={
             "retries": 2
         },
         description="Scrape r/austria subreddit",
         schedule_interval="0 0/6 * * *",
         # schedule_interval=timedelta(hours=6),
         start_date=pendulum.datetime(2022, 7, 17, tz=local_tz),
         catchup=False,
         tags=["redditscrape"],
         ) as dag:

    scraped_data_present = FileSensor(task_id="sense_scraped_file",
                                      timeout=4*60,
                                      filepath="scraped_data.csv",
                                      fs_conn_id="fs_data",
                                      poke_interval=60,
                                      )

    scraped_data_present.doc_md = dedent(
        """
        Checks whether the file with the already scraped posts is present in the `opt/airflow/data folder`
        """)

    # aux functions
    def _get_date(created_utc):
        return datetime.fromtimestamp(created_utc)

    # load api credentials

    def _bot_login(username, pw, client_id, client_secret):
        r = praw.Reddit(username=username,
                        password=pw,
                        client_id=client_id,
                        client_secret=client_secret,
                        user_agent="whocares451 tests reddit crawler v0.1")
        return r

    # function to read in data from subreddit

    def _scrape_data(reddit_con, scraped_ids, limit=1000):
        """Takes a praw.Reddit connector and an scrapes submission text and metadata.

        Parameters
        ----------
        reddit_con : praw.Reddit
            PRAW reddit connector
        scraped_ids : existing post ids
            Existing post ids to load only new posts (by id)
        limit : int, optional
            max number of new submissions to pull, by default 1000

        Returns
        -------
        pd.DataFrame
            DataFrame with title, author, ups, downs, score, url, comms_num, created_utc and body.
        """

        internal_con = reddit_con.new(limit=limit)

        # setting up results container
        topics_dict = {"id": [],
                       "title": [],
                       "author": [],
                       "ups": [],
                       "downs": [],
                       "score": [],
                       "url": [],
                       "comms_num": [],
                       "created_utc": [],
                       "body": []}

        print("Scraping Data")

        for submission in internal_con:
            unique_id = submission.id not in tuple(scraped_ids)

            if unique_id:
                topics_dict["id"].append(submission.id)
                topics_dict["title"].append(submission.title)
                topics_dict["author"].append(submission.author)
                topics_dict["ups"].append(submission.ups)
                topics_dict["downs"].append(submission.downs)
                topics_dict["score"].append(submission.score)
                topics_dict["url"].append(submission.url)
                topics_dict["comms_num"].append(submission.num_comments)
                topics_dict["created_utc"].append(submission.created_utc)
                topics_dict["body"].append(submission.selftext)

        # converting to a pandas dataframe
        topics_data = pd.DataFrame(topics_dict)

        if topics_data.empty:
            raise AirflowSkipException("There are no new posts")

        print(f"Found {topics_data.shape[0]} new posts!")

        # convert posted time
        topics_data["timestamp"] = topics_data["created_utc"].apply(_get_date)

        # add timestamp
        topics_data["scraped_time"] = datetime.now().strftime(
            "%Y-%m-%d %H:%m:%S")

        return topics_data

    def _redditscrape_job(creds, subreddit, limit=1000):

        reddit = _bot_login(username="searchingforsram",
                            pw=creds["pw"],
                            client_id=creds["api_id"],
                            client_secret=creds["api_secret"]
                            )
        subreddit = reddit.subreddit(subreddit)
        print(f"Connected to r/{subreddit}")

        # loading data

        print("loading scraped data")
        scraped_data = pd.read_csv("/opt/airflow/data/scraped_data.csv")
        print("loading scraped data succesful")

        new_subs = _scrape_data(subreddit, scraped_data["id"], limit=limit)

        # check for duplicates

        print("Checking for duplicates")

        if new_subs["id"].isin(scraped_data["id"]).any():
            num_dupes = scraped_data["id"].isin(new_subs["id"]).sum()
            print(f"Duplicates found! There are {num_dupes} duplicated lines!")
            raise AssertionError("Duplicated ids found!")

        # Display new data
        print("New data rows:", new_subs.shape[0])
        print("Old data rows:", scraped_data.shape[0])
        print(
            f"Combined data rows: {new_subs.shape[0] + scraped_data.shape[0]}")

        # save new data
        new_subs.to_csv("/opt/airflow/data/new_scraped_subs.csv",
                        header=False, index=False)

        # Append new lines to file
        new_subs.to_csv("/opt/airflow/data/scraped_data.csv",
                        header=False, index=False, mode="a")
        print("Scraping finished")

    scrape_data = PythonOperator(task_id="scrape_data",
                                 python_callable=_redditscrape_job,
                                 op_kwargs={"creds": Variable.get("reddit_pw", deserialize_json=True),
                                            "subreddit": "austria",
                                            "limit": 100})

    scrape_data.doc_md = dedent(
        """Scrapes reddit posts and appends to file if unique
        Establishes a connection using `praw.Reddit` implemented in `_bot_login`, checks for duplicates
        against a already existing file (with data) and appends newly scraped posts.
        Newly scraped posts are saved in an intermediate file `new_scraped_subs.csv`.
        
        The reason that login-scraping-checks-append is in one task:
        - Present data is already loaded, if duplicate checks where in another fun we would need to reload the data.
        - There is little sense in scraping and saving data (to an extra file) and then checking for dupes,
        something is wrong!
        """
    )

    dummy_task = EmptyOperator(task_id="dummy_task")


scraped_data_present >> scrape_data >> Label("To see skipped vals") >> dummy_task

Footnotes

  1. Im used to working with complex frameworks which sometimes have less than ideal documentation, but airflow was still challenging.↩︎

  2. An example is computing KPIs for each year-quarter combination that would take to long to query live for a dashboard. Therefore you can build a table for all defined periods in time with airflow, which of course updates the data in the future. Say you need to change xyz or made an error, just re-run the entire pipeline with catchup and the table is build for all defined points in time.↩︎