Есть ли способ передать предельный параметр functions.collect_set в Spark?
Я имею дело с столбцом чисел в большом искровом DataFrame, и я хотел бы создать новый столбец, который хранит объединенный список уникальных чисел, которые появляются в этом столбце.
В основном именно то, что выполняет функция .collect_set. Тем не менее, мне нужно всего до 1000 элементов в агрегированном списке. Можно ли каким-либо образом передать этот параметр функции functions.collect_set() или любым другим способом получить только до 1000 элементов в агрегированном списке без использования UDAF?
Поскольку столбец настолько велик, я хотел бы избежать сбора всех элементов и последующего обрезки списка.
Спасибо!
Ответы
Ответ 1
Мое решение очень похоже на ответ Локи с collect_set_limit
.
Я бы использовал UDF, который будет делать то, что вы хотите после collect_set
(или collect_list
) или гораздо более сложного UDAF.
Учитывая больший опыт работы с UDF, я бы пошел с этим первым. Хотя пользовательские функции не оптимизированы, для этого варианта использования это нормально.
val limitUDF = udf { (nums: Seq[Long], limit: Int) => nums.take(limit) }
val sample = spark.range(50).withColumn("key", $"id" % 5)
scala> sample.groupBy("key").agg(collect_set("id") as "all").show(false)
+---+--------------------------------------+
|key|all |
+---+--------------------------------------+
|0 |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|
|1 |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|
|3 |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|
|2 |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|
|4 |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|
+---+--------------------------------------+
scala> sample.
groupBy("key").
agg(collect_set("id") as "all").
withColumn("limit(3)", limitUDF($"all", lit(3))).
show(false)
+---+--------------------------------------+------------+
|key|all |limit(3) |
+---+--------------------------------------+------------+
|0 |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|[0, 15, 30] |
|1 |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|[1, 16, 31] |
|3 |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|[33, 48, 13]|
|2 |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|[12, 27, 37]|
|4 |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|[9, 19, 34] |
+---+--------------------------------------+------------+
См функции объекта (для udf
функции Docs).
Ответ 2
Я использую измененную копию функций collect_set и collect_list; из-за областей кода измененные копии должны быть в том же пути пакета, что и оригиналы. Связанный код работает для Spark 2.1.0; если вы используете предыдущую версию, сигнатуры методов могут отличаться.
Бросьте этот файл (https://gist.github.com/lokkju/06323e88746c85b2ce4de3ea9cdef9bc) в свой проект как src/main/org/apache/spark/sql/catal/expression/collect_limit.scala
используйте его как:
import org.apache.spark.sql.catalyst.expression.collect_limit._
df.groupBy('set_col).agg(collect_set_limit('set_col,1000)
Ответ 3
scala> df.show
+---+-----+----+--------+
| C0| C1| C2| C3|
+---+-----+----+--------+
| 10| Name|2016| Country|
| 11|Name1|2016|country1|
| 10| Name|2016| Country|
| 10| Name|2016| Country|
| 12|Name2|2017|Country2|
+---+-----+----+--------+
scala> df.groupBy("C1").agg(sum("C0"))
res36: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]
scala> res36.show
+-----+-------+
| C1|sum(C0)|
+-----+-------+
|Name1| 11|
|Name2| 12|
| Name| 30|
+-----+-------+
scala> df.limit(2).groupBy("C1").agg(sum("C0"))
res33: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]
scala> res33.show
+-----+-------+
| C1|sum(C0)|
+-----+-------+
| Name| 10|
|Name1| 11|
+-----+-------+
scala> df.groupBy("C1").agg(sum("C0")).limit(2)
res2: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]
scala> res2.show
+-----+-------+
| C1|sum(C0)|
+-----+-------+
|Name1| 11|
|Name2| 12|
+-----+-------+
scala> df.distinct
res8: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string]
scala> res8.show
+---+-----+----+--------+
| C0| C1| C2| C3|
+---+-----+----+--------+
| 11|Name1|2016|country1|
| 10| Name|2016| Country|
| 12|Name2|2017|Country2|
+---+-----+----+--------+
scala> df.dropDuplicates(Array("c1"))
res11: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string]
scala> res11.show
+---+-----+----+--------+
| C0| C1| C2| C3|
+---+-----+----+--------+
| 11|Name1|2016|country1|
| 12|Name2|2017|Country2|
| 10| Name|2016| Country|
+---+-----+----+--------+
Ответ 4
использовать
val firstThousand = rdd.take(1000)
Вернет первый 1000.
Сбор также имеет функцию фильтра, которая может быть предоставлена. Это позволит вам быть более конкретным относительно того, что возвращается.
Ответ 5
Как уже упоминалось в других ответах, эффективный способ сделать это - написать UDAF. К сожалению, API UDAF на самом деле не так расширяемо, как агрегатные функции, которые поставляются с искрой. Однако вы можете использовать их внутренние API-интерфейсы, чтобы использовать внутренние функции и делать то, что вам нужно.
Вот реализация для collect_set_limit
которая в основном является копией прошлой внутренней Spark-функции CollectSet
. Я бы просто расширил его, но это класс дела. На самом деле все, что нужно, это переопределить методы update и merge для соблюдения переданного ограничения:
case class CollectSetLimit(
child: Expression,
limitExp: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] {
val limit = limitExp.eval( null ).asInstanceOf[Int]
def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty
override def update(buffer: mutable.HashSet[Any], input: InternalRow): mutable.HashSet[Any] = {
if( buffer.size < limit ) super.update(buffer, input)
else buffer
}
override def merge(buffer: mutable.HashSet[Any], other: mutable.HashSet[Any]): mutable.HashSet[Any] = {
if( buffer.size >= limit ) buffer
else buffer ++= other.take( limit - buffer.size )
}
override def prettyName: String = "collect_set_limit"
}
И чтобы фактически зарегистрировать это, мы можем сделать это через внутреннюю FunctionRegistry
Spark FunctionRegistry
которая берет имя и конструктор, который фактически является функцией, которая создает CollectSetLimit
используя предоставленные выражения:
val collectSetBuilder = (args: Seq[Expression]) => CollectSetLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_set_limit", collectSetBuilder )
Редактировать:
Оказывается, добавление его во встроенную систему работает только в том случае, если вы еще не создали SparkContext, поскольку при запуске он становится неизменным клоном. Если у вас есть существующий контекст, то это должно работать, чтобы добавить его с отражением:
val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_set_limit", collectSetBuilder )