{-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns #-}
module Main
where

import GHC.Prim (
  Double#, ByteArray#, MutableByteArray#, RealWorld,
  newByteArray#, unsafeFreezeByteArray#,
  readDoubleArray#, writeDoubleArray#, indexDoubleArray#)
import GHC.Base  ( Int(..), (+#) )
import GHC.Float ( Double(..) )
import GHC.ST    ( ST(..), runST )
import GHC.Conc  ( forkOnIO, numCapabilities )
import Data.Array.Base (dOUBLE_SCALE)

import Control.Concurrent.MVar ( MVar, newEmptyMVar, takeMVar, putMVar )
import Control.Monad           ( zipWithM_ )

import System.Environment      ( getArgs )

import System.CPUTime
import System.Time

-- Arrays
-- ------

data Arr    = Arr  !Int !Int ByteArray#
data MArr s = MArr !Int (MutableByteArray# s)

lengthA :: Arr -> Int
lengthA (Arr _ n _) = n

indexA :: Arr -> Int -> Double
indexA (Arr (I# i#) _ arr#) (I# j#) = D# (indexDoubleArray# arr# (i# +# j#))

sliceA :: Arr -> Int -> Int -> Arr
sliceA (Arr i _ arr#) j n = Arr (i+j) n arr#

newMA :: Int -> ST s (MArr s)
newMA n@(I# n#) = ST $ \s1# ->
  case newByteArray# (dOUBLE_SCALE n#) s1# of { (# s2#, marr# #) ->
  (# s2#, MArr n marr# #) }

unsafeFreezeMA :: MArr s -> ST s Arr
unsafeFreezeMA (MArr n marr#) = ST $ \s1# ->
  case unsafeFreezeByteArray# marr# s1# of { (# s2#, arr# #) ->
  (# s2#, Arr 0 n arr# #) }

writeMA :: MArr s -> Int -> Double -> ST s ()
writeMA (MArr _ marr#) (I# i#) (D# d#) = ST $ \s# ->
  case writeDoubleArray# marr# i# d# s# of { s2# -> (# s2#, () #) }

replicateA :: Int -> Double -> Arr
replicateA n d = runST (
  do
    marr <- newMA n
    fill marr
    unsafeFreezeMA marr
  )
  where
    fill marr = fill' 0
      where
        fill' i | i < n = do
                            writeMA marr i d
                            fill' (i+1)
                | otherwise = return ()


dotpA :: Arr -> Arr -> Double
dotpA !xs !ys = go 0 0
  where
    n = lengthA xs

    go i !r | i < n     = go (i+1) (r + indexA xs i * indexA ys i)
            | otherwise = r

-- Parallel arrays
-- ---------------

splitLen :: Int -> Int -> [Int]
splitLen threads n = replicate m (l+1) ++ replicate (threads - m) l
  where
    l = n `div` threads
    m = n `mod` threads

splitA :: Int -> Arr -> [Arr]
splitA threads arr = zipWith (sliceA arr) (scanl (+) 0 lens) lens
  where
    lens = splitLen threads (lengthA arr)

-- Gangs
-- -----

data Gang   = Gang Int [MVar (Arr, Arr)] [MVar Double]

worker :: MVar (Arr, Arr) -> MVar Double -> IO ()
worker arg res
   = do
       (xs, ys) <- takeMVar arg
       putMVar res $! dotpA xs ys

forkGang :: Int -> IO Gang
forkGang n
  = do
      as <- sequence $ replicate n newEmptyMVar
      rs <- sequence $ replicate n newEmptyMVar
      zipWithM_ forkOnIO [0..] $ zipWith worker as rs
      return $ Gang n as rs

-- Timing
-- ------

data Time = Time { cpu_time  :: Integer
                 , wall_time :: Integer
                 }

type TimeUnit = Integer -> Integer

picoseconds :: TimeUnit
picoseconds = id

milliseconds :: TimeUnit
milliseconds n = n `div` 1000000000

seconds :: TimeUnit
seconds n = n `div` 1000000000000

cpuTime :: TimeUnit -> Time -> Integer
cpuTime f = f . cpu_time

wallTime :: TimeUnit -> Time -> Integer
wallTime f = f . wall_time

getTime :: IO Time
getTime =
  do
    cpu          <- getCPUTime
    TOD sec pico <- getClockTime
    return $ Time cpu (pico + sec * 1000000000000)

zipT :: (Integer -> Integer -> Integer) -> Time -> Time -> Time
zipT f (Time cpu1 wall1) (Time cpu2 wall2) =
  Time (f cpu1 cpu2) (f wall1 wall2)

minus :: Time -> Time -> Time
minus = zipT (-)

fromTime :: Time -> (Integer, Integer)
fromTime t = (wallTime milliseconds t, cpuTime milliseconds t)

instance Show Time where
  showsPrec n t = showsPrec n (wallTime milliseconds t)
                . showChar '/'
                . showsPrec n (cpuTime milliseconds t)

-- Benchmark
-- ---------

dotp :: Gang -> [Arr] -> [Arr] -> IO [Double]
dotp (Gang n as rs) xss yss
  = do
      zipWithM_ putMVar as $ zip xss yss
      mapM takeMVar rs

main = do
         [arg1, arg2] <- getArgs
         let n    = read arg2
             runs = read arg1
             xs   = replicateA n 5
             ys   = replicateA n 6
             xss  = splitA numCapabilities xs
             yss  = splitA numCapabilities ys
         eval xss `seq` eval yss `seq` return ()
         let oneRun = do 
                        gang <- forkGang numCapabilities
                        t1 <- getTime
                        dotp gang xss yss
                        t2 <- getTime
                        return $ fromTime (t2 `minus` t1)
         times <- sequence (replicate runs oneRun)
         let (walls, cpus) = unzip times
         putStrLn $ show (sum walls `div` toInteger runs) ++ "/" ++ 
                    show (sum cpus  `div` toInteger runs)
         return ()
  where
    eval (x:xs) = x `seq` eval xs
    eval []     = ()

