TechLead
Lesson 9 of 22
6 min read
Data Engineering

Apache Spark Basics

Learn Spark's architecture, DataFrames, transformations, and actions for large-scale distributed data processing

What is Apache Spark?

Apache Spark is an open-source, distributed computing engine designed for large-scale data processing. It can process terabytes to petabytes of data across clusters of machines, providing APIs in Python (PySpark), Scala, Java, and R. Spark is the de facto standard for batch processing at scale and is widely used for ETL pipelines, data lake processing, machine learning, and even streaming.

Spark was developed at UC Berkeley in 2009 as a faster alternative to Hadoop MapReduce. Its key innovation is in-memory computing — rather than writing intermediate results to disk between steps (as MapReduce does), Spark keeps data in memory, making it up to 100x faster for iterative algorithms and interactive queries. Spark's unified engine supports batch processing, streaming, SQL, machine learning, and graph processing.

Spark Architecture

  • Driver: The main program that creates the SparkSession, defines transformations, and coordinates execution across the cluster
  • Executors: Worker processes on cluster nodes that run tasks and store cached data. Each executor has its own memory and CPU cores.
  • Cluster Manager: Allocates resources across the cluster — YARN, Kubernetes, Mesos, or Spark's standalone mode
  • Tasks: The smallest unit of work. A stage is split into tasks, one per partition, and distributed across executors.
  • DAG (Directed Acyclic Graph): Spark builds a DAG of transformations and optimizes execution before running any code

DataFrames: Spark's Primary API

The DataFrame is Spark's primary abstraction — a distributed collection of data organized into named columns, similar to a pandas DataFrame or a database table, but distributed across a cluster. DataFrames provide a high-level API with built-in optimization through the Catalyst query optimizer.

from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, count, sum as spark_sum, avg, max as spark_max,
    year, month, when, lit, concat, upper, trim,
    date_format, datediff, current_date
)
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType,
    DoubleType, DateType, TimestampType
)

# Create a SparkSession
spark = SparkSession.builder     .appName("DataEngineeringPipeline")     .config("spark.sql.adaptive.enabled", "true")     .config("spark.sql.adaptive.coalescePartitions.enabled", "true")     .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")     .getOrCreate()

# Read data from various sources
# Parquet (columnar, compressed — most common for data lakes)
orders = spark.read.parquet("s3://data-lake/raw/orders/")

# CSV with explicit schema
schema = StructType([
    StructField("customer_id", IntegerType(), False),
    StructField("name", StringType(), True),
    StructField("email", StringType(), True),
    StructField("country", StringType(), True),
    StructField("signup_date", DateType(), True),
])

customers = spark.read     .option("header", "true")     .schema(schema)     .csv("s3://data-lake/raw/customers/")

# JSON (semi-structured)
events = spark.read.json("s3://data-lake/raw/clickstream/2025/03/")

# Show schema and sample data
orders.printSchema()
orders.show(5, truncate=False)
print(f"Total orders: {orders.count()}")
print(f"Partitions: {orders.rdd.getNumPartitions()}")

Transformations and Actions

Spark operations are divided into transformations (lazy — define what to do) and actions (eager — trigger execution). Transformations build up a logical plan; actions execute it. This lazy evaluation allows Spark to optimize the entire pipeline before running anything.

# Transformations (lazy — nothing executes yet)

# Select and rename columns
clean_orders = orders.select(
    col("order_id"),
    col("customer_id"),
    col("product_id"),
    col("quantity").cast("int"),
    col("unit_price").cast("double"),
    (col("quantity") * col("unit_price")).alias("total_amount"),
    col("order_date").cast("date"),
    col("status"),
)

# Filter rows
active_orders = clean_orders.filter(
    (col("status") != "cancelled") &
    (col("total_amount") > 0) &
    (col("order_date") >= "2025-01-01")
)

# Add computed columns
enriched = active_orders.withColumn(
    "order_month", date_format("order_date", "yyyy-MM")
).withColumn(
    "amount_tier",
    when(col("total_amount") > 500, "premium")
    .when(col("total_amount") > 100, "standard")
    .otherwise("basic")
).withColumn(
    "days_since_order",
    datediff(current_date(), col("order_date"))
)

# Join with customers
order_details = enriched.join(
    customers.select("customer_id", "name", "country"),
    on="customer_id",
    how="left"
)

# Group and aggregate
monthly_revenue = order_details.groupBy("order_month", "country").agg(
    count("order_id").alias("total_orders"),
    spark_sum("total_amount").alias("revenue"),
    avg("total_amount").alias("avg_order_value"),
    spark_max("total_amount").alias("max_order"),
)

