Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is a good way of reusing function result in Scala

Let me clarify my question by example. This is a standard exponentiation algorithm written with tail recursion in Scala:

def power(x: Double, y: Int): Double = {
  def sqr(z: Double): Double = z * z
  def loop(xx: Double, yy: Int): Double = 
    if (yy == 0) xx
    else if (yy % 2 == 0) sqr(loop(xx, yy / 2))
    else loop(xx * x, yy - 1)

  loop(1.0, y)
}

Here sqr method is used to produce the square of loop's result. It doesn't look like a good idea - to define a special function for such a simple operation. But, we can't write just loop(..) * loop(..) instead, since it doubles the calculations.

We also can write it with val and without sqr function:

def power(x: Double, y: Int): Double = {
  def loop(xx: Double, yy: Int): Double = 
    if (yy == 0) xx
    else if (yy % 2 == 0) { val s = loop(xx, yy / 2); s * s }
    else loop(xx * x, yy - 1)

  loop(1.0, y)
}

I can't say that it looks better then variant with sqr, since it uses state variable. The first case is more functional the second way is more Scala-friendly.

Anyway, my question is how to deal with cases when you need to postprocess function's result? Maybe Scala has some other ways to achieve that?

like image 487
Vladimir Kostyukov Avatar asked Oct 24 '25 19:10

Vladimir Kostyukov


2 Answers

You are using the law that

x^(2n) = x^n * x^n

But this is the same as

x^n * x^n = (x*x)^n

Hence, to avoid squaring after recursion, the value in the case where y is even should be like displayed below in the code listing.

This way, tail-calling will be possible. Here is the full code (not knowing Scala, I hope I get the syntax right by analogy):

def power(x: Double, y: Int): Double = {
    def loop(xx: Double, acc: Double, yy: Int): Double = 
      if (yy == 0) acc
      else if (yy % 2 == 0) loop(xx*xx, acc, yy / 2)
      else loop(xx, acc * xx, yy - 1)

    loop(x, 1.0, y)
}

Here it is in a Haskell like language:

power2 x n = loop x 1 n 
    where 
        loop x a 0 = a 
        loop x a n = if odd n then loop x    (a*x) (n-1) 
                              else loop (x*x) a    (n `quot` 2)
like image 122
Ingo Avatar answered Oct 27 '25 11:10

Ingo


You could use a "forward pipe". I've got this idea from here: Cache an intermediate variable in an one-liner.

So

val s = loop(xx, yy / 2); s * s

could be rewritten to

loop(xx, yy / 2) |> (s => s * s)

using an implicit conversion like this

implicit class PipedObject[A](value: A) {
  def |>[B](f: A => B): B = f(value)
}

As Petr has pointed out: Using an implicit value class

object PipedObjectContainer {
  implicit class PipedObject[A](val value: A) extends AnyVal {
    def |>[B](f: A => B): B = f(value)
  }
}

to be used like this

import PipedObjectContainer._
loop(xx, yy / 2) |> (s => s * s)

is better, since it does not need a temporary instance (requires Scala >= 2.10).

like image 29
Beryllium Avatar answered Oct 27 '25 11:10

Beryllium