一. 需求
1.1 需求简介
这里的热门商品是从点击量的维度来看的.
计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。
1.2 思路分析
使用 sql 来完成. 碰到复杂的需求, 可以使用 udf 或 udaf
查询出来所有的点击记录, 并与 city_info 表连接, 得到每个城市所在的地区. 与 Product_info 表连接得到产品名称
按照地区和商品 id 分组, 统计出每个商品在每个地区的总点击次数
每个地区内按照点击次数降序排列
只取前三名. 并把结果保存在数据库中
城市备注需要自定义 UDAF 函数
二. 实际操作
1. 准备数据
  我们这次 Spark-sql 操作中所有的数据均来自 Hive.
  首先在 Hive 中创建表, 并导入数据.
  一共有 3 张表: 1 张用户行为表, 1 张城市表, 1 张产品表
1. 打开Hive
2. 创建三个表
|
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
CREATE TABLE `user_visit_action`( `date` string, `user_id` bigint, `session_id` string, `page_id` bigint, `action_time` string, `search_keyword` string, `click_category_id` bigint, `click_product_id` bigint, `order_category_ids` string, `order_product_ids` string, `pay_category_ids` string, `pay_product_ids` string, `city_id` bigint)row format delimited fields terminated by '\t';CREATE TABLE `product_info`( `product_id` bigint, `product_name` string, `extend_info` string)row format delimited fields terminated by '\t';CREATE TABLE `city_info`( `city_id` bigint, `city_name` string, `area` string)row format delimited fields terminated by '\t'; |
3. 上传数据
|
1
2
3
|
load data local inpath '/opt/module/datas/user_visit_action.txt' into table spark0806.user_visit_action;load data local inpath '/opt/module/datas/product_info.txt' into table spark0806.product_info;load data local inpath '/opt/module/datas/city_info.txt' into table spark0806.city_info; |
4. 测试是否上传成功
|
1
|
hive> select * from city_info; |
2. 显示各区域热门商品 Top3
|
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
// user_visit_action product_info city_info1. 先把需要的字段查出来 t1select ci.*, pi.product_name, click_product_idfrom user_visit_action uvajoin product_info pi on uva.click_product_id=pi.product_idjoin city_info ci on uva.city_id=ci.city_id2. 按照地区和商品名称聚合select area, product_name, count(*) countfrom t1group by area , product_name3. 按照地区进行分组开窗 排序 开窗函数 t3 // (rank(1 2 2 4 5...) row_number(1 2 3 4...) dense_rank(1 2 2 3 4...))select area, product_name, count, rank() over(partition by area order by count desc)from t24. 过滤出来名次小于等于3的select area, product_name, countfrom t3where rk <=3 |
2. 运行结果
3. 定义udaf函数 得到需求结果
|
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
|
package com.buwenbuhuo.spark.sql.projectimport java.text.DecimalFormatimport org.apache.spark.sql.Rowimport org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}import org.apache.spark.sql.types._/** ** * * @author 不温卜火 * * * @create 2020-08-06 13:24 ** * MyCSDN : [url=https://buwenbuhuo.blog.csdn.net/]https://buwenbuhuo.blog.csdn.net/[/url] * */class CityRemarkUDAF extends UserDefinedAggregateFunction { // 输入数据的类型: 北京 String override def inputSchema: StructType = { StructType(Array(StructField("city", StringType))) } // 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000 Map, 总的点击量 1000/? override def bufferSchema: StructType = { StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType))) } // 输出的数据类型 "北京21.2%,天津13.2%,其他65.6%" String override def dataType: DataType = StringType // 相同的输入是否应用有相同的输出. override def deterministic: Boolean = true // 给存储数据初始化 override def initialize(buffer: MutableAggregationBuffer): Unit = { //初始化map缓存 buffer(0) = Map[String, Long]() // 初始化总的点击量 buffer(1) = 0L } // 分区内合并 Map[城市名, 点击量] override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { input match { case Row(cityName: String) => // 1. 总的点击量 + 1 buffer(1) = buffer.getLong(1) + 1L // 2. 给这个城市的点击量 +1 => 找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去 val map: collection.Map[String, Long] = buffer.getMap[String, Long](0) buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L)) case _ => } } // 分区间的合并 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { val map1 = buffer1.getAs[Map[String, Long]](0) val map2 = buffer2.getAs[Map[String, Long]](0) val total1: Long = buffer1.getLong(1) val total2: Long = buffer2.getLong(1) // 1. 总数的聚合 buffer1(1) = total1 + total2 // 2. map的聚合 buffer1(0) = map1.foldLeft(map2) { case (map, (cityName, count)) => map + (cityName -> (map.getOrElse(cityName, 0L) + count)) } } // 最终的输出结果 override def evaluate(buffer: Row): Any = { // "北京21.2%,天津13.2%,其他65.6%" val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](0) val total: Long = buffer.getLong(1) val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2) var cityRemarks: List[CityRemark] = cityCountTop2.map { case (cityName, count) => CityRemark(cityName, count.toDouble / total) }// CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio)) cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio)) cityRemarks.mkString(",") }}case class CityRemark(cityName: String, cityRatio: Double) { val formatter = new DecimalFormat("0.00%") override def toString: String = s"$cityName:${formatter.format(cityRatio)}"} |
运行结果
4 .保存到Mysql
1. 源码
|
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
|
val props: Properties = new Properties()props.put("user","root")props.put("password","199712")spark.sql( """ |select | area, | product_name, | count, | remark |from t3 |where rk<=3 |""".stripMargin) .coalesce(1) .write .mode("overwrite") .jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props) |
2.运行结果
三. 完整代码
1. udaf
|
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
|
package com.buwenbuhuo.spark.sql.projectimport java.text.DecimalFormatimport org.apache.spark.sql.Rowimport org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}import org.apache.spark.sql.types._/** ** * * @author 不温卜火 * * * @create 2020-08-06 13:24 ** * MyCSDN : [url=https://buwenbuhuo.blog.csdn.net/]https://buwenbuhuo.blog.csdn.net/[/url] * */class CityRemarkUDAF extends UserDefinedAggregateFunction { // 输入数据的类型: 北京 String override def inputSchema: StructType = { StructType(Array(StructField("city", StringType))) } // 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000 Map, 总的点击量 1000/? override def bufferSchema: StructType = { StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType))) } // 输出的数据类型 "北京21.2%,天津13.2%,其他65.6%" String override def dataType: DataType = StringType // 相同的输入是否应用有相同的输出. override def deterministic: Boolean = true // 给存储数据初始化 override def initialize(buffer: MutableAggregationBuffer): Unit = { //初始化map缓存 buffer(0) = Map[String, Long]() // 初始化总的点击量 buffer(1) = 0L } // 分区内合并 Map[城市名, 点击量] override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { input match { case Row(cityName: String) => // 1. 总的点击量 + 1 buffer(1) = buffer.getLong(1) + 1L // 2. 给这个城市的点击量 +1 => 找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去 val map: collection.Map[String, Long] = buffer.getMap[String, Long](0) buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L)) case _ => } } // 分区间的合并 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { val map1 = buffer1.getAs[Map[String, Long]](0) val map2 = buffer2.getAs[Map[String, Long]](0) val total1: Long = buffer1.getLong(1) val total2: Long = buffer2.getLong(1) // 1. 总数的聚合 buffer1(1) = total1 + total2 // 2. map的聚合 buffer1(0) = map1.foldLeft(map2) { case (map, (cityName, count)) => map + (cityName -> (map.getOrElse(cityName, 0L) + count)) } } // 最终的输出结果 override def evaluate(buffer: Row): Any = { // "北京21.2%,天津13.2%,其他65.6%" val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](0) val total: Long = buffer.getLong(1) val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2) var cityRemarks: List[CityRemark] = cityCountTop2.map { case (cityName, count) => CityRemark(cityName, count.toDouble / total) }// CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio)) cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio)) cityRemarks.mkString(",") }}case class CityRemark(cityName: String, cityRatio: Double) { val formatter = new DecimalFormat("0.00%") override def toString: String = s"$cityName:${formatter.format(cityRatio)}"} |
2. 主程序(具体实现)
|
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
|
package com.buwenbuhuo.spark.sql.projectimport java.util.Propertiesimport org.apache.spark.sql.SparkSession/** ** * * @author 不温卜火 * * * @create 2020-08-05 19:01 ** * MyCSDN : [url=https://buwenbuhuo.blog.csdn.net/]https://buwenbuhuo.blog.csdn.net/[/url] * */object SqlApp { def main(args: Array[String]): Unit = { val spark: SparkSession = SparkSession .builder() .master("local") .appName("SqlApp") .enableHiveSupport() .getOrCreate() import spark.implicits._ spark.udf.register("remark",new CityRemarkUDAF) // 去执行sql,从hive查询数据 spark.sql("use spark0806") spark.sql( """ |select | ci.*, | pi.product_name, | uva.click_product_id |from user_visit_action uva |join product_info pi on uva.click_product_id=pi.product_id |join city_info ci on uva.city_id=ci.city_id | |""".stripMargin).createOrReplaceTempView("t1") spark.sql( """ |select | area, | product_name, | count(*) count, | remark(city_name) remark |from t1 |group by area, product_name |""".stripMargin).createOrReplaceTempView("t2") spark.sql( """ |select | area, | product_name, | count, | remark, | rank() over(partition by area order by count desc) rk |from t2 |""".stripMargin).createOrReplaceTempView("t3") val props: Properties = new Properties() props.put("user","root") props.put("password","199712") spark.sql( """ |select | area, | product_name, | count, | remark |from t3 |where rk<=3 |""".stripMargin) .coalesce(1) .write .mode("overwrite") .jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props) // 把结果写入到mysql中 spark.close() }} |