UNB/ CS/ David Bremner/ teaching/ cs3613/ tutorials/ tutorial8/ fae-cont.rkt
;; Originally by Matthew Flatt
;; Translated to plait, subjectively "improved" by David Bremner

#lang plait

(define-type FAE
  [Num (n : Number)]
  [Add (lhs : FAE)
       (rhs : FAE)]
  [Sub (lhs : FAE)
       (rhs : FAE)]
  [Sub1 (arg : FAE)]
  [Mul (lhs : FAE)
       (rhs : FAE)]
  [Id (name : Symbol)]
  [Fun (param : Symbol)
       (body : FAE)]
  [Call (fun-expr : FAE)
       (arg-expr : FAE)]
  [If0 (test : FAE)
       (then : FAE)
       (else : FAE)])

(define-type FAE-Value
  [NumV (n : Number)]
  [ClosureV (param : Symbol)
            (body : FAE)
            (env : Env)])

(define-type Env
  [mtEnv]
  [aBind (name : Symbol)
        (value : FAE-Value)
        (rest : Env)])

(define-type FAE-Cont
  [mtK]
  [addSecondK (r : FAE)
              (env : Env)
              (k : FAE-Cont)]
  [doAddK (v1 : FAE-Value)
          (k : FAE-Cont)]
  [subSecondK (r : FAE)
              (env : Env)
              (k : FAE-Cont)]
  [doSubK (v1 : FAE-Value)
          (k : FAE-Cont)]
  [CallArgK (arg-expr : FAE)
           (env : Env)
           (k : FAE-Cont)]
  [doCallK (fun-val : FAE-Value)
          (k : FAE-Cont)]
  [doIfK (then-expr : FAE)
         (else-expr : FAE)
         (env : Env)
         (k : FAE-Cont)])

;; ----------------------------------------

