Lo más parecido que pude encontrar es zip, que hace más o menos lo que quiero, pero usa el archivo index. Quiero especificar un campo en el que se unen las dos listas si tienen el mismo valor.
En SQL, uno usaría "table1 INNER JOIN table2 WHERE table1.field = table2.field". ¿Hay algo similar en Kotlin?
Quizás este ejemplo lo aclare más:
class Something(val id : Int, val value : Int) val list1 = listOf(Something(0, 1), Something(1, 6), Something(2, 8)) val list2 = listOf(Something(1, 2), Something(5, 3), Something(9, 6)) val result = list1.innerJoin(list2, on=id).map { element1, element2 -> element1.value * element2.value} //should return [12] (6*2)
list1 y list2 tienen un elemento con id=1, por lo que sus valores (6 y 2) se multiplican en este ejemplo y el resultado debería ser 12.
Actualmente uso este fragmento de código, que funciona, pero me preguntaba si hay una manera más fácil y eficiente de hacerlo.
val result = list1.map { element1 -> val element2 = list2.find { element2 -> element2.id == element1.id } ?: return@map null element1.value * element2.value }.filterNotNull()
Gracias.
Podría usar mapNotNull
para eliminar el paso de filtrado secundario, pero eso es lo más conciso posible. También puede convertir list2
a Map primero para cambiar esto de O(n^2) a O(n) .
val list2ById = list2.associateBy(Something::id) val result = list1.mapNotNull { element1 -> list2ById[element1.id]?.value?.times(element1.value) }
Kotlin no proporciona el método stdlib para esto, pero puedes definir el tuyo propio:
fun <T> Collection<T>.innerJoin(other: Collection<T>, on: T.() -> Any): Collection<Pair<T, T>> { val otherByKey = other.associateBy(on) return this.mapNotNull { val otherMapped = otherByKey[on(it)] if (otherMapped == null) null else it to otherMapped } }
Uso:
fun main() { val list1 = listOf(Something(0, 1), Something(1, 6), Something(2, 8)) val list2 = listOf(Something(1, 2), Something(5, 3), Something(9, 6)) val result = list1.innerJoin(list2, on = { id }).map { (element1, element2) -> element1.value * element2.value } println(result) //12 }