import Text.Printf
import Data.List
import Data.Maybe

data Dist a = Dist [(a,Double)] deriving (Show,Eq)

-- Part 1 --

certainly :: a -> Dist a
certainly a = Dist [(a,1)]

scale :: Dist a -> Dist a
scale (Dist lst) = Dist $ map (\(x,y)->(x,y/(sum (map snd lst)))) lst

-- Part 2 --

uniform :: [a] -> Dist a
uniform lst = scale $ Dist $ map (\c->(c,1)) lst

die :: Dist Int
die = uniform [1..6]

coin :: Dist Char
coin = uniform "HT"

-- Part 3 --

norm :: Ord a => Dist a -> Dist a
norm (Dist lst) =
  Dist $ map (\x->(fst (head x),sum (map snd x)))
    $ groupBy (\(x,_) (y,_)->x==y)
    $ sortBy (\(x,_) (y,_)->compare x y) lst

-- Part 4 --

(??) :: (a->Bool) -> Dist a -> Double
f ?? Dist e = (sum.map snd.filter (f.fst)) e

-- Part 5 --

instance Functor Dist where
  fmap f (Dist lst) = Dist $ map (\(e,p)->(f e,p)) lst

-- Part 6 --

instance Applicative Dist where
  pure x = certainly x
  Dist flst <*> Dist xlst = Dist [(f x,fp*xp) | (f,fp)<-flst, (x,xp)<-xlst]

-- Part 7 --

selectI :: Int -> Dist a -> Dist [a]
selectI 0 _ = certainly []
selectI n d = (:) <$> d <*> selectI (n-1) d

-- Part 8 --

instance (Ord e, Num e) => Num (Dist e) where
  negate = fmap negate
  d0 + d1 = norm $ (+) <$> d0 <*> d1
  d0 * d1 = norm $ (*) <$> d0 <*> d1
  fromInteger = certainly.fromInteger
  abs = norm.fmap abs
  signum = norm.fmap signum

-- Part 9 --

condDist :: Eq a => Dist a -> Dist (a,Dist a)
condDist (Dist lst) = fmap (\x->(x,scale (Dist (filter ((/=x).fst) lst)))) (Dist lst)

-- Part 10 --

selectD' :: Ord a => Int -> Dist a -> Dist ([a],Dist a)
selectD' 0 c = return ([],c)
selectD' n c = do
  (x,c1) <- condDist c
  (xs,c2) <- selectD' (n-1) c1
  return (x:xs, c2)

selectD :: Ord a => Int -> Dist a -> Dist [a]
selectD n c = norm $ fst <$> selectD' n c

instance Monad Dist where
  xs >>= f =
    let Dist lst = f <$> xs in
    Dist $ concatMap (\(Dist e,p)->map (\(x,q)->(x,p*q)) e) lst

-- Part 11 --

data DistF a b = DistF (Dist a->Dist (b,Dist a))

selectD'' :: Ord a => Int -> DistF a [a]
selectD'' 0 = return []
selectD'' n = do
  x <- DistF condDist
  xs <- selectD'' (n-1)
  return (x:xs)

applyDistF :: DistF a b -> Dist a -> Dist b
applyDistF (DistF f) da = fst <$> f da

-- Part 12 --

instance Functor (DistF a) where fmap = undefined
instance Applicative (DistF a) where
  pure x = DistF (\y->certainly (x,y))
  (<*>) = undefined
