UNB/ CS/ David Bremner/ teaching/ cs4613/ lectures/ lecture13/ tfae.rkt
#lang plait
(define-type FAE
  [Num (n : Number)]
  [Bool (b : Boolean)]
  [Not (expr : FAE)]
  [Add (lhs : FAE)
       (rhs : FAE)]
  [Sub (lhs : FAE)
       (rhs : FAE)]
  [Id (name : Symbol)]
  [Lam (param : Symbol)
       (argty : TE)
       (body : FAE)]
  [Call (fun-expr : FAE)
        (arg-expr : FAE)])

(define-type TE
  [NumTE]
  [BoolTE]
  [ArrowTE (arg : TE)
           (result : TE)])

(define-type FAE-Value
  [NumV (n : Number)]
  [BoolV (b : Boolean)]
  [ClosureV (param : Symbol)
            (body : FAE)
            (env : ValueEnv)])
(define-type ValueEnv
  [EmptyValueEnv]
  [BindValue (name : Symbol)
             (value : FAE-Value)
             (rest : ValueEnv)])

(define-type Type
  [NumT]
  [BoolT]
  [ArrowT (arg : Type)
          (result : Type)])

(define-type TypeEnv
  [EmptyTypeEnv]
  [BindType (name : Symbol)
            (type : Type)
            (rest : TypeEnv)])

(define (parse-te sx)
  (cond
    [(s-exp-symbol? sx)
     (case (s-exp->symbol sx)
       [(num) (NumTE)]
       [(bool) (BoolTE)])]
    [(s-exp-match? `(ANY -> ANY) sx)
     (ArrowTE (parse-te (sx-ref sx 0)) (parse-te (sx-ref sx 2)))]))

;; ----------------------------------------
;; parse : S-expr -> FAE
(define (parse-error sx)
  (error 'parse (string-append "parse error: " (to-string sx))))

(define (sx-ref sx n) (list-ref (s-exp->list sx) n))

(define (parse sx)
  (local
      [(define (px i) (parse (sx-ref sx i)))]
    (cond
      [(s-exp-number? sx) (Num (s-exp->number sx))]
      [(s-exp-symbol? sx)
       (let ([sym (s-exp->symbol sx)])
         (case sym
           [(true) (Bool #t)]
           [(false) (Bool #f)]
           [else (Id sym)]))]
      [(s-exp-match? `(lam (SYMBOL : ANY) ANY) sx)
       (let* ([args (sx-ref sx 1)]
              [id (s-exp->symbol (sx-ref args 0))]
              [te (parse-te (sx-ref args 2))]
              [body (px 2)])
         (Lam id te body))]
      [(s-exp-match? `(ANY ANY) sx)
       (cond
         [(equal? (sx-ref sx 0) `not) (Not (px 1))]
         [else (Call (px 0) (px 1))])]
      [(s-exp-list? sx)
       (case (s-exp->symbol (sx-ref sx 0))
         [(+) (Add (px 1) (px 2))]
         [(-) (Sub (px 1) (px 2))]
         [else (parse-error sx)])]
      [else (parse-error sx)])))
