代码之家  ›  专栏  ›  技术社区  ›  Raphael Roth

如何使用Scala惰性集合实现takeUntil

  •  1
  • Raphael Roth  · 技术社区  · 6 年前

    我有一个昂贵的函数,我想运行尽可能少的次数,要求如下:

    • 如果函数返回的值低于给定的阈值,我不想尝试其他输入
    • 如果没有结果低于阈值,我希望以最小的输出获取结果

    我无法使用迭代器的takeWhile/dropWhile找到一个好的解决方案,因为我想包含第一个匹配的元素。最后得到了以下解决方案:

    val pseudoResult = Map("a" -> 0.6,"b" -> 0.2, "c" -> 1.0)
    
    def expensiveFunc(s:String) : Double = {
      pseudoResult(s)
    }
    
    val inputsToTry = Seq("a","b","c")
    
    val inputIt = inputsToTry.iterator
    val results = mutable.ArrayBuffer.empty[(String, Double)]
    
    val earlyAbort = 0.5 // threshold
    
    breakable {
      while (inputIt.hasNext) {
        val name = inputIt.next()
        val res = expensiveFunc(name)
        results += Tuple2(name,res)
        if (res<earlyAbort) break()
      }
    }
    
    println(results) // ArrayBuffer((a,0.6), (b,0.2))
    
    val (name, bestResult) = results.minBy(_._2) // (b, 0.2)
    

    如果我设置 val earlyAbort = 0.1 ,结果应该仍然是 (b, 0.2) 没有重新评估所有的案例。

    5 回复  |  直到 6 年前
        1
  •  3
  •   Gonzalo Guglielmo    6 年前

    你可以利用 Stream 为了达到你想要的,记住 溪流 是一种懒惰的收集,根据需要评估操作。

    这是斯卡拉 Stream 文档。

    你只需要这样做:

    val pseudoResult = Map("a" -> 0.6,"b" -> 0.2, "c" -> 1.0)
    val earlyAbort = 0.5
    
    def expensiveFunc(s: String): Double = {
      println(s"Evaluating for $s")
      pseudoResult(s)
    }
    
    val inputsToTry = Seq("a","b","c")
    
    val results = inputsToTry.toStream.map(input => input -> expensiveFunc(input))
    val finalResult = results.find { case (k, res) => res < earlyAbort }.getOrElse(results.minBy(_._2))
    

    如果 find 不获取任何值,您可以使用同一流查找min,并且不会再次计算该函数,这是由于备忘录:

    Stream类还使用memoization,这样以前计算的值就可以从Stream元素转换为类型A的具体值

    如果原始集合为空,则考虑此代码将失败,如果要支持空集合,则应替换 minBy 具有 sortBy(_._2).headOption getOrElse 通过 orElse :

    val finalResultOpt = results.find { case (k, res) => res < earlyAbort }.orElse(results.sortBy(_._2).headOption)
    

    结果是:

    评估

    评估b

    最终结果:(字符串,双精度)=(b,0.2)

    finalResultOpt:Option[(String,Double)]=一些((b,0.2))

        2
  •  1
  •   jwvh    6 年前

    最清楚、最简单的事情就是 fold 超过输入,只传递当前最佳结果。

    val inputIt :Iterator[String] = inputsToTry.iterator
    val earlyAbort = 0.5 // threshold
    
    inputIt.foldLeft(("",Double.MaxValue)){ case (low,name) =>
      if (low._2 < earlyAbort) low
      else Seq(low, (name, expensiveFunc(name))).minBy(_._2)
    }
    //res0: (String, Double) = (b,0.2)
    

    它召唤 expensiveFunc() 只是需要多少次,但它确实遍历了整个输入迭代器。如果这仍然是太繁重(大量的输入),那么我会用尾部递归方法。

    val inputIt :Iterator[String] = inputsToTry.iterator
    val earlyAbort = 0.5 // threshold
    
    def bestMin(low :(String,Double) = ("",Double.MaxValue)) :(String,Double) = {
      if (inputIt.hasNext) {
        val name = inputIt.next()
        val res = expensiveFunc(name)
        if (res < earlyAbort) (name, res)
        else if (res < low._2) bestMin((name,res))
        else bestMin(low)
      } else low
    }
    bestMin()  //res0: (String, Double) = (b,0.2)
    
        3
  •  0
  •   proximator    6 年前

    在输入列表中使用视图: 请尝试以下操作:

      val pseudoResult = Map("a" -> 0.6, "b" -> 0.2, "c" -> 1.0)
    
      def expensiveFunc(s: String): Double = {
        println(s"executed for ${s}")
        pseudoResult(s)
      }
    
      val inputsToTry = Seq("a", "b", "c")
      val earlyAbort = 0.5 // threshold
    
      def doIt(): List[(String, Double)] = {
    
        inputsToTry.foldLeft(List[(String, Double)]()) {
          case (n, name) =>
    
    
            val res = expensiveFunc(name)
            if(res < earlyAbort) {
              return n++List((name, res))
            }
            n++List((name, res))
        }
    
      }
    
      val (name, bestResult) = doIt().minBy(_._2)
      println(name)
      println(bestResult)
    

    输出:

    executed for a
    executed for b
    b
    0.2
    

    如您所见,只有a和b被计算,而不是c。

        4
  •  0
  •   curious    6 年前

    这是tail递归的一个用例:

      import scala.annotation.tailrec
      val pseudoResult = Map("a" -> 0.6,"b" -> 0.2, "c" -> 1.0)
    
      def expensiveFunc(s:String) : Double = {
        pseudoResult(s)
      }
    
      val inputsToTry = Seq("a","b","c")
    
      val earlyAbort = 0.5 // threshold
    
      @tailrec
      def f(s: Seq[String], result: Map[String, Double] = Map()): Map[String, Double] = s match {
        case Nil => result
        case h::t =>
          val expensiveCalculation = expensiveFunc(h)
          val intermediateResult = result + (h -> expensiveCalculation)
          if(expensiveCalculation < earlyAbort) {
            intermediateResult
          } else {
            f(t, intermediateResult)
          }
      }
      val result = f(inputsToTry)
    
      println(result) // Map(a -> 0.6, b -> 0.2)
    
      val (name, bestResult) = f(inputsToTry).minBy(_._2) // ("b", 0.2)
    
        5
  •  0
  •   stefanobaghino    6 年前

    如果你实施 takeUntil 使用它,如果你找不到你想要的东西,你仍然需要再次浏览这个列表才能得到最低的一个。也许更好的方法是拥有一个结合了 find 具有 reduceOption ,如果找到某个项,则提前返回,或者返回将集合缩减为单个项的结果(在您的示例中,是查找最小的项)。

    结果与使用 Stream ,如前一个回复中所强调的,但避免利用备忘录,这对于非常大的收藏来说可能会很麻烦。

    可能的实施方式如下:

    import scala.annotation.tailrec
    
    def findOrElse[A](it: Iterator[A])(predicate: A => Boolean,
                                       orElse: (A, A) => A): Option[A] = {
      @tailrec
      def loop(elseValue: Option[A]): Option[A] = {
        if (!it.hasNext) elseValue
        else {
          val next = it.next()
          if (predicate(next)) Some(next)
          else loop(Option(elseValue.fold(next)(orElse(_, next))))
        }
      }
      loop(None)
    }
    

    让我们添加输入来测试:

    def f1(in: String): Double = {
      println("calling f1")
      Map("a" -> 0.6, "b" -> 0.2, "c" -> 1.0, "d" -> 0.8)(in)
    }
    
    def f2(in: String): Double = {
      println("calling f2")
      Map("a" -> 0.7, "b" -> 0.6, "c" -> 1.0, "d" -> 0.8)(in)
    }
    
    val inputs = Seq("a", "b", "c", "d")
    

    以及我们的助手:

    def apply[IN, OUT](in: IN, f: IN => OUT): (IN, OUT) =
      in -> f(in)
    
    def threshold[A](a: (A, Double)): Boolean =
      a._2 < 0.5
    
    def compare[A](a: (A, Double), b: (A, Double)): (A, Double) =
      if (a._2 < b._2) a else b
    

    我们现在可以运行它并查看其运行情况:

    val r1 = findOrElse(inputs.iterator.map(apply(_, f1)))(threshold, compare)
    val r2 = findOrElse(inputs.iterator.map(apply(_, f2)))(threshold, compare)
    val r3 = findOrElse(Map.empty[String, Double].iterator)(threshold, compare)
    

    r1 Some(b, 0.2) , r2 Some(b, 0.6) r3 是(合理的) None . 在第一种情况下,由于我们使用惰性迭代器并提前终止,因此我们只调用 f1 两次。

    您可以查看结果并使用此代码 here on Scastie .