instance Ord a => Monad (DistF a) where
  DistF d1 >>= f = DistF (\y-> do
    (x,y') <- d1 y
    let DistF d2 = f x
    (d3,y'') <- d2 (norm y')
    return (d3,norm y''))

-- Part 13 --

pickNumber :: Int -> Dist Int
pickNumber n = (norm.applyDistF (pickNumber' n).uniform) [1..10]

pickNumber' :: Int -> DistF Int Int
pickNumber' n = do
  a <- DistF condDist
  b <- selectD'' 2
  c <- selectD'' 3
  let l = map (\x->if x>n then 0 else x) [a, sum b, sum c]
  (return.(1+).fromJust.elemIndex (maximum l)) l

--TESTS--

die' :: Dist Int
die' = Dist [(1,0.1), (2,0.1), (3,0.15), (4,0.15), (5,0.2), (6,0.3)]

t' :: Show a => Dist a -> [(a,String)]
t' (Dist lst) = (map (\(x,y)->(x,printf "%.2f" y)) lst)

t :: Show a => Dist a -> String
t x = show (t' x)

t2 :: Show a => Dist (a, Dist a) -> String
t2 (Dist lst) = t $ Dist (map (\((x,y),z)->((x,(t' y)),z)) lst)

tests :: [[(String,String)]]
tests = [
    [(t $ certainly "Abc", show [("Abc","1.00")]),
     (t $ certainly (2::Integer), show [(2::Integer,"1.00")]),
     (t $ scale (Dist [("Abc",1),("Def",1)]), show [("Abc","0.50"),("Def","0.50")]),
     (t $ scale (Dist [("Abc",100),("Def",300)]), show [("Abc","0.25"),("Def","0.75")]),
     (t $ scale (Dist [(30,100),(20,200),(10::Integer,100)]), show [(30,"0.25"),(20,"0.50"),(10,"0.25")])]

  , [(t $ uniform ["Abc", "Def"], show [("Abc","0.50"),("Def","0.50")]),
     (t $ uniform ["Ab", "De", "Abc", "Def"], show [("Ab","0.25"),("De","0.25"),("Abc","0.25"),("Def","0.25")]),
     (t $ die, show [((1::Integer),"0.17"),(2,"0.17"),(3,"0.17"),(4,"0.17"),(5,"0.17"),(6,"0.17")]),
    (t $ coin, show [('H',"0.50"),('T',"0.50")])]

  , [(t $ norm (Dist [("Abc",0.25), ("Def",0.25), ("Abc",0.50)]), show [("Abc", "0.75"), ("Def","0.25")]),
    (t $ norm (Dist [("Abc",0.25), ("Abc",0.25), ("Abc",0.50)]), show [("Abc", "1.00")]),
    (t $ norm (Dist [("Abc",0.25), ("Def",0.25), ("Abc",0.25), ("Abc",0.25)]), show [("Abc", "0.75"), ("Def","0.25")])]

  , [(show $ (=='H') ?? coin, "0.5"),
    (show $ (<=3) ?? die, "0.5")]

  , [(t $ (+2) <$> die, show [((3::Integer),"0.17"),(4,"0.17"),(5,"0.17"),(6,"0.17"),(7,"0.17"),(8,"0.17")]),
    (t $ (+2) <$> die', show [((3::Integer),"0.10"),(4,"0.10"),(5,"0.15"),(6,"0.15"),(7,"0.20"),(8,"0.30")])]

  , [(t $ norm $ (+) <$> die <*> die, show [((2::Integer),"0.03"),(3,"0.06"),(4,"0.08"),(5,"0.11"),(6,"0.14"),(7,"0.17"),(8,"0.14"),(9,"0.11"),(10,"0.08"),(11,"0.06"),(12,"0.03")]),
    (t $ norm $ (+) <$> die' <*> die', show [((2::Integer),"0.01"),(3,"0.02"),(4,"0.04"),(5,"0.06"),(6,"0.09"),(7,"0.15"),(8,"0.14"),(9,"0.15"),(10,"0.13"),(11,"0.12"),(12,"0.09")])]

  , [(t $ norm $ sum <$> selectI 2 die, show [((2::Integer),"0.03"),(3,"0.06"),(4,"0.08"),(5,"0.11"),(6,"0.14"),(7,"0.17"),(8,"0.14"),(9,"0.11"),(10,"0.08"),(11,"0.06"),(12,"0.03")]),
    (t $ norm $ sum <$> selectI 2 die', show [((2::Integer),"0.01"),(3,"0.02"),(4,"0.04"),(5,"0.06"),(6,"0.09"),(7,"0.15"),(8,"0.14"),(9,"0.15"),(10,"0.13"),(11,"0.12"),(12,"0.09")])]

  , [(t $ die + die, show [((2::Integer),"0.03"),(3,"0.06"),(4,"0.08"),(5,"0.11"),(6,"0.14"),(7,"0.17"),(8,"0.14"),(9,"0.11"),(10,"0.08"),(11,"0.06"),(12,"0.03")]),
    (t $ die * die, show [(1,"0.03"),((2::Integer),"0.06"),(3,"0.06"),(4,"0.08"),(5,"0.06"),(6,"0.11"),(8,"0.06"),(9,"0.03"),(10,"0.06"),(12,"0.11"),(15,"0.06"),(16,"0.03"),(18,"0.06"),(20,"0.06"),(24,"0.06"),(25,"0.03"),(30,"0.06"),(36,"0.03")]),
    (t $ negate $ uniform [-2..2::Int], show [((2::Integer),"0.20"),(1,"0.20"),(0,"0.20"),(-1,"0.20"),(-2,"0.20")]),
    (t $ signum $ uniform [-2..2::Int], show [((-1::Integer),"0.40"),(0,"0.20"),(1,"0.40")]),
    (t $ ((fromInteger 10)::Dist Int), show [((10::Integer),"1.00")]),
    (t $ die * die', show [((1::Integer),"0.02"),(2,"0.03"),(3,"0.04"),(4,"0.06"),(5,"0.05"),(6,"0.11"),(8,"0.04"),(9,"0.02"),(10,"0.05"),(12,"0.12"),(15,"0.06"),(16,"0.02"),(18,"0.08"),(20,"0.06"),(24,"0.08"),(25,"0.03"),(30,"0.08"),(36,"0.05")])]

  , [(t2 $ condDist $ Dist [('A',0.30), ('B',0.30), ('C',0.40)], show [(('A',  [('B',"0.43"),('C',"0.57")]),"0.30"),(('B', [('A',"0.43"),('C',"0.57")]),"0.30"),(('C', [('A',"0.50"),('B',"0.50")]),"0.40")]),
    (t2 $ condDist die, show [(((1::Integer), [((2::Integer),"0.20"),(3,"0.20"),(4,"0.20"),(5,"0.20"),(6,"0.20")]),"0.17"),((2, [((1::Integer),"0.20"),(3,"0.20"),(4,"0.20"),(5,"0.20"),(6,"0.20")]),"0.17"),((3, [((1::Integer),"0.20"),(2,"0.20"),(4,"0.20"),(5,"0.20"),(6,"0.20")]),"0.17"),((4, [((1::Integer),"0.20"),(2,"0.20"),(3,"0.20"),(5,"0.20"),(6,"0.20")]),"0.17"),((5, [((1::Integer),"0.20"),(2,"0.20"),(3,"0.20"),(4,"0.20"),(6,"0.20")]),"0.17"),((6, [((1::Integer),"0.20"),(2,"0.20"),(3,"0.20"),(4,"0.20"),(5,"0.20")]),"0.17")])]

  , [(t $ norm $ sum <$> selectD 2 die, show [((3::Integer),"0.07"),(4,"0.07"),(5,"0.13"),(6,"0.13"),(7,"0.20"),(8,"0.13"),(9,"0.13"),(10,"0.07"),(11,"0.07")]),
    (t $ norm $ sum <$> selectD 3 die,show [(6,"0.05"),((7::Integer),"0.05"),(8,"0.10"),(9,"0.15"),(10,"0.15"),(11,"0.15"),(12,"0.15"),(13,"0.10"),(14,"0.05"),(15,"0.05")])]

  , [(t $ applyDistF (DistF condDist) die, t die)]

  , [(t $ norm $ sum <$> applyDistF (selectD'' 3) die, show [((6::Integer),"0.05"),(7,"0.05"),(8,"0.10"),(9,"0.15"),(10,"0.15"),(11,"0.15"),(12,"0.15"),(13,"0.10"),(14,"0.05"),(15,"0.05")])]

  , [(t $ pickNumber 15, show [((1::Integer),"0.19"),(2,"0.53"),(3,"0.28")]),
    (t $ pickNumber 10, show [((1::Integer),"0.70"),(2,"0.25"),(3,"0.06")]),
    (t $ pickNumber 100, show [(1,"0.01"),(2,"0.23"),((3::Integer),"0.76")])]]

test2str :: (Int,[(String,String)]) -> String
test2str (x,y) =
  case dropWhile (\(_,(a,b))->a==b) (zip [1::Integer ..] y) of
    [] -> "Part "++(show x)++" works on given examples."
    ((n,(act,ex)):_) -> "---------------------\nPart "++(show x)++" FAILED\n  Test case "++(show n)++" failed\n  Expected Output: "++ex++"\n    Actual Output: "++act++"\n---------------------"

main :: IO ()
main = putStr $ unlines $ map test2str $ zip [1..] tests