# Actions (trigger execution)
monthly_revenue.show(20)                          # Display results
monthly_revenue.count()                           # Count rows
monthly_revenue.collect()                         # Return all rows to driver
monthly_revenue.write.parquet("s3://output/")     # Write to storage

Performance Optimization

Spark performance depends heavily on how you manage partitions, avoid shuffles, and leverage caching. Understanding these optimization techniques is critical for production pipelines:

Key Optimization Strategies

  • Partition Management: Control partition count with repartition() (full shuffle) or coalesce() (reduce without shuffle). Target 128MB-256MB per partition.
  • Avoid Shuffles: Operations like groupBy, join, and distinct cause expensive shuffles. Use broadcast joins for small tables.
  • Broadcast Joins: When joining a large table with a small table (< 10MB), broadcast the small table to all executors to avoid a shuffle.
  • Caching: Use .cache() or .persist() for DataFrames that are reused multiple times in the pipeline.
  • Predicate Pushdown: Spark pushes filters down to the data source — Parquet files only read relevant row groups and columns.
  • Adaptive Query Execution (AQE): Spark 3.x dynamically adjusts shuffle partitions and join strategies based on runtime statistics.
from pyspark.sql.functions import broadcast

# Broadcast join: small dimension table broadcast to all executors
products = spark.read.parquet("s3://data-lake/dim/products/")  # Small table
orders = spark.read.parquet("s3://data-lake/fact/orders/")     # Large table

# Without broadcast: expensive shuffle join
# enriched = orders.join(products, "product_id")

# With broadcast: no shuffle, much faster
enriched = orders.join(broadcast(products), "product_id")

# Repartition for optimal parallelism
# Too few partitions -> underutilization
# Too many partitions -> overhead from scheduling and small files
optimized = enriched.repartition(200, "order_date")  # Repartition by date

# Coalesce: reduce partitions without shuffle (for writing output)
enriched.coalesce(10).write     .partitionBy("order_date")     .mode("overwrite")     .parquet("s3://data-lake/silver/enriched_orders/")

# Cache for reuse
enriched.cache()
enriched.count()  # Triggers caching

# Now these queries read from memory, not disk
enriched.groupBy("category").count().show()
enriched.filter(col("amount") > 100).count()

Writing Output

# Write to various formats and sinks

# Parquet with partitioning (most common for data lakes)
monthly_revenue.write     .mode("overwrite")     .partitionBy("order_month")     .option("compression", "snappy")     .parquet("s3://data-lake/gold/monthly_revenue/")

# Write to a database (JDBC)
monthly_revenue.write     .format("jdbc")     .option("url", "jdbc:postgresql://warehouse:5432/analytics")     .option("dbtable", "analytics.monthly_revenue")     .option("user", "spark_user")     .option("password", "secret")     .mode("overwrite")     .save()

# Delta Lake (ACID transactions on data lakes)
monthly_revenue.write     .format("delta")     .mode("overwrite")     .save("s3://data-lake/delta/monthly_revenue/")

Spark SQL

Spark also supports standard SQL, which is often preferred by analysts and for complex queries:

# Register DataFrames as temporary SQL views
orders.createOrReplaceTempView("orders")
customers.createOrReplaceTempView("customers")

# Run SQL queries
result = spark.sql("""
    SELECT
        c.country,
        DATE_TRUNC('month', o.order_date) AS month,
        COUNT(DISTINCT o.order_id) AS total_orders,
        SUM(o.quantity * o.unit_price) AS revenue,
        COUNT(DISTINCT o.customer_id) AS unique_customers
    FROM orders o
    JOIN customers c ON o.customer_id = c.customer_id
    WHERE o.status != 'cancelled'
      AND o.order_date >= '2025-01-01'
    GROUP BY c.country, DATE_TRUNC('month', o.order_date)
    ORDER BY month DESC, revenue DESC
""")

result.show(20)

Key Takeaways

  • Spark is the standard for large-scale distributed data processing, supporting batch, streaming, SQL, and ML
  • DataFrames provide a high-level API with built-in optimization through the Catalyst optimizer
  • Transformations are lazy; actions trigger execution — this enables whole-pipeline optimization
  • Use broadcast joins for small tables, manage partition sizes (128-256MB), and enable AQE for automatic tuning
  • Parquet is the preferred format for data lakes — columnar, compressed, with schema and predicate pushdown
  • Spark SQL allows you to write standard SQL against distributed datasets, making it accessible to SQL-fluent teams

Continue Learning