(module+ test
  (print-only-errors #t)
  (test (parse `3) (Num 3))
  (test (parse `x) (Id 'x))
  (test (parse `{+ 1 2}) (Add (Num 1) (Num 2)))
  (test (parse `{- 1 2}) (Sub (Num 1) (Num 2)))
  (test (parse `{lam {x : num} x}) (Lam 'x (NumTE) (Id 'x)))
  (test (parse `{f 2}) (Call (Id 'f) (Num 2)))

  (test/exn (parse `"foo") "parse error")
  (test/exn (parse `{foo}) "parse error")
  (test (parse
         `{{lam {x : num}
                {{lam {f : {num -> num}}
                      {+ { f 1}
                         { {lam {x : num}
                                {f 2}}
                           3}}}
                 {lam {y : num} {+ x y}}}}
           0})
        (Call (Lam 'x (NumTE)
                   (Call (Lam 'f (ArrowTE (NumTE) (NumTE))
                              (Add (Call (Id 'f) (Num 1))
                                   (Call (Lam 'x (NumTE)
                                              (Call (Id 'f)
                                                    (Num 2)))
                                         (Num 3))))
                         (Lam 'y (NumTE)
                              (Add (Id 'x) (Id 'y)))))
              (Num 0)))
  )


;; ----------------------------------------

;; eval : FAE ValueEnv -> FAE-Value
(define (eval a-fae env)
  (type-case FAE a-fae
    [(Num n) (NumV n)]
    [(Bool b) (BoolV b)]
    [(Not e) (BoolV (not (BoolV-b (eval e env))))]
    [(Add l r) (num+ (eval l env) (eval r env))]
    [(Sub l r) (num- (eval l env) (eval r env))]
    [(Id name) (lookup name env)]
    [(Lam param arg-te body-expr)
     (ClosureV param body-expr env)]
    [(Call fun-expr arg-expr)
     (local [(define fun-val
               (eval fun-expr env))
             (define arg-val
               (eval arg-expr env))]
       (eval (ClosureV-body fun-val)
             (BindValue (ClosureV-param fun-val)
                        arg-val
                        (ClosureV-env fun-val))))]))

;; num-op : (Number Number -> Number) -> (FAE-Value FAE-Value -> FAE-Value)
(define (num-op op op-name x y)
  (NumV (op (NumV-n x) (NumV-n y))))

(define (num+ x y) (num-op + '+ x y))
(define (num- x y) (num-op - '- x y))

(define (lookup name env)
  (type-case ValueEnv env
    [(EmptyValueEnv) (error 'lookup "free variable")]
    [(BindValue sub-name num rest-env)
     (if (equal? sub-name name)
         num
         (lookup name rest-env))]))

;; ----------------------------------------

(define (type-lookup name-to-find env)
  (type-case TypeEnv env
    [(EmptyTypeEnv) (error 'type-lookup "free variable, so no type")]
    [(BindType name ty rest)
     (if (equal? name-to-find name)
         ty
         (type-lookup name-to-find rest))]))

;; ----------------------------------------

(define (parse-type te)
  (type-case TE te
    [(NumTE) (NumT)]
    [(BoolTE) (BoolT)]
    [(ArrowTE a b) (ArrowT (parse-type a)
                           (parse-type b))]))

(define (type-error fae msg)
  (error 'typecheck (string-append
                     "no type: "
                     (string-append
                      (to-string fae)
                      (string-append " not "
                                     msg)))))

(define (type-assert exprs type env result) : Type
  (cond
    [(empty? exprs) result]
    [(not (equal? (typecheck (first exprs) env) type))
     (type-error (first exprs) (type-to-string type))]
    [else (type-assert (rest exprs) type env result)]))

(define (type-to-string [type : Type])
  (type-case Type type
    [(BoolT) "bool"]
    [(NumT) "num"]
    [else (to-string type)]))

(define (typecheck [fae : FAE] [env : TypeEnv]) : Type
  (type-case FAE fae
    [(Num n) (NumT)]
    [(Bool b) (BoolT)]
    [(Not e) (type-assert (list e) (BoolT) env (BoolT))]
    [(Add l r) (type-assert (list l r) (NumT) env (NumT))]
    [(Sub l r) (type-assert (list l r) (NumT) env (NumT))]
    [(Id name) (type-lookup name env)]
    [(Lam name te body)
     (let* ([arg-type (parse-type te)]
            [body-type (typecheck body (BindType name arg-type env))])
       (ArrowT arg-type body-type))]
    [(Call fn arg)
     (type-case Type (typecheck fn env)
       [(ArrowT arg-type result-type)
        (type-assert (list arg) arg-type env result-type)]
       [else (type-error fn "function")])]))

;; ----------------------------------------


(define (run s-expr)
  (eval (parse s-expr) (EmptyValueEnv)))

(define (check s-expr)
  (typecheck (parse s-expr) (EmptyTypeEnv)))

(define-syntax-rule (test/type expr type) (test (check expr) type))
(define-syntax-rule (test/notype expr) (test/exn (check expr) "no type"))

(module+ test
  (print-only-errors #t)

  (test (check `{+ 1 2}) (NumT))
  (test/exn (run `x) "free variable")
  (test (run `{not false}) (BoolV #t))
  (test (run `10) (NumV 10))
  (test (run `{+ 10 17}) (NumV 27))
  (test (run `{- 10 7}) (NumV 3))
  (test (run `{{lam {x : num} {+ x 12}} {+ 1 17}}) (NumV 30))

  (test (eval (Id 'x)
              (BindValue 'x (NumV 10) (EmptyValueEnv)))
        (NumV 10))

  (define lam-lam
    `{{lam {x : num}
           {{lam {f : (num -> num)}
                 {+ {f 1}
                    {{lam {x : num}
                          {f 2}}
                     3}}}
            {lam {y : num}
                 {+ x y}}}}
      0})

  (test (run lam-lam)
        (NumV 3))
  )

(module+ test
  (print-only-errors #t)

  (test/notype `x)

  (test/notype `{not 1})

  (test/type `{not false} (BoolT))

  (test/type lam-lam (NumT))

  (test/type  `10 (NumT))

  (test/type `{+ 10 17} (NumT))
  (test/type `{- 10 7} (NumT))

  (test/notype `{+ false 17})
  (test/notype `{- false 17})
  (test/notype `{+ 17 false})
  (test/notype `{- 17 false})

  (test/type `{lam {x : num} {+ x 12}}
             (ArrowT (NumT) (NumT)))

  (test/notype `{{lam {x : num} x} true})

  (test/type `{lam {x : num} {lam {y : bool} x}}
             (ArrowT (NumT) (ArrowT (BoolT)  (NumT))))

  (test/type `{{lam {x : num} {+ x 12}} {+ 1 17}} (NumT))

  (test/notype `{1 2})

  (test/notype `{+ {lam {x : num}  12} 2})

  ;; Added coverage test for type-to-string
  (test/notype `{{lam {f : {num -> num}}
                      {f 1}}
                 1})
  )