UNB/ CS/ David Bremner/ teaching/ cs4613/ lectures/ lecture13/ tpfae.rkt
#lang plait
(define-type TPFAE
  [Num (n : Number)]
  [Bool (b : Boolean)]
  [Not (expr : TPFAE)]
  [Add (lhs : TPFAE)
       (rhs : TPFAE)]
  [Sub (lhs : TPFAE)
       (rhs : TPFAE)]
  [Id (name : Symbol)]
  [Pair (left : TPFAE)  (right : TPFAE)]
  [Fst (v : TPFAE)]
  [Snd (v : TPFAE)]
  [Lam (param : Symbol)
       (argty : TE)
       (body : TPFAE)]
  [Call (fun-expr : TPFAE)
        (arg-expr : TPFAE)])

(define-type TE
  [NumTE]
  [BoolTE]
  [PairTE (left : TE)
          (right : TE)]
  [ArrowTE (arg : TE)
           (result : TE)])

(define-type FAE-Value
  [NumV (n : Number)]
  [BoolV (b : Boolean)]
  [PairV (l : FAE-Value) (r : FAE-Value)]
  [ClosureV (param : Symbol)
            (body : TPFAE)
            (env : ValueEnv)])

(define-type ValueEnv
  [EmptyValueEnv]
  [BindValue (name : Symbol)
             (value : FAE-Value)
             (rest : ValueEnv)])

(define-type Type
  [NumT]
  [BoolT]
  [IdT (name : Symbol)]
  [PairT (left : Type) (right : Type)]
  [ArrowT (arg : Type)
          (result : Type)])

(define-type TypeEnv
  [EmptyTypeEnv]
  [BindType (name : Symbol)
            (type : Type)
            (rest : TypeEnv)])

;; ----------------------------------------

