Как получить другие столбцы при использовании Spark DataFrame groupby?
когда я использую группу DataFrame следующим образом:
df.groupBy(df("age")).agg(Map("id"->"count"))
Я получаю только DataFrame с столбцами "age" и "count (id)", но в df есть много других столбцов, таких как "name".
В целом, я хочу получить результат как в MySQL,
"выберите имя, возраст, счетчик (id) из группы df по возрасту"
Что делать, если вы используете groupby в Spark?
Ответы
Ответ 1
Короче говоря, вам нужно объединить результаты с исходной таблицей. Spark SQL следует тому же пред-SQL: 1999 как большинство основных баз данных (PostgreSQL, Oracle, MS SQL Server), который не позволяет добавлять дополнительные столбцы в запросы агрегации.
Так как для агрегаций, таких как результаты подсчета, недостаточно четко определены, и поведение, как правило, меняется в системах, поддерживающих этот тип запросов, вы можете просто добавить дополнительные столбцы, используя произвольный агрегат, например first
или last
.
В некоторых случаях вы можете заменить agg
на select
на функции окна и последующие where
, но в зависимости от контекста это может быть довольно дорого.
Ответ 2
Один из способов получить все столбцы после выполнения groupBy - использовать функцию соединения.
feature_group = ['name', 'age']
data_counts = df.groupBy(feature_group).count().alias("counts")
data_joined = df.join(data_counts, feature_group)
data_joined теперь будет иметь все столбцы, включая значения count.
Ответ 3
Может быть, это решение будет полезным.
from pyspark.sql import SQLContext
from pyspark import SparkContext, SparkConf
from pyspark.sql import functions as F
from pyspark.sql import Window
name_list = [(101, 'abc', 24), (102, 'cde', 24), (103, 'efg', 22), (104, 'ghi', 21),
(105, 'ijk', 20), (106, 'klm', 19), (107, 'mno', 18), (108, 'pqr', 18),
(109, 'rst', 26), (110, 'tuv', 27), (111, 'pqr', 18), (112, 'rst', 28), (113, 'tuv', 29)]
age_w = Window.partitionBy("age")
name_age_df = sqlContext.createDataFrame(name_list, ['id', 'name', 'age'])
name_age_count_df = name_age_df.withColumn("count", F.count("id").over(age_w)).orderBy("count")
name_age_count_df.show()
Выход:
+---+----+---+-----+
| id|name|age|count|
+---+----+---+-----+
|109| rst| 26| 1|
|113| tuv| 29| 1|
|110| tuv| 27| 1|
|106| klm| 19| 1|
|103| efg| 22| 1|
|104| ghi| 21| 1|
|105| ijk| 20| 1|
|112| rst| 28| 1|
|101| abc| 24| 2|
|102| cde| 24| 2|
|107| mno| 18| 3|
|111| pqr| 18| 3|
|108| pqr| 18| 3|
+---+----+---+-----+
Ответ 4
Агрегатные функции уменьшают значения строк для указанных столбцов в группе. Если вы хотите сохранить другие значения строк, вам необходимо реализовать логику сокращения, которая определяет строку, из которой происходит каждое значение. Например, сохраните все значения первого ряда с максимальным значением возраста. Для этого вы можете использовать UDAF (пользовательскую статистическую функцию), чтобы уменьшить количество строк в группе.
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
object AggregateKeepingRowJob {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName(this.getClass.getName.replace("$", ""))
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
sc.setLogLevel("ERROR")
import sparkSession.sqlContext.implicits._
val rawDf = Seq(
(1L, "Moe", "Slap", 2.0, 18),
(2L, "Larry", "Spank", 3.0, 15),
(3L, "Curly", "Twist", 5.0, 15),
(4L, "Laurel", "Whimper", 3.0, 15),
(5L, "Hardy", "Laugh", 6.0, 15),
(6L, "Charley", "Ignore", 5.0, 5)
).toDF("id", "name", "requisite", "money", "age")
rawDf.show(false)
rawDf.printSchema
val maxAgeUdaf = new KeepRowWithMaxAge
val aggDf = rawDf
.groupBy("age")
.agg(
count("id"),
max(col("money")),
maxAgeUdaf(
col("id"),
col("name"),
col("requisite"),
col("money"),
col("age")).as("KeepRowWithMaxAge")
)
aggDf.printSchema
aggDf.show(false)
}
}
UDAF:
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(
StructField("store", StringType) ::
StructField("prod", StringType) ::
StructField("amt", DoubleType) ::
StructField("units", IntegerType) :: Nil
)
// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
StructField("store", StringType) ::
StructField("prod", StringType) ::
StructField("amt", DoubleType) ::
StructField("units", IntegerType) :: Nil
)
// This is the output type of your aggregation function.
override def dataType: DataType =
StructType((Array(
StructField("store", StringType),
StructField("prod", StringType),
StructField("amt", DoubleType),
StructField("units", IntegerType)
)))
override def deterministic: Boolean = true
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = ""
buffer(1) = ""
buffer(2) = 0.0
buffer(3) = 0
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val amt = buffer.getAs[Double](2)
val candidateAmt = input.getAs[Double](2)
amt match {
case a if a < candidateAmt =>
buffer(0) = input.getAs[String](0)
buffer(1) = input.getAs[String](1)
buffer(2) = input.getAs[Double](2)
buffer(3) = input.getAs[Int](3)
case _ =>
}
}
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer2.getAs[String](0)
buffer1(1) = buffer2.getAs[String](1)
buffer1(2) = buffer2.getAs[Double](2)
buffer1(3) = buffer2.getAs[Int](3)
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
buffer
}
}
Ответ 5
Вот пример, с которым я столкнулся в
искрового мастерской
val populationDF = spark.read
.option("infer-schema", "true")
.option("header", "true")
.format("csv").load("file:///databricks/driver/population.csv")
.select('name, regexp_replace(col("population"), "\\s", "").cast("integer").as("population"))
val maxPopulationDF = populationDF.agg(max('population).as("populationmax"))
Чтобы получить другие столбцы, я делаю простое соединение между исходным DF и агрегированным
populationDF.join(maxPopulationDF,populationDF.col("population") === maxPopulationDF.col("populationmax")).select('name, 'populationmax).show()
Ответ 6
Вы должны помнить, что агрегатные функции сокращают строки, и поэтому вам нужно указать, какое из названий строк вы хотите использовать с помощью сокращающей функции. Если вы хотите сохранить все строки группы (предупреждение! Это может привести к взрывам или перекосам), вы можете собрать их в виде списка. Затем вы можете использовать UDF (пользовательскую функцию), чтобы уменьшить их по вашим критериям, в моем примере деньги. А затем разверните столбцы из одной уменьшенной строки с помощью другой пользовательской функции. Для целей этого ответа я предполагаю, что вы хотите сохранить имя человека, у которого больше всего денег.
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StringType
import scala.collection.mutable
object TestJob3 {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName(this.getClass.getName.replace("$", ""))
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
import sparkSession.sqlContext.implicits._
val rawDf = Seq(
(1, "Moe", "Slap", 2.0, 18),
(2, "Larry", "Spank", 3.0, 15),
(3, "Curly", "Twist", 5.0, 15),
(4, "Laurel", "Whimper", 3.0, 9),
(5, "Hardy", "Laugh", 6.0, 18),
(6, "Charley", "Ignore", 5.0, 5)
).toDF("id", "name", "requisite", "money", "age")
rawDf.show(false)
rawDf.printSchema
val rawSchema = rawDf.schema
val fUdf = udf(reduceByMoney, rawSchema)
val nameUdf = udf(extractName, StringType)
val aggDf = rawDf
.groupBy("age")
.agg(
count(struct("*")).as("count"),
max(col("money")),
collect_list(struct("*")).as("horizontal")
)
.withColumn("short", fUdf($"horizontal"))
.withColumn("name", nameUdf($"short"))
.drop("horizontal")
aggDf.printSchema
aggDf.show(false)
}
def reduceByMoney= (x: Any) => {
val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
val red = d.reduce((r1, r2) => {
val money1 = r1.getAs[Double]("money")
val money2 = r2.getAs[Double]("money")
val r3 = money1 match {
case a if a >= money2 =>
r1
case _ =>
r2
}
r3
})
red
}
def extractName = (x: Any) => {
val d = x.asInstanceOf[GenericRowWithSchema]
d.getAs[String]("name")
}
}
вот вывод
+---+-----+----------+----------------------------+-------+
|age|count|max(money)|short |name |
+---+-----+----------+----------------------------+-------+
|5 |1 |5.0 |[6, Charley, Ignore, 5.0, 5]|Charley|
|15 |2 |5.0 |[3, Curly, Twist, 5.0, 15] |Curly |
|9 |1 |3.0 |[4, Laurel, Whimper, 3.0, 9]|Laurel |
|18 |2 |6.0 |[5, Hardy, Laugh, 6.0, 18] |Hardy |
+---+-----+----------+----------------------------+-------+
Ответ 7
Вы можете сделать следующее:
Пример данных:
name age id
abc 24 1001
cde 24 1002
efg 22 1003
ghi 21 1004
ijk 20 1005
klm 19 1006
mno 18 1007
pqr 18 1008
rst 26 1009
tuv 27 1010
pqr 18 1012
rst 28 1013
tuv 29 1011
df.select("name","age","id").groupBy("name","age").count().show();
Вывод:
+----+---+-----+
|name|age|count|
+----+---+-----+
| efg| 22| 1|
| tuv| 29| 1|
| rst| 28| 1|
| klm| 19| 1|
| pqr| 18| 2|
| cde| 24| 1|
| tuv| 27| 1|
| ijk| 20| 1|
| abc| 24| 1|
| mno| 18| 1|
| ghi| 21| 1|
| rst| 26| 1|
+----+---+-----+