From 9ed0df9acf159530ef8808f68bc55bdfbbca60f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joa=CC=83o=20Costa?= Date: Fri, 20 Sep 2024 17:15:20 +0100 Subject: [PATCH] Support primitives in Flow#collectType --- .../pekko/stream/scaladsl/FlowCollectTypeSpec.scala | 13 ++++++++++++- .../org/apache/pekko/stream/scaladsl/Flow.scala | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowCollectTypeSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowCollectTypeSpec.scala index 956bf9b58b0..8f0a82cf3a2 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowCollectTypeSpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowCollectTypeSpec.scala @@ -25,7 +25,7 @@ class FlowCollectTypeSpec extends StreamSpec { "A CollectType" must { - "collectType" in { + "collectType with references" in { val fruit = Source(List(Orange, Apple, Apple, Orange)) val apples = fruit.collectType[Apple].runWith(Sink.seq).futureValue @@ -36,6 +36,17 @@ class FlowCollectTypeSpec extends StreamSpec { all should equal(List(Orange, Apple, Apple, Orange)) } + "collectType with primitives" in { + val numbers = Source(List[Int](1, 2, 3) ++ List[Double](1.5)) + + val integers = numbers.collectType[Int].runWith(Sink.seq).futureValue + integers should equal(List(1, 2, 3)) + val doubles = numbers.collectType[Double].runWith(Sink.seq).futureValue + doubles should equal(List(1.5)) + val all = numbers.collectType[Any].runWith(Sink.seq).futureValue + all should equal(List(1, 2, 3, 1.5)) + } + } } diff --git a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Flow.scala b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Flow.scala index 372378d42b9..3bf8e436ba6 100755 --- a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Flow.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Flow.scala @@ -1686,7 +1686,7 @@ trait FlowOps[+Out, +Mat] { * '''Cancels when''' downstream cancels */ def collectType[T](implicit tag: ClassTag[T]): Repr[T] = - collect { case c if tag.runtimeClass.isInstance(c) => c.asInstanceOf[T] } + collect { case tag(c) => c } /** * Chunk up this stream into groups of the given size, with the last group