;; eval : TPFAE ValueEnv -> FAE-Value
(define (eval a-fae env)
  (type-case TPFAE a-fae
    [(Num n) (NumV n)]
    [(Bool b) (BoolV b)]
    [(Not e) (BoolV (not (BoolV-b (eval e env))))]
    [(Add l r) (num+ (eval l env) (eval r env))]
    [(Sub l r) (num- (eval l env) (eval r env))]
    [(Pair l r) (PairV (eval l env) (eval r env))]
    [(Fst ex)
     (type-case FAE-Value (eval ex env)
       [(PairV l r) l]
       [else (type-error 'eval "not a pair")])]
    [(Snd ex)
     (type-case FAE-Value (eval ex env)
       [(PairV l r) r]
       [else (type-error 'eval "not a pair")])]    [(Id name) (lookup name env)]
    [(Lam param arg-te body-expr)
     (ClosureV param body-expr env)]
    [(Call fun-expr arg-expr)
     (local [(define fun-val
               (eval fun-expr env))
             (define arg-val
               (eval arg-expr env))]
       (eval (ClosureV-body fun-val)
             (BindValue (ClosureV-param fun-val)
                        arg-val
                        (ClosureV-env fun-val))))]))

;; num-op : (Number Number -> Number) -> (FAE-Value FAE-Value -> FAE-Value)
(define (num-op op op-name x y)
  (NumV (op (NumV-n x) (NumV-n y))))

(define (num+ x y) (num-op + '+ x y))
(define (num- x y) (num-op - '- x y))

(define (lookup name env)
  (type-case ValueEnv env
    [(EmptyValueEnv) (error 'lookup "free variable")]
    [(BindValue sub-name num rest-env)
     (if (equal? sub-name name)
         num
         (lookup name rest-env))]))

;; ----------------------------------------

(define (type-lookup name-to-find env)
  (type-case TypeEnv env
    [(EmptyTypeEnv) (error 'type-lookup "free variable, so no type")]
    [(BindType name ty rest)
     (if (equal? name-to-find name)
         ty
         (type-lookup name-to-find rest))]))

;; ----------------------------------------

(define (parse-type te)
  (type-case TE te
    [(NumTE) (NumT)]
    [(BoolTE) (BoolT)]
    [(ArrowTE a b) (ArrowT (parse-type a)
                           (parse-type b))]
    [(PairTE  l r) (PairT (parse-type l) (parse-type r))]))

(define (type-error fae msg)
  (error 'typecheck (string-append
                     "no type: "
                     (string-append
                      (to-string fae)
                      (string-append " not "
                                     msg)))))

(define (type-assert exprs type env result) : Type
  (cond
    [(empty? exprs) result]
    [(not (equal? (typecheck (first exprs) env) type))
     (type-error (first exprs) (type-to-string type))]
    [else (type-assert (rest exprs) type env result)]))

(define (type-to-string [type : Type])
  (type-case Type type
    [(BoolT) "bool"]
    [(NumT) "num"]
    [else (to-string type)]))

(define (typecheck [fae : TPFAE] [env : TypeEnv]) : Type
  (type-case TPFAE fae
    [(Num n) (NumT)]
    [(Bool b) (BoolT)]
    [(Not e) (type-assert (list e) (BoolT) env (BoolT))]
    [(Add l r) (type-assert (list l r) (NumT) env (NumT))]
    [(Sub l r) (type-assert (list l r) (NumT) env (NumT))]
    [(Id name) (type-lookup name env)]
    [(Pair l r) (PairT (typecheck l env) (typecheck r env))]
    [(Fst ex)
     (type-case Type (typecheck ex env)
       [(PairT l r) l]
       [else (type-error ex "not a pair")])]
    [(Snd ex)
     (type-case Type (typecheck ex env)
       [(PairT l r) r]
       [else (type-error ex "not a pair")])]
    [(Lam name te body)
     (let* ([arg-type (parse-type te)]
            [body-type (typecheck body (BindType name arg-type env))])
       (ArrowT arg-type body-type))]
    [(Call fn arg)
     (type-case Type (typecheck fn env)
       [(ArrowT arg-type result-type)
        (type-assert (list arg) arg-type env result-type)]
       [else (type-error fn "function")])]))

;; ----------------------------------------

(define (parse-error sx)
  (error 'parse (string-append "parse error: " (to-string sx))))

(define (sx-ref sx n) (list-ref (s-exp->list sx) n))
(define (parse sx)
  (local
      [(define (px i) (parse (sx-ref sx i)))]
    (cond
      [(s-exp-number? sx) (Num (s-exp->number sx))]
      [(s-exp-symbol? sx)
       (let ([sym (s-exp->symbol sx)])
         (case sym
           [(true) (Bool #t)]
           [(false) (Bool #f)]
           [else (Id sym)]))]
      [(s-exp-match? `(lam (SYMBOL : ANY) ANY) sx)
       (let* ([args (sx-ref sx 1)]
              [id (s-exp->symbol (sx-ref args 0))]
              [te (parse-te (sx-ref args 2))]
              [body (px 2)])
         (Lam id te body))]
      [(s-exp-match? `(ANY ANY) sx)
       (cond
         [(equal? (sx-ref sx 0) `not) (Not (px 1))]
         [(equal? (sx-ref sx 0) `fst) (Fst (px 1))]
         [(equal? (sx-ref sx 0) `snd) (Snd (px 1))]
         [else (Call (px 0) (px 1))])]
      [(s-exp-list? sx)
       (case (s-exp->symbol (sx-ref sx 0))
         [(+) (Add (px 1) (px 2))]
         [(-) (Sub (px 1) (px 2))]
         [(pair) (Pair (px 1) (px 2))]
         [else (parse-error sx)])]
      [else (parse-error sx)])))

(define (parse-te sx)
  (cond
    [(s-exp-symbol? sx)
     (case (s-exp->symbol sx)
       [(num) (NumTE)]
       [(bool) (BoolTE)])]
    [(s-exp-match? `(ANY * ANY) sx)
     (PairTE (parse-te (sx-ref sx 0)) (parse-te (sx-ref sx 2)))]
    [(s-exp-match? `(ANY -> ANY) sx)
     (ArrowTE (parse-te (sx-ref sx 0)) (parse-te (sx-ref sx 2)))]))

(define (run s-expr)
  (eval (parse s-expr) (EmptyValueEnv)))

(define (check s-expr)
  (typecheck (parse s-expr) (EmptyTypeEnv)))

(define-syntax-rule (test/type expr type) (test (check expr) type))
(define-syntax-rule (test/notype expr) (test/exn (check expr) "no type"))

(module+ test
  (print-only-errors #t)

  (test (check `{+ 1 2}) (NumT))
  (test/exn (run `x) "free variable")
  (test (run `{not false}) (BoolV #t))
  (test (run `10) (NumV 10))
  (test (run `{+ 10 17}) (NumV 27))
  (test (run `{- 10 7}) (NumV 3))
  (test (run `{{lam {x : num} {+ x 12}} {+ 1 17}}) (NumV 30))

  (test (eval (Id 'x)
              (BindValue 'x (NumV 10) (EmptyValueEnv)))
        (NumV 10))

  (define lam-lam
    `{{lam {x : num}
           {{lam {f : (num -> num)}
                 {+ {f 1}
                    {{lam {x : num}
                          {f 2}}
                     3}}}
            {lam {y : num}
                 {+ x y}}}}
      0})

  (test (run lam-lam)
        (NumV 3))
  )

(module+ test
  (print-only-errors #t)

  (test/notype `x)

  (test/notype `{not 1})

  (test/type `{not false} (BoolT))

  (test/type lam-lam (NumT))

  (test/type  `10 (NumT))

  (test/type `{+ 10 17} (NumT))
  (test/type `{- 10 7} (NumT))

  (test/notype `{+ false 17})
  (test/notype `{- false 17})
  (test/notype `{+ 17 false})
  (test/notype `{- 17 false})

  (test/type `{lam {x : num} {+ x 12}}
             (ArrowT (NumT) (NumT)))

  (test/notype `{{lam {x : num} x} true})

  (test/type `{lam {x : num} {lam {y : bool} x}}
             (ArrowT (NumT) (ArrowT (BoolT)  (NumT))))

  (test/type `{{lam {x : num} {+ x 12}} {+ 1 17}} (NumT))

  (test/notype `{1 2})

  (test/notype `{+ {lam {x : num}  12} 2})

  ;; Added coverage test for type-to-string
  (test/notype `{{lam {f : {num -> num}}
                      {f 1}}
                 1})
  )

(module+ test
  ;; Coverage tests for pairs
  (define test-pair-ex1  `{pair 1 false})

  (define pair-lam `{lam {x : {num * num}} {fst x}})

  (test (check test-pair-ex1) (PairT (NumT) (BoolT)))
  (test (run test-pair-ex1) (PairV (NumV 1) (BoolV #f)))
  (test (check `{fst ,test-pair-ex1}) (NumT))
  (test (check `{snd ,test-pair-ex1}) (BoolT))
  (test (run `{fst ,test-pair-ex1}) (NumV 1))
  (test (run `{snd ,test-pair-ex1}) (BoolV #f))
  (test (run `{fst ,test-pair-ex1}) (NumV 1))
  (test/exn (run `{fst {fst ,test-pair-ex1}}) "not a pair")
  (test/exn (run `{snd {snd ,test-pair-ex1}}) "not a pair")
  (test (check pair-lam) (ArrowT (PairT (NumT) (NumT)) (NumT)))
  (test/exn (check `{,pair-lam 1}) "not (PairT (NumT) (NumT))")
)