Spark Sql提供了丰富的内置函数让开发者来使用,但实际开发业务场景可能很复杂,内置函数不能够满足业务需求,因此spark sql提供了可扩展的内置函数。
UDF:是普通函数,输入一个或多个参数,返回一个值。比如:len(),isnull()
UDAF:是聚合函数,输入一组值,返回一个聚合结果。比如:max(),avg(),sum()
Spark编写UDF函数
下边的例子是在spark2.0之前的示例:例子中展示只有一个参数输入,和一个参数输出的UDF。
package com.dx.streaming.producer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class TestUDF1 { public static void main(String[] args) { SparkConf sparkConf = new SparkConf(); sparkConf.setMaster("local[2]"); sparkConf.setAppName("spark udf test"); JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf); @SuppressWarnings("deprecation") SQLContext sqlContext=new SQLContext(javaSparkContext); JavaRDD<String> javaRDD = javaSparkContext.parallelize(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu")); JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() { private static final long serialVersionUID = -4769584490875182711L; @Override public Row call(String line) throws Exception { String[] fields = line.split(","); return RowFactory.create(fields); } }); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("id", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> ds = sqlContext.createDataFrame(rowRDD, schema); ds.createOrReplaceTempView("user"); // 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx sqlContext.udf().register("strLength", new UDF1<String, Integer>() { private static final long serialVersionUID = -8172995965965931129L; @Override public Integer call(String t1) throws Exception { return t1.length(); } }, DataTypes.IntegerType); Dataset<Row> rows = sqlContext.sql("select id,name,strLength(name) as length from user"); rows.show(); javaSparkContext.stop(); } }
输出效果:
+---+--------+------+
| id| name|length|
+---+--------+------+
| 1|zhangsan| 8|
| 2| lisi| 4|
| 3| wangwu| 6|
| 4| zhaoliu| 7|
+---+--------+------+
上边使用UDF展示了:单个输入,单个输出的函数。那么下边将会展示使用spark2.0实现三个输入,一个输出的UDF函数。
package com.dx.streaming.producer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF3; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class TestUDF2 { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate(); Dataset<String> row = sparkSession.createDataset(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu"), Encoders.STRING()); // 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx sparkSession.udf().register("strLength", new UDF1<String, Integer>() { private static final long serialVersionUID = -8172995965965931129L; @Override public Integer call(String t1) throws Exception { return t1.length(); } }, DataTypes.IntegerType); sparkSession.udf().register("strConcat", new UDF3<String, String, String, String>() { private static final long serialVersionUID = -8172995965965931129L; @Override public String call(String combChar, String t1, String t2) throws Exception { return t1 + combChar + t2; } }, DataTypes.StringType); showByStruct(sparkSession, row); System.out.println("=========================================="); showBySchema(sparkSession, row); sparkSession.stop(); } private static void showBySchema(SparkSession sparkSession, Dataset<String> row) { JavaRDD<String> javaRDD = row.javaRDD(); JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() { private static final long serialVersionUID = -4769584490875182711L; @Override public Row call(String line) throws Exception { String[] fields = line.split(","); return RowFactory.create(fields); } }); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("id", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema); ds.show(); ds.createOrReplaceTempView("user"); Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('+',id,name) as str from user"); rows.show(); } private static void showByStruct(SparkSession sparkSession, Dataset<String> row) { JavaRDD<Person> map = row.javaRDD().map(Person::parsePerson); Dataset<Row> persons = sparkSession.createDataFrame(map, Person.class); persons.show(); persons.createOrReplaceTempView("user"); Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('-',id,name) as str from user"); rows.show(); } }
Person.java
package com.dx.streaming.producer; import java.io.Serializable; public class Person implements Serializable{ private String id; private String name; public Person(String id, String name) { this.id = id; this.name = name; } public String getId() { return id; } public void setId(String id) { this.id = id; } public String getName() { return name; } public void setName(String name) { this.name = name; } public static Person parsePerson(String line) { String[] fields = line.split(","); Person person = new Person(fields[0], fields[1]); return person; } }