Apache Spark has very powerful built-in API for gathering data from a relational database. Effectiveness and efficiency, following the usual Spark approach, is managed in a transparent way.
The two basic concepts we have to know when dealing in such scenarios are
- Dataset: a distributed collection of data
- DataFrame: a Dataset organised into named columns
Being conceptually similar to a table in a relational database, the Dataset is the structure that will hold our data:
val dataset = sparkSession.read.jdbc(...);
The jdbc(…) method documentation is a little bit cryptic at least for me, especially about how partitions are working. Here are the parameters:
- url: JDBC database url of the form
jdbc:subprotocol:subname
. - table: the name of the table, in the external database.
- columnName: the name of a column (of the table above) of integral type that will be used for partitioning.
- lowerBound: the minimum value of
columnName
used to decide partition stride. - upperBound: the maximum value of
columnName
used to decide partition stride. - numPartitions: the number of partitions. This, along with lower bound (inclusive), upper bound (exclusive), form partition strides for generated WHERE clause expressions used to split the column evenly. When the input is less than 1, the number is set to 1.
- connectionProperties: JDBC database connection arguments, a list of arbitrary key/value pairs. Normally at least a “user” and “password” property should be included. “fetchSize” can be used to control the number of rows per fetch.
From the list above, the relevant parameters that affect the partitions behaviour are: columnName, lowerBound, upperBound and numPartitions. While the first (columnName) and the last (numPartitions) should be intuitive, upperBound and lowerBound description is not so clear, even after reading the documentation.
What is that “stride“? How is it computed?
Well, while one could say “Why don’t you try on yourself?” therefore relying on a usual “pragmatic” approach, I usually prefer to go deep in order to have a deterministic view about what I’m doing.
Of course, I have nothing against the pragmatic approach: I’m a big fan of StackOverflow and main methods 🙂 but often, there’s another, more precise way to obtain the required information (even so close to that pragmatic approach, as you will see soon).
Assuming you’re familiar with Java and / or Scala, it’s pretty easy to find within the Spark code, where the partition logic resides and the WHERE conditions are built: that place is the JDBCRelation singleton, specifically the columnPartition(…) function:
package org.apache.spark.sql.execution.datasources.jdbc
private object JDBCRelation extends Logging {
...
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
...
val stride =
upperBound / numPartitions - lowerBound / numPartitions
...
}
There, you can see the stride is a kind of “step” used for determining the range of each partition. If we pass the following values:
- column = “table_id”
- lowerBound = 0
- upperBound = 100000
- numPartitions = 10
The stride will have a value of 10000.
How does that stride actually work? If I move the columnPartition code into a main class (here it comes the pragmatic approach), after removing things like logging and return type (in bold) we have a simple method like this:
def columnPartition(...): Unit = {
require(lowerBound <= upperBound)
val numPartitions =
if ((upperBound - lowerBound) >= requestedPartitions) {
requestedPartitions
} else {
upperBound - lowerBound
}
val stride: Long =
upperBound / numPartitions - lowerBound / numPartitions
var i: Int = 0
var currentValue: Long = lowerBound
var ans = new ArrayBuffer[String]()
while (i < numPartitions) {
val lBound = if (i != 0) s"$column >= $currentValue" else null
currentValue += stride
val uBound =
if (i != numPartitions - 1) s"$column < $currentValue" else null
val whereClause =
if (uBound == null) {
lBound
} else if (lBound == null) {
s"$uBound or $column is null"
} else {
s"$lBound AND $uBound"
}
ans += whereClause
i = i + 1
}
ans.foreach(println)
}
Here, you can clearly see the stride is a step used for determining the WHERE clause of each SELECT command that will be executed (one for partition). Specifically, we can execute this method with the sample values above
def main(args: Array[String]): Unit = {
MyObject.columnPartition(0, 100000, 10, "table_id")
}
and we will get this answer:
table_id < 10000 or table_id is null
table_id >= 10000 AND table_id < 20000
table_id >= 20000 AND table_id < 30000
table_id >= 30000 AND table_id < 40000
table_id >= 40000 AND table_id < 50000
table_id >= 50000 AND table_id < 60000
table_id >= 60000 AND table_id < 70000
table_id >= 70000 AND table_id < 80000
table_id >= 80000 AND table_id < 90000
table_id >= 90000
The partition logic makes sure no data is left out, whatever the input parameters are; what you should experiment and tune, which strongly depends on your concrete context, is the number of partitions and the corresponding size. One usual mistake is a wrong size assigned to the last partition: if, in the previous example, the table cardinality is 10.000.000 the last partition:
table_id >= 90000
will have to fetch 9.910.000 rows which is an unbalanced weight compared with the other partitions.