import java.io._

class BufferedReaderExtensions(r: BufferedReader) extends Iterator[String] {
  private var line: String = null

  def next() = {
    try {
      if (hasNext()) 
        line
      else
        throw new NoSuchElementException()
    } finally {
      line = null
    }
  }
  def hasNext() = {
    if (line == null) {
      line = r.readLine()
    }
    line != null
  }
}

object GMM {
  implicit def breaderExts(r : BufferedReader) = new BufferedReaderExtensions(r)

  def starZip[A](xs: List[List[A]]) = {
    def iter(xs: List[List[A]], result: List[List[A]]): List[List[A]] = {
      if (xs exists {x => x.isEmpty})
        result
      else 
        iter(
          for(x <- xs) yield x.tail,
          result ::: List((for(x <- xs) yield x.head)))
    }
    iter(xs, List())
  }

  def sum(xs: List[Double]) = xs.foldLeft(0.0)({(x: Double, y: Double) => x + y})

  def zip3[A, B, C](as: List[A], bs: List[B], cs: List[C]) = {
    as zip bs zip cs map {case ((c,d),e) => (c,d,e)}
  }  

  def probNormDist(point: Double, mean: Double, stddev: Double) = {
    1 / (Math.sqrt(2 * Math.Pi) * stddev) * 
        Math.exp(-0.5 * Math.pow((point - mean) / stddev, 2.0))
  }

  def computeAux(points: List[Double], params: List[(Double, Double, Double)]) = {
    val f = 
      for(p <- points) yield 
        for((mean, stddev, weight) <- params) yield 
          weight * probNormDist(p, mean, stddev)
    
    starZip(for(i <- f) yield for(k <- i) yield k / sum(i))
  }
  
  def newWeights(aux: List[List[Double]], numPoints: Int) = 
    for(a <- aux) yield sum(a) / numPoints

  def newMeans(aux: List[List[Double]], points: List[Double]) =
    for(ys <- aux) yield 1 / sum(ys) * sum(for((y, x) <- ys zip points) yield y*x)

  def newStddevs(aux: List[List[Double]], points: List[Double], means: List[Double]) = {
    for((ys, mean) <- aux zip means) yield
      Math.sqrt(1 / sum(ys) * sum(for((y, x) <- ys zip points) yield y * Math.pow(x - mean, 2)))
  }
  
  def em(points: List[Double], params: List[(Double, Double, Double)], steps: Int): List[(Double, Double, Double)] = {
    val aux = computeAux(points, params)
    val nextParams = zip3(
      newMeans(aux, points),
      newStddevs(aux, points, for((mean, stddev, weight) <- params) yield mean),
      newWeights(aux, points.length))
           
    if (steps < 2)
      nextParams
    else
      em(points, nextParams, steps - 1)
  }

  def readData(file: String) = {
    val r = new BufferedReader(new FileReader(file))
    try {
      List.fromIterator(
        for (l <- r) yield java.lang.Double.parseDouble(l)
        )
      
    } finally {
      r.close();
    }
  }
  def main(args: Array[String]) = {
    val points = readData("points.dat")
    val initial = List((10.0, 20.0, 0.1), 
                       (20.0, 40.0, 0.1),
                       (30.0, 60.0, 0.8))
    println(em(points, initial, 1000))
    null
  }
}