;; parse : S-expr -> FAE
;; parse : S-expr -> FAE
(define (parse-error sx)
  (error 'parse (string-append "parse error: " (to-string sx))))

(define (sx-ref sx n) (list-ref (s-exp->list sx) n))

(define (parse sx)
  (local
      [(define (px i) (parse (sx-ref sx i)))]
    (cond
      [(s-exp-number? sx) (Num (s-exp->number sx))]
      [(s-exp-symbol? sx) (Id (s-exp->symbol sx))]
      [(s-exp-match? `(fun (SYMBOL) ANY) sx)
       (let* ([args (sx-ref sx 1)]
              [Id (s-exp->symbol (sx-ref args 0))]
              [body (px 2)])
         (Fun Id body))]
      [(s-exp-match? `(-- ANY) sx) (Sub1 (px 1))]
      [(s-exp-match? `(ANY ANY) sx) (Call (px 0) (px 1))]
      [(s-exp-list? sx)
       (case (s-exp->symbol (sx-ref sx 0))
         [(+) (Add (px 1) (px 2))]
         [(-) (Sub (px 1) (px 2))]
         [(*) (Mul (px 1) (px 2))]
         [(if0) (If0 (px 1) (px 2) (px 3))]
         [else (parse-error sx)])]
      [else (parse-error sx)])))
(module+ test
  (print-only-errors #t)
  (test (parse `3) (Num 3))
  (test (parse `x) (Id 'x))
  (test (parse `{+ 1 2}) (Add (Num 1) (Num 2)))
  (test (parse `{- 1 2}) (Sub (Num 1) (Num 2)))
  (test (parse `{* 1 2}) (Mul (Num 1) (Num 2)))
  (test (parse `{-- 2}) (Sub1 (Num 2)))
  (test (parse `{fun {x} x}) (Fun 'x (Id 'x)))
  (test (parse `{1 2}) (Call (Num 1) (Num 2)))
  (test (parse `{if0 0 1 2}) (If0 (Num 0) (Num 1) (Num 2)))
  (test (parse `{{fun {x}
                      {{fun {f}
                            {+ {f 1}
                               {{fun {x} {f 2}} 3}}}
                       {fun {y} {+ x y}}}}
                 0})
               (Call (Fun 'x
                          (Call (Fun 'f
                                     (Add (Call (Id 'f) (Num 1))
                                          (Call (Fun 'x
                                                     (Call (Id 'f) (Num 2)))
                                                (Num 3))))
                                (Fun 'y (Add (Id 'x) (Id 'y)))))
                     (Num 0)))


  )

;; ----------------------------------------

;; eval : FAE EnvFAE-Cont -> FAE-Value
(define (eval a-fae env k)
  (type-case FAE a-fae
    [(Num n) (continue k (NumV n))]
    [(Add l r) (eval l env (addSecondK r env k))]
    [(Sub l r) (eval l env (subSecondK r env k))]
    [(Sub1 arg) ....]
    [(Mul l r) ....]
    [(Id name) (continue k (lookup name env))]
    [(Fun param body-expr)
     (continue k (ClosureV param body-expr env))]
    [(Call fun-expr arg-expr)
     (eval fun-expr env (CallArgK arg-expr env k))]
    [(If0 test-expr then-expr else-expr)
     (eval test-expr env (doIfK then-expr else-expr env k))]))

(define (continue [k : FAE-Cont] [v : FAE-Value]) : FAE-Value
  (type-case FAE-Cont k
    [(mtK) v]
    [(addSecondK r env next-k)
     (eval r env (doAddK v next-k))]
    [(doAddK v1 next-k)
     (continue next-k (num+ v1 v))]
    [(subSecondK r env next-k)
     (eval r env (doSubK v next-k))]
    [(doSubK v1 next-k)
     (continue next-k (num- v1 v))]
    [(CallArgK arg-expr env next-k)
     (eval arg-expr env (doCallK v next-k))]
    [(doCallK fun-val next-k)
     (eval (ClosureV-body fun-val)
             (aBind (ClosureV-param fun-val) v (ClosureV-env fun-val))
             next-k)]
    [(doIfK then-expr else-expr env next-k)
     (if (numzero? v)
         (eval then-expr env next-k)
         (eval else-expr env next-k))]))

;; num-op : (number number -> number) -> (FAE-Value FAE-Value -> FAE-Value)
(define (num-op op op-name)
  (lambda (x y)
    (NumV (op (NumV-n x) (NumV-n y)))))

(define num+ (num-op + '+))
(define num- (num-op - '-))

;; numzero? : FAE-Value -> boolean
(define (numzero? x)
  (zero? (NumV-n x)))

(define (lookup name env)
  (type-case Env env
    [(mtEnv) (error 'lookup "free variable")]
    [(aBind this-name num rest-env)
          (if (symbol=? this-name name)
              num
              (lookup name rest-env))]))

(define init-k (mtK))

(module+ test
  (test (eval (Num 10)
                (mtEnv)
                init-k)
        (NumV 10))
  (test (eval (Add (Num 10) (Num 7))
                (mtEnv)
                init-k)
        (NumV 17))
  (test (eval (Sub (Num 10) (Num 7))
                (mtEnv)
                init-k)
        (NumV 3))
  (test (eval (Call (Fun 'x (Add (Id 'x) (Num 12)))
                      (Add (Num 1) (Num 17)))
                (mtEnv)
                init-k)
        (NumV 30))
  (test (eval (Id 'x)
                (aBind 'x (NumV 10) (mtEnv))
                init-k)
        (NumV 10))

  (test (eval (Call (Fun 'x (Add (Id 'x) (Num 12)))
                      (Add (Num 1) (Num 17)))
                (mtEnv)
                init-k)
        (NumV 30))


  (test/exn (eval (Id 'x) (mtEnv) init-k)
            "free variable")

  (test/exn
   (eval
    (parse
     `{ {fun {x} {+ x y}} 0})
    (mtEnv)
    init-k)
   "free variable")

  (test
   (eval
    (parse
     `{{fun {x}
         {{fun {f} {f 2}}
          {fun {y} {+ x y}}}}
       0})
    (mtEnv)
    init-k)
   (NumV 2))

  (test (eval
         (parse `{{fun {x}
                      {{fun {f}
                            {+ {f 1}
                               {{fun {x} {f 2}} 3}}}
                       {fun {y} {+ x y}}}}
                 0})
         (mtEnv)
         init-k)
        (NumV 3))

  (test (eval (If0 (Num 0)
                     (Num 1)
                     (Num 2))
                (mtEnv)
                init-k)
        (NumV 1))
  (test (eval (If0 (Num 1)
                     (Num 0)
                     (Num 2))
                (mtEnv)
                init-k)
        (NumV 2))

  (test (eval (parse
                 `{{fun {mkrec}
                        {{fun {fib}
                              ;; Call fib on 4:
                              {fib 4}}
                         ;; Create recursive fib:
                         {mkrec
                          {fun {fib}
                               ;; Fib:
                               {fun {n}
                                    {if0 n
                                         1
                                         {if0 {- n 1}
                                              1
                                              {+ {fib {- n 1}}
                                                 {fib {- n 2}}}}}}}}}}
                   ;; mkrec:
                   {fun {body-proc}
                        {{fun {fX}
                              {fX fX}}
                         {fun {fX}
                              {body-proc {fun {x} {{fX fX} x}}}}}}})
                (mtEnv)
                init-k)
        (NumV 5))

)