将任意Java对象RDD转换成DataFrame,rdddataframe
分享于 点击 42855 次 点评:34
将任意Java对象RDD转换成DataFrame,rdddataframe
需求
将任意Java对象RDD转换成DataFrame。
要做到这一点,主要需要如下两步:
- 从Java类中获取StructType
- 将Java对象转换成Row
Spark版本: 1.6.1
准备
- 研究SparkSQL内置的数据类型,做成Java类与SparkSQL类型的映射表
推荐阅读spark源码org.apache.spark.sql.catalyst.ScalaReflection类,其中列举了大部分基础类型与SparkSQL类型的映射。
但我还是重新写了这部分功能,最重要的原因是源码只支持基本类型,对于复杂或嵌套Java类无能为力。
其次,我想支持更多的类型,且我想做到对某些类型的对象进行自定义转换。
比如我遇到的Java类中有个属性为Map<String, Object> parameters; 其中的泛型Object无法映射到任何SparkSQL类型中,
导致StructType无法构建完整,造成不得不放弃一部分数据。
但我的做法是,对泛型未指定或指定为Object的,直接调用toString方法转换为String,可以挽回一部分数据丢失。
还有一些常见的,比如需要将java.util.Date转换为java.sql.Date,将char[]转换为String的。
- 研究 java.lang.reflect.Type
Type接口有一个子类和四个子接口,一个子类为java.lang.Class(最为大众所知),四个子接口为 GenericArrayType, ParameterizedType,
TypeVariable, WildcardType。
开发
开发时间大概两周,运行较为稳定。下面分享代码,发现问题欢迎指正。
外部调用的主要是两个方法
def getStructType(clazz: Class[_]): Option[StructType]
def getRow(clazz: Class[_], obj: Any): Option[Row]
完整代码
import java.lang.reflect.{ GenericArrayType, Modifier, ParameterizedType, Field }
import java.lang.{ Iterable => JIterable }
import java.util.{ Map => JMap }
import scala.collection.JavaConversions._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DataType, StructField, StructType, DecimalType, DataTypes }
import org.apache.spark.sql.types.DataTypes._
/**
* @author yizhu.sun 2016年7月21日
*/
object DataFrameReflectUtil {
/** 成员变量的类型和sparkSQL类型的映射 */
val predefinedDataType: collection.mutable.Map[Class[_], DataType] =
collection.mutable.Map(
(classOf[Boolean], BooleanType),
(classOf[java.lang.Boolean], BooleanType),
(classOf[Byte], ByteType),
(classOf[java.lang.Byte], ByteType),
(classOf[Array[Byte]], BinaryType),
(classOf[Array[java.lang.Byte]], BinaryType),
(classOf[Short], ShortType),
(classOf[java.lang.Short], ShortType),
(classOf[Int], IntegerType),
(classOf[java.lang.Integer], IntegerType),
(classOf[Long], LongType),
(classOf[java.lang.Long], LongType),
(classOf[Float], FloatType),
(classOf[java.lang.Float], FloatType),
(classOf[Double], DoubleType),
(classOf[java.lang.Double], DoubleType),
(classOf[Char], StringType),
(classOf[java.lang.Character], StringType),
(classOf[Array[Char]], StringType),
(classOf[Array[java.lang.Character]], StringType),
(classOf[String], StringType),
(classOf[java.math.BigDecimal], DecimalType.SYSTEM_DEFAULT),
(classOf[java.util.Date], DateType),
(classOf[java.sql.Date], DateType),
(classOf[java.security.Timestamp], TimestampType),
(classOf[java.util.Calendar], CalendarIntervalType),
// 成员为Object类型的,都转为String
(classOf[Any], StringType))
/** 类之间的转换。比如将java.util.Date转换为java.sql.Date */
private val classConverter: Map[Class[_], (Any) => _ <: Any] =
Map(
classOf[java.util.Date] ->
((o: Any) => new java.sql.Date(o.asInstanceOf[java.util.Date].getTime)),
classOf[Char] ->
((o: Any) => o.asInstanceOf[Char].toString),
classOf[java.lang.Character] ->
((o: Any) => o.asInstanceOf[java.lang.Character].toString),
classOf[Array[Char]] ->
((o: Any) => new String(o.asInstanceOf[Array[Char]])),
classOf[Array[java.lang.Character]] ->
((o: Any) => new String(o.asInstanceOf[Array[java.lang.Character]].map(_.charValue))),
classOf[Any] ->
((o: Any) => o.toString))
/** cache of Class -> Option[StructType] */
private val structTypeCache = new org.apache.commons.collections.map.LRUMap(100)
/** cache of java.lang.reflect.Type -> Option[DataType] */
private val dataTypeCache = new org.apache.commons.collections.map.LRUMap(1000)
/** cache of Class -> Array[Field] */
private val classFieldsCache = collection.mutable.Map[Class[_], Array[Field]]()
/** scala.collection.Map 类型的Class的cache */
private val scalaMapClassCache = collection.mutable.Set[Class[_]]()
/** scala.collection.Iterable 类型的Class的cache */
private val scalaIterableClassCache = collection.mutable.Set[Class[_]]()
/** java.util.Map 类型的Class的cache */
private val javaMapClassCache = collection.mutable.Set[Class[_]]()
/** java.lang.Iterable 类型的Class的cache */
private val javaIterableClassCache = collection.mutable.Set[Class[_]]()
// 注意在Scala中Map是Iterable的子类
def isScalaMapClass(clazz: Class[_]) = {
if (scalaMapClassCache.contains(clazz)) true
else if (classOf[Map[_, _]].isAssignableFrom(clazz)) {
scalaMapClassCache += clazz
true
} else false
}
def isScalaIterableClass(clazz: Class[_]) = {
if (scalaIterableClassCache.contains(clazz)) true
else if (classOf[Iterable[_]].isAssignableFrom(clazz)) {
scalaIterableClassCache += clazz
true
} else false
}
def isJavaMapClass(clazz: Class[_]) = {
if (javaMapClassCache.contains(clazz)) true
else if (classOf[JMap[_, _]].isAssignableFrom(clazz)) {
javaMapClassCache += clazz
true
} else false
}
def isJavaIterableClass(clazz: Class[_]) = {
if (javaIterableClassCache.contains(clazz)) true
else if (classOf[JIterable[_]].isAssignableFrom(clazz)) {
javaIterableClassCache += clazz
true
} else false
}
def getFields(clazz: Class[_]) =
classFieldsCache.getOrElseUpdate(clazz, {
val fields = clazz.getDeclaredFields
.filterNot(f => Modifier.isTransient(f.getModifiers))
.flatMap(f =>
getDataType(f.getGenericType) match {
case Some(_) => Some(f)
case None => None
})
fields.foreach(_.setAccessible(true))
fields
})
/**
* 根据Class对象,生成StructType对象。
*/
def getStructType(clazz: Class[_]): Option[StructType] = {
val cachedStructType = structTypeCache.get(clazz)
if (cachedStructType == null) {
val fields = getFields(clazz)
val newStructType =
if (fields.isEmpty) None
else {
val types = fields.map(f => {
val dataType = getDataType(f.getGenericType).get
StructField(f.getName, dataType, true) // 默认所有的字段都可能为空
})
if (types.isEmpty) None else Some(StructType(types))
}
structTypeCache.put(clazz, newStructType)
newStructType
} else cachedStructType.asInstanceOf[Option[StructType]]
}
/**
* 根据java.lang.reflect.Type获取org.apache.spark.sql.types.DataType
* 递归处理嵌套类型
*/
private def getDataType(tp: java.lang.reflect.Type): Option[DataType] = {
val cachedDataType = dataTypeCache.get(tp)
if (cachedDataType == null) {
val newDataType = tp match {
case ptp: ParameterizedType => // 带有泛型的数据类型,e.g. List[String]
val clazz = ptp.getRawType.asInstanceOf[Class[_]]
val rowTypes = ptp.getActualTypeArguments
if (isScalaMapClass(clazz) || isJavaMapClass(clazz)) {
(getDataType(rowTypes(0)), getDataType(rowTypes(1))) match {
case (Some(keyType), Some(valueType)) =>
Some(DataTypes.createMapType(keyType, valueType, true))
case _ => None
}
} else if (isScalaIterableClass(clazz) || isJavaIterableClass(clazz)) {
getDataType(rowTypes(0)) match {
case Some(dataType) => Some(DataTypes.createArrayType(dataType, true))
case None => None
}
} else {
getStructType(clazz)
}
case gatp: GenericArrayType => // 泛型数据类型的数组,e.g. Array[List[String]]
getDataType(gatp.getGenericComponentType) match {
case Some(dataType) => Some(DataTypes.createArrayType(dataType, true))
case None => None
}
case clazz: Class[_] => // 没有泛型的类型(包括没有指定泛型的Map和Collection)
predefinedDataType.get(clazz) match {
case Some(tp) => Some(tp)
case None =>
if (clazz.isArray) { // 非泛型对象的数组
getDataType(clazz.getComponentType) match {
case Some(dataType) => Some(DataTypes.createArrayType(dataType, true))
case None => None
}
} else if (isScalaMapClass(clazz) || isJavaMapClass(clazz)) {
Some(DataTypes.createMapType(StringType, StringType, true))
} else if (isScalaIterableClass(clazz) || isJavaIterableClass(clazz)) {
Some(DataTypes.createArrayType(StringType, true))
} else { // 一般Object类型,转换为嵌套类型
getStructType(clazz)
}
}
case _ =>
throw new IllegalArgumentException("不支持 WildcardType 和 TypeVariable")
}
dataTypeCache.put(tp, newDataType)
newDataType
} else cachedDataType.asInstanceOf[Option[DataType]]
}
/**
* 读取一行数据
*/
def getRow(clazz: Class[_], obj: Any): Option[Row] =
getStructType(clazz) match {
case Some(_) =>
if (obj == null) Some(null)
else Some(Row(getFields(clazz).flatMap(f => getCell(f.getGenericType, f.get(obj))): _*))
case None => None
}
/**
* 读取单个数据
*/
private def getCell(tp: java.lang.reflect.Type, value: Any): Option[Any] =
tp match {
case ptp: ParameterizedType => // 带有泛型的数据类型,e.g. List[String]
val clazz = ptp.getRawType.asInstanceOf[Class[_]]
val rowTypes = ptp.getActualTypeArguments
if (isScalaMapClass(clazz)) {
(getDataType(rowTypes(0)), getDataType(rowTypes(1))) match {
case (Some(keyType), Some(valueType)) =>
if (value == null) Some(null)
else Some(value.asInstanceOf[Map[Any, Any]].filterKeys(_ != null)
.map { case (k, v) => getCell(rowTypes(0), k).get -> getCell(rowTypes(1), v).get })
case _ => None
}
} else if (isScalaIterableClass(clazz)) {
getDataType(rowTypes(0)) match {
case Some(_) =>
if (value == null) Some(null)
else Some(value.asInstanceOf[Iterable[Any]].filter(_ != null).map(v => getCell(rowTypes(0), v).get).toSeq)
case None => None
}
} else if (isJavaIterableClass(clazz)) {
getDataType(rowTypes(0)) match {
case Some(_) =>
if (value == null) Some(null)
else Some(value.asInstanceOf[JIterable[Any]].filter(_ != null).map(v => getCell(rowTypes(0), v).get).toSeq)
case None => None
}
} else if (isJavaMapClass(clazz)) {
(getDataType(rowTypes(0)), getDataType(rowTypes(1))) match {
case (Some(keyType), Some(valueType)) =>
if (value == null) Some(null)
else Some(value.asInstanceOf[JMap[Any, Any]].filterKeys(_ != null)
.map { case (k, v) => getCell(rowTypes(0), k).get -> getCell(rowTypes(1), v).get })
case _ => None
}
} else {
getCell(clazz, value)
}
case gatp: GenericArrayType => // 泛型数据类型的数组,e.g. Array[List[String]]
getDataType(gatp.getGenericComponentType) match {
case Some(dataType) => Some(value.asInstanceOf[Array[Any]].map(v => getCell(gatp.getGenericComponentType, v).get).toSeq)
case None => None
}
case clazz: Class[_] => // 没有泛型的类型(包括没有指定泛型的Map和Collection)
predefinedDataType.get(clazz) match {
case Some(_) =>
classConverter.get(clazz) match {
case Some(converter) => Some(if (value == null) null else converter(value))
case None => Some(value)
}
case None =>
if (clazz.isArray) { // 非泛型对象的数组
getDataType(clazz.getComponentType) match {
case Some(dataType) =>
if (value == null) Some(null)
else Some(value.asInstanceOf[Array[_]].filter(_ != null).flatMap(v => getCell(clazz.getComponentType, v)).toSeq)
case None => None
}
} else if (isScalaMapClass(clazz)) {
Some(value.asInstanceOf[Map[Any, Any]].filterKeys(_ != null)
.map { case (k, v) => getCell(classOf[Any], k).get -> getCell(classOf[Any], v).get })
} else if (isScalaIterableClass(clazz)) {
Some(value.asInstanceOf[Iterable[Any]].filter(_ != null)
.map(v => getCell(classOf[Any], v).get).toSeq)
} else if (isJavaIterableClass(clazz)) {
Some(value.asInstanceOf[JIterable[Any]].filter(_ != null)
.map(v => getCell(classOf[Any], v).get).toSeq)
} else if (isJavaMapClass(clazz)) {
Some(value.asInstanceOf[JMap[Any, Any]].filterKeys(_ != null)
.map { case (k, v) => getCell(classOf[Any], k).get -> getCell(classOf[Any], v).get })
} else { // 一般Object类型,转换为嵌套类型
getRow(clazz, value)
}
}
case _ =>
throw new IllegalArgumentException("不支持 WildcardType 和 TypeVariable")
}
}
构建两个测试类
class TClass(
val list1: List[Array[Char]],
val map1: Map[String, Array[Int]],
val obj1: TInnerClass) extends Serializable
class TInnerClass(
val date1: java.util.Date) extends Serializable
测试代码
// sc: SparkContext
// ssc: SQLContext
val obj1 = new TClass(
List(Array('1', '2', '3'), null),
Map("123" -> Array(1, 2, 3),
"nil" -> null),
new TInnerClass(new java.util.Date))
val obj2 = new TClass(
List(Array('1', '2', '3'), null),
Map("empty" -> Array(),
"90" -> Array(9, 0)),
new TInnerClass(null))
val tClazz = classOf[TClass]
val rdd = sc.makeRDD(Seq(obj1, obj2))
val rowRDD = rdd.flatMap(DataFrameReflectUtil.getRow(tClazz, _))
DataFrameReflectUtil.getStructType(tClazz) match {
case Some(scheme) =>
val df = ssc.createDataFrame(rowRDD, scheme)
df.registerTempTable("df")
df.printSchema
ssc.sql("select list1, map1, obj1 from df").show(false)
ssc.sql("select map1['90'], map1['90'][0], date_add(obj1.date1, 1) from df").show(false)
case None =>
println("getStructType failed")
}
root
|-- list1: array (nullable = true)
| |-- element: string (containsNull = true)
|-- map1: map (nullable = true)
| |-- key: string
| |-- value: array (valueContainsNull = true)
| | |-- element: integer (containsNull = true)
|-- obj1: struct (nullable = true)
| |-- date1: date (nullable = true)
+-----+------------------------------------------------------+------------+
|list1|map1 |obj1 |
+-----+------------------------------------------------------+------------+
|[123]|Map(123 -> WrappedArray(1, 2, 3), nil -> null) |[2016-09-01]|
|[123]|Map(empty -> WrappedArray(), 90 -> WrappedArray(9, 0))|[null] |
+-----+------------------------------------------------------+------------+
+------+----+----------+
|_c0 |_c1 |_c2 |
+------+----+----------+
|null |null|2016-09-02|
|[9, 0]|9 |null |
+------+----+----------+
相关文章
- 暂无相关文章
用户点评