Custom Window Function in Spark to create Session IDs

If you’ve worked with Spark, you have probably written some custom UDF or UDAFs.
UDFs are ‘User Defined Functions’, so you can introduce complex logic in your queries/jobs, for instance, to calculate a digest for a string, or if you want to use a java/scala library in your queries.

UDAF stands for ‘User Defined Aggregate Function’ and it works on aggregates, so you can implement functions that can be used in a GROUP BY clause, similar to AVG.

You may not be familiar with Window functions, which are similar to aggregate functions, but they add a layer of complexity, since they are applied within a PARTITION BY clause. An example of window function is RANK(). You can read more about window functions here.

While aggregate functions work over a group, window functions work over a logical window of record and allow you to produce new columns from the combination of a record and one or more records in the window.
Describing what window functions are is beyond the scope of this article, so for that refer to the previously mentioned article from Databricks, but in particular, we are interested at the ‘previous event in time for a user’ in order to figure out sessions.

There is plenty of documentation on how to write UDFs and UDAFs, see for instance This link for UDFs or this link for UDAFs.

I was surprised to find out there’s not much info on how to build an custom window function, so I dug up the source code for spark and started looking at how window functions are implemented. That opened to me a whole new world, since Window functions, although conceptually similar to UDAFs, use a lower level Spark API than UDAFs, they are written using Catalyst expressions.

Sessionization basics

Now, for what kind of problem do we need window functions in the first place?
A common problem when working on any kind of website, is to determine ‘user sessions’, periods of user activity. if an user is inactive for a certain time T, then it’s considered a new ‘session’. Statistics over sessions are used to determine for instance if the user is a bot, to find out what pages have the most activity, etc.

Let’s say that we consider a session over if we don’t see any activity for one hour (sixty minutes). Let’s see an example of user activity, where ‘event’ has the name of the page the user visited and time is the time of the event. I simplified it, since the event would be a URL, while the time would be a full timestamp, and the session id would be generated as a random UUID, but I put simpler names/times just to illustrate the logic.

user event time session
user1 page1 10:12session1 (new session)
user1 page2 10:20session1 (same session, 8 minutes from last event)
user1 page1 11:13session1 (same session, 53 minutes from last event)
user1 page3 14:12session2 (new session, 3 hours after last event)

Note that this is the activity for one user. We do have many users, and in fact partitioning by user is the job of the window function.

Digging in

It’s better to use an example to illustrate how the function works in respect of the window definition.
Let’s assume we have a very simple user activity data, with a user ID called user, while ts is a numeric timestamp and session is a session ID, that may be already present. While we may start with no session whatsoever, in most practical cases, we may be processing data hourly, so at hour N + 1 we want to continue the sessions
we calculated at hour n.

Let’s create some test data and show what we want to achieve.

// our Data Definition
case class UserActivityData(user:String, ts:Long, session:String)

// our sample data
val d = Array[UserActivityData](
    UserActivityData("user1",  st, "ss1"),
    UserActivityData("user2",  st +   5*one_minute, null),
    UserActivityData("user1",  st +  10*one_minute, null),
    UserActivityData("user1",  st +  15*one_minute, null),
    UserActivityData("user2",  st +  15*one_minute, null),
    UserActivityData("user1",  st + 140*one_minute, null),
    UserActivityData("user1",  st + 160*one_minute, null))

// creating the DataFrame
val sqlContext = new SQLContext(sc)
val df = sqlContext.createDataFrame(sc.parallelize(d))

// Window specification
val specs = Window.partitionBy(f.col("user")).orderBy(f.col("ts").asc)
// create the session
val res = df.withColumn( "newsession", 
   calculateSession(f.col("ts"), f.col("session")) over specs)

First, the window specification. Sessions are create per user, and the ordering is of course by timestamp.
Hence, we want to apply the function partitionBy user and orderBy timestamp.

We want to write a createSession function that will use the following logic:

IF(no previous event) create new session
ELSE (if current event was past session window)
THEN create new session
ELSE use current session

and will produce something like this:

