Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I retrieve the alias for a DataFrame in Spark

I'm using Spark 2.0.2. I have a DataFrame that has an alias on it, and I'd like to be able to retrieve that. A simplified example of why I'd want that is below.

def check(ds: DataFrame) = {
   assert(ds.count > 0, s"${df.getAlias} has zero rows!")    
}

The above code of course fails because DataFrame has no getAlias function. Is there a way to do this?

like image 618
Dave DeCaprio Avatar asked Oct 29 '25 15:10

Dave DeCaprio


2 Answers

You can try something like this but I wouldn't go so far to claim it is supported:

  • Spark < 2.1:

    import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias
    import org.apache.spark.sql.Dataset
    
    def getAlias(ds: Dataset[_]) = ds.queryExecution.analyzed match {
      case SubqueryAlias(alias, _) => Some(alias)
      case _ => None
    }
    
  • Spark 2.1+:

    def getAlias(ds: Dataset[_]) = ds.queryExecution.analyzed match {
      case SubqueryAlias(alias, _, _) => Some(alias)
      case _ => None
    }
    

Example usage:

val plain = Seq((1, "foo")).toDF
getAlias(plain)
Option[String] = None
val aliased = plain.alias("a dataset")
getAlias(aliased)
Option[String] = Some(a dataset)
like image 51
zero323 Avatar answered Oct 31 '25 11:10

zero323


Disclaimer: as stated above, this code relies on undocumented APIs subject to change. It works as of Spark 2.3.

After much digging into mostly undocumented Spark methods, here is the full code to pull the list of fields, along with the table alias for a dataframe in PySpark:

def schema_from_plan(df):
    plan = df._jdf.queryExecution().analyzed()
    all_fields = _schema_from_plan(plan)

    iterator = plan.output().iterator()
    output_fields = {}
    while iterator.hasNext():
        field = iterator.next()
        queryfield = all_fields.get(field.exprId().id(),{})
        if not queryfield=={}:
            tablealias = queryfield["tablealias"]
        else:
            tablealias = ""
        output_fields[field.exprId().id()] = {
            "tablealias": tablealias,
            "dataType": field.dataType().typeName(),
            "name": field.name()
        }
    return list(output_fields.values())

def _schema_from_plan(root,tablealias=None,fields={}):
    iterator = root.children().iterator()
    while iterator.hasNext():
        node = iterator.next()
        nodeClass = node.getClass().getSimpleName()
        if (nodeClass=="SubqueryAlias"):
            # get the alias and process the subnodes with this alias
            _schema_from_plan(node,node.alias(),fields)
        else:
            if tablealias:
                # add all the fields, along with the unique IDs, and a new tablealias field            
                iterator = node.output().iterator()
                while iterator.hasNext():
                    field = iterator.next()
                    fields[field.exprId().id()] = {
                        "tablealias": tablealias,
                        "dataType": field.dataType().typeName(),
                        "name": field.name()
                    }
            _schema_from_plan(node,tablealias,fields)
    return fields

# example: fields = schema_from_plan(df)
like image 44
zalmane Avatar answered Oct 31 '25 11:10

zalmane



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!