代码之家  ›  专栏  ›  技术社区  ›  Shkarik

为什么我在Scala中的二进制搜索实现如此缓慢?

  •  4
  • Shkarik  · 技术社区  · 7 年前

    最近,我实现了这个二进制搜索,Scala应该在6秒内运行,但它在检查分配的机器上运行12-13秒。

    阅读代码前请注意:输入由两行组成:第一行是要搜索的数字列表,第二行是要在数字列表中搜索的“搜索词”列表。预期输出只列出数字列表中每个术语的索引。每个输入的最大长度为10^5,每个数字的最大大小为10^9。

    例如:

    Input:
    5 1 5 8 12 13 //note, that the first number 5 indicates the length of the 
    following sequence
    
    5 8 1 23 1 11 //note, that the first number 5 indicates the length of the 
    following sequence
    
    Output:
    2 0 -1 0 -1 // index of each term in the input array
    

    我的解决方案:

    object BinarySearch extends App {
      val n_items = readLine().split(" ").map(BigInt(_))
      val n = n_items(0)
      val items = n_items.drop(1)
    
      val k :: terms = readLine().split(" ").map(BigInt(_)).toList
    
      println(search(terms, items).mkString(" "))
    
      def search(terms: List[BigInt], items:Array[BigInt]): Array[BigInt] = {
        @tailrec
        def go(terms: List[BigInt], results: Array[BigInt]): Array[BigInt] = terms match {
          case List() => results
          case head :: tail => go(tail, results :+ find(head))
        }
    
        def find(term: BigInt): BigInt = {
          @tailrec
          def go(left: BigInt, right: BigInt): BigInt = {
            if (left > right) { -1 }
            else {
              val middle = left + (right - left) / 2
              val middle_val = items(middle.toInt)
    
              middle_val match {
                case m if m == term => middle
                case m if m <= term => go(middle + 1, right)
                case m if m > term => go(left, middle - 1)
              }
            }
          }
    
          go(0, n - 1)
        }
    
        go(terms, Array())
      }
    }
    

    这段代码为什么这么慢?非常感谢。

    1 回复  |  直到 7 年前
        1
  •  4
  •   Peter de Rivaz    7 年前

    我担心

    results :+ find(head)
    

    将项目附加到长度为L的列表中是O(L)(请参见 here ),因此,如果要计算n个结果,则复杂性将为O(n*n)。

    尝试使用可变的ArrayBuffer而不是数组来累积结果,或者简单地通过find函数映射输入项。

    换言之,替换

    go(terms, Array())
    

    具有

    terms.map( x => find(x) ).toArray
    

    顺便说一句,这个问题的限制太小了,使用BigInt太过分了,可能会使代码速度大大降低。正常Int应足够大,以解决此问题。