user ts session newsession
user1 1508863564166 f237e656-1e.. f237e656-1e..
user1 1508864164166 null f237e656-1e..
user1 1508864464166 null f237e656-1e5..
user1 1508871964166 null 51c05c35-6f..
user1 1508873164166 null 51c05c35-6f..
user2 1508863864166 null 2c16b61a-6c..
user2 1508864464166 null 2c16b61a-6c..


Note that we are using random UUIDs as it’s pretty much the standard, and we’re shortening them for typographical reasons.

As you see, for each user, it will create a new session whenever the difference between two events is bigger than the session threshold.

Internally, for every record, we want to keep track of:

  • The current session ID
  • The timestamp of the previous session

This is going to be the state that we must maintain. Spark takes care of initializing it for us.
It is also going to be the parameters the function expects.

Let’s see the skeleton of the function:

// object to collect my UDWFs
object MyUDWF {
  val defaultSessionLengthms = 3600 * 1000 // longer than this, and it's a new session

  case class SessionUDWF(timestamp:Expression, session:Expression,
           sessionWindow:Expression = Literal(defaultMaxSessionLengthms)) 
      extends AggregateWindowFunction {
    self: Product =>

    override def children: Seq[Expression] = Seq(timestamp, session)
    override def dataType: DataType = StringType

    protected val zero = Literal( 0L )
    protected val nullString = Literal(null:String)

    protected val curentSession = AttributeReference("currentSession", 
                   StringType, nullable = true)()
    protected val previousTs =    AttributeReference("previousTs", 
                   LongType, nullable = false)()

    override val aggBufferAttributes: Seq[AttributeReference] =  
                    curentSession  :: previousTs :: Nil

    override val initialValues: Seq[Expression] =  nullString :: zero :: Nil
    override def prettyName: String = "makeSession"

    // we have to write these ones
    override val updateExpressions: Seq[Expression] = ...
    override val evaluateExpression: Expression = ...

A few notes here:

  • Our ‘state’ is going to be a Seq[AttributeReference]
  • Each AttributeReference must be declared with its type. As we said, we keep the current Session and the timestamp of the previous one.
  • We inizialize it by overriding initialValues
  • For every record, within the window, spark will call first updateExpressions, then will produce the values calling evaluateExpression

Now it’s time to implement the updateExpressionsand evaluateExpression functions.

    // this is invoked whenever we need to create a a new session ID. You can use your own logic, here we create UUIDs
    protected val  createNewSession = () => org.apache.spark.unsafe.types.

    // initialize with no session, zero previous timestamp
    override val initialValues: Seq[Expression] =  nullString :: zero :: Nil

    // if a session is already assigned, keep it, otherwise, assign one
    override val updateExpressions: Seq[Expression] =
      If(IsNotNull(session), session, assignSession) ::
        timestamp ::

    // assign session: if previous timestamp was longer than interval, 
    // new session, otherwise, keep current.
    protected val assignSession =  If(LessThanOrEqual(
          Subtract(timestamp, aggBufferAttributes(1)), sessionWindow),
      ScalaUDF( createNewSession, StringType, children = Nil))

    // just return the current session in the buffer
    override val evaluateExpression: Expression = aggBufferAttributes(0)

Notice how we use catalyst expressions, while in normal UDAFs we just use plain scala expressions.

Last thing, we need to declare a static method that we can invoke from the query that will instantiate the function. Notice how I created two, one that allows the user to specify what’s the max duration of a session, and one that takes the default:

  def calculateSession(ts:Column,sess:Column): Column = 
         withExpr { 
           SessionUDWF(ts.expr,sess.expr, Literal(defaultMaxSessionLengthms)) 
  def calculateSession(ts:Column,sess:Column, sessionWindow:Column): Column =
         withExpr { 
            SessionUDWF(ts.expr,sess.expr, sessionWindow.expr) 

Now creating session IDs is as easy as:

// Window specification
val specs = Window.partitionBy(f.col("user")).orderBy(f.col("ts").asc)
// create the session
val res = df.withColumn( "newsession", 
   calculateSession(f.col("ts"), f.col("session"), 
     f.lit(10*1000) over specs) // 10 seconds. Duration is in ms.

Notice that here we specified 10 second sessions.

There’s a little more piping involved which was omitted for clarity, but you can find the complete code, including unit tests, in my github project



