Как получить элемент по индексу в Spark RDD (Java)
Я знаю метод rdd.first(), который дает мне первый элемент в RDD.
Также существует метод rdd.take(num), который дает мне первые "num" элементы.
Но разве нет возможности получить элемент по индексу?
Спасибо.
Ответы
Ответ 1
Это должно быть возможно, сначала индексируя RDD. Преобразование zipWithIndex
обеспечивает стабильную индексацию, нумерацию каждого элемента в исходном порядке.
Учитывая: rdd = (a,b,c)
val withIndex = rdd.zipWithIndex // ((a,0),(b,1),(c,2))
Чтобы найти элемент по индексу, эта форма не полезна. Сначала нам нужно использовать индекс как ключ:
val indexKey = withIndex.map{case (k,v) => (v,k)} //((0,a),(1,b),(2,c))
Теперь можно использовать действие lookup
в PairRDD для поиска элемента по ключу:
val b = indexKey.lookup(1) // Array(b)
Если вы ожидаете часто использовать lookup
на одном и том же RDD, я бы рекомендовал кэшировать RDD indexKey
для повышения производительности.
Как это сделать с помощью API Java - это упражнение, оставшееся для читателя.
Ответ 2
Я попробовал этот класс для выбора элемента по индексу. Во-первых, при построении new IndexedFetcher(rdd, itemClass)
он подсчитывает количество элементов в каждом разделе RDD. Затем, когда вы вызываете indexedFetcher.get(n)
, он запускает задание только для раздела, содержащего этот индекс.
Обратите внимание, что мне нужно было скомпилировать это с использованием Java 1.7 вместо 1.8; с Spark 1.1.0 связанный org.objectweb.asm в com.esotericsoftware.reflectasm еще не может читать классы Java 1.8 (бросает IllegalStateException при попытке запустить функцию Java 1.8).
import java.io.Serializable;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.RDD;
import scala.reflect.ClassTag;
public static class IndexedFetcher<E> implements Serializable {
private static final long serialVersionUID = 1L;
public final RDD<E> rdd;
public Integer[] elementsPerPartitions;
private Class<?> clazz;
public IndexedFetcher(RDD<E> rdd, Class<?> clazz){
this.rdd = rdd;
this.clazz = clazz;
SparkContext context = this.rdd.context();
ClassTag<Integer> intClassTag = scala.reflect.ClassTag$.MODULE$.<Integer>apply(Integer.class);
elementsPerPartitions = (Integer[]) context.<E, Integer>runJob(rdd, IndexedFetcher.<E>countFunction(), intClassTag);
}
public static class IteratorCountFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, Integer> implements Serializable {
private static final long serialVersionUID = 1L;
@Override public Integer apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
int count = 0;
while (iterator.hasNext()) {
count++;
iterator.next();
}
return count;
}
}
static <E> scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> countFunction() {
scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> function = new IteratorCountFunction<E>();
return function;
}
public E get(long index) {
long remaining = index;
long totalCount = 0;
for (int partition = 0; partition < elementsPerPartitions.length; partition++) {
if (remaining < elementsPerPartitions[partition]) {
return getWithinPartition(partition, remaining);
}
remaining -= elementsPerPartitions[partition];
totalCount += elementsPerPartitions[partition];
}
throw new IllegalArgumentException(String.format("Get %d within RDD that has only %d elements", index, totalCount));
}
public static class FetchWithinPartitionFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, E> implements Serializable {
private static final long serialVersionUID = 1L;
private final long indexWithinPartition;
public FetchWithinPartitionFunction(long indexWithinPartition) {
this.indexWithinPartition = indexWithinPartition;
}
@Override public E apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
int count = 0;
while (iterator.hasNext()) {
E element = iterator.next();
if (count == indexWithinPartition)
return element;
count++;
}
throw new IllegalArgumentException(String.format("Fetch %d within partition that has only %d elements", indexWithinPartition, count));
}
}
public E getWithinPartition(int partition, long indexWithinPartition) {
System.out.format("getWithinPartition(%d, %d)%n", partition, indexWithinPartition);
SparkContext context = rdd.context();
scala.Function2<TaskContext, scala.collection.Iterator<E>, E> function = new FetchWithinPartitionFunction<E>(indexWithinPartition);
scala.collection.Seq<Object> partitions = new scala.collection.mutable.WrappedArray.ofInt(new int[] {partition});
ClassTag<E> classTag = scala.reflect.ClassTag$.MODULE$.<E>apply(this.clazz);
E[] result = (E[]) context.<E, E>runJob(rdd, function, partitions, true, classTag);
return result[0];
}
}
Ответ 3
Я тоже застрял на этом, поэтому, чтобы расширить ответ Maasg, но отвечая на поиск диапазона значений по индексу для Java (вам нужно будет определить 4 переменные вверху):
DataFrame df;
SQLContext sqlContext;
Long start;
Long end;
JavaPairRDD<Row, Long> indexedRDD = df.toJavaRDD().zipWithIndex();
JavaRDD filteredRDD = indexedRDD.filter((Tuple2<Row,Long> v1) -> v1._2 >= start && v1._2 < end);
DataFrame filteredDataFrame = sqlContext.createDataFrame(filteredRDD, df.schema());
Помните, что при запуске этого кода в вашем кластере должен быть Java 8 (поскольку используется выражение лямбда).
Кроме того, zipWithIndex, вероятно, дорогой!