UNB/ CS/ David Bremner/ teaching/ cs4613/ assignments/ A3/ skeleton.rkt
#lang plait
(define-type Exp
  [numE (n : Number)]
  [boolE (b : Boolean)]
  [notE (expr : Exp)]
  [plusE (lhs : Exp) (rhs : Exp)]
  [minusE (lhs : Exp) (rhs : Exp)]
  [timesE (lhs : Exp) (rhs : Exp)]
  [listE (elements : (Listof Exp))] ;; New
  [if0E (test-expr : Exp) (then-expr : Exp) (else-expr : Exp)]
  [recE (name : Symbol) (ty : TE) (rhs-expr : Exp) (body-expr : Exp)]
  [idE (name : Symbol)]
  [lamE (param : Symbol) (argty : TE) (body : Exp)]
  [appE (lam-expr : Exp) (arg-expr : Exp)])

(define-type Value
  [numV (n : Number)]
  [boolV (b : Boolean)]
  [listV (elements : (Listof Value))]
  [closureV (param : Symbol)
            (body : Exp)
            (env : ValueEnv)])

(define-type TE
  [numTE]
  [boolTE]
  [arrowTE (arg : TE) (result : TE)]
  [listTE (element : TE)] ;; New
  [guessTE])

(define-type Type
  [numT]
  [boolT]
  [arrowT (arg : Type) (result : Type)]
  [listT (element : Type)] ;; New
  [varT (id : Number) (val : (Boxof (Optionof Type)))])

(define-type ValueEnv
  [EmptyValueEnv]
  [BindValue (name : Symbol)
             (value : Value)
             (rest : ValueEnv)]
  [RecBindValue (name : Symbol)
                (value-box : (Boxof Value))
                (rest : ValueEnv)])

(define-type TypeEnv
  [EmptyTypeEnv]
  [BindType (name : Symbol)
            (type : Type)
            (rest : TypeEnv)])

;; num-op : (Number Number -> Number) -> (Value Value -> Value)
(define (num-op op op-name x y)
  (numV (op (numV-n x) (numV-n y))))
(define (num+ x y) (num-op + '+ x y))
(define (num- x y) (num-op - '- x y))
(define (num* x y) (num-op * '* x y))
(define (numzero? x) (= 0 (numV-n x)))

;; interp : Exp Env -> Value
(define (interp a-exp env)
  (type-case Exp a-exp
    [(numE n) (numV n)]
    [(boolE b) (boolV b)]
    [(notE e) (boolV (not (boolV-b (interp e env))))]
    [(plusE l r) (num+ (interp l env) (interp r env))]
    [(minusE l r) (num- (interp l env) (interp r env))]
    [(timesE l r) (num* (interp l env) (interp r env))]
    [(listE elements) ....]
    [(idE name) (lookup name env)]
    [(if0E test then-part else-part)
     (if (numzero? (interp test env))
         (interp then-part env)
         (interp else-part env))]
    [(recE bound-id type named-expr body-expr)
     (let* ([value-holder (box (numV 42))]
            [new-env (RecBindValue bound-id value-holder env)])
       (begin
         (set-box! value-holder (interp named-expr new-env))
         (interp body-expr new-env)))]
    [(lamE param arg-te body-expr)
     (closureV param body-expr env)]
    [(appE lam-expr arg-expr)
     (local [(define lam-val
               (interp lam-expr env))
             (define arg-val
               (interp arg-expr env))]
       (interp (closureV-body lam-val)
             (BindValue (closureV-param lam-val)
                        arg-val
                        (closureV-env lam-val))))]))

(define (run s-expr)
  (interp (parse s-expr) (EmptyValueEnv)))

(define (check s-expr)
  (typecheck (parse s-expr) (EmptyTypeEnv)))

(define (lookup name env)
  (type-case ValueEnv env
    [(EmptyValueEnv) (error 'lookup "free variable")]
    [(BindValue sub-name num rest-env)
     (if (equal? sub-name name)
         num
         (lookup name rest-env))]
    [(RecBindValue sub-name val-box rest-env)
     (if (equal? sub-name name)
         (unbox val-box)
         (lookup name rest-env))]))

(define (parse-error sx)
  (error 'parse (string-append "parse error: " (to-string sx))))

(define (sx-ref sx n) (list-ref (s-exp->list sx) n))

(define (parse sx)
  (local
      [(define (px i) (parse (sx-ref sx i)))]
    (cond
      [(s-exp-number? sx) (numE (s-exp->number sx))]
      [(s-exp-symbol? sx)
       (let ([sym (s-exp->symbol sx)])
         (case sym
           [(true) (boolE #t)]
           [(false) (boolE #f)]
           [else (idE sym)]))]
      [(s-exp-match? `(list ANY ...) sx)
       (listE (map (lambda (elt) (parse elt)) (rest (s-exp->list sx))))]
      [(s-exp-match? `(lam (SYMBOL : ANY) ANY) sx)
       (let* ([args (sx-ref sx 1)]
              [id (s-exp->symbol (sx-ref args 0))]
              [te (parse-te (sx-ref args 2))]
              [body (px 2)])
         (lamE id te body))]
      [(s-exp-match? `(rec (SYMBOL : ANY) ANY ANY) sx)
       (let* ([args (sx-ref sx 1)]
              [id (s-exp->symbol (sx-ref args 0))]
              [te (parse-te (sx-ref args 2))]
              [rhs (px 2)]
              [body (px 3)])
         (recE id te rhs body))]
      [(s-exp-match? `(ANY ANY) sx)
       (cond
         [(equal? (sx-ref sx 0) `not) (notE(px 1))]
         [else (appE (px 0) (px 1))])]
      [(s-exp-list? sx)
       (case (s-exp->symbol (sx-ref sx 0))
         [(+) (plusE (px 1) (px 2))]
         [(-) (minusE (px 1) (px 2))]
         [(*) (timesE (px 1) (px 2))]
         [(if0) (if0E (px 1) (px 2) (px 3))]
         [else (parse-error sx)])]
      [else (parse-error sx)])))

(define (parse-te sx)
  (cond
    [(s-exp-symbol? sx)
     (case (s-exp->symbol sx)
       [(num) (numTE)]
       [(bool) (boolTE)]
       [(?) (guessTE)])]
    [(s-exp-match? `(ANY -> ANY) sx)
     (arrowTE (parse-te (sx-ref sx 0)) (parse-te (sx-ref sx 2)))]
    [(s-exp-match? `(listof ANY) sx)
     (listTE (parse-te (sx-ref sx 1)))]))

(module+ test
  (define fact-rec
    (recE 'fact (arrowTE (numTE) (numTE))
         (lamE 'n (numTE)
              (if0E (idE 'n)
                   (numE 1)
                   (timesE (idE 'n) (appE (idE 'fact) (minusE (idE 'n) (numE 1))))))
         (appE (idE 'fact) (numE 5))))

  (define fact-rec-concrete
    `{rec {fact : {num -> num}}
          {lam {n : num}
               {if0 n 1
                    {* n {fact {- n 1}}}}}
          {fact 5}})

  (define fib-rec
    (recE 'fib (arrowTE (numTE) (numTE))
         (lamE 'x (numTE)
              (if0E (idE 'x)
                   (numE 1)
                   (if0E (minusE (idE 'x) (numE 1))
                        (numE 1)
                        (plusE (appE (idE 'fib) (minusE (idE 'x) (numE 1)))
                             (appE (idE 'fib) (minusE (idE 'x) (numE 2)))))))
         (appE (idE 'fib) (numE 4))))

  (define fib-rec-concrete
    `{rec {fib : {num -> num}}
          {lam {x : num}
               {if0 x 1
                    {if0 {- x 1}
                         1
                         {+ {fib {- x 1}}
                            {fib {- x 2}}}}}}
          {fib 4}})
  )

(module+ test
  (print-only-errors #t)
  (test (parse `3) (numE 3))
  (test (parse `x) (idE 'x))
  (test (parse `{+ 1 2}) (plusE (numE 1) (numE 2)))
  (test (parse `{- 1 2}) (minusE (numE 1) (numE 2)))
  (test (parse `{lam {x : num} x}) (lamE 'x (numTE) (idE 'x)))
  (test (parse `{f 2}) (appE (idE 'f) (numE 2)))
  (test (parse `{if0 0 1 2}) (if0E (numE 0) (numE 1) (numE 2)))

  (test/exn (parse `"foo") "parse error")
  (test/exn (parse `{foo}) "parse error")
  (test (parse
         `{{lam {x : num}
                {{lam {f : {num -> num}}
                      {+ {f 1}
                         {{lam {x : num}
                               {f 2}}
                          3}}}
                 {lam {y : num} {+ x y}}}}
           0})
        (appE (lamE 'x (numTE)
                   (appE (lamE 'f (arrowTE (numTE) (numTE))
                              (plusE (appE (idE 'f) (numE 1))
                                   (appE (lamE 'x (numTE)
                                              (appE (idE 'f)
                                                    (numE 2)))
                                         (numE 3))))
                         (lamE 'y (numTE)
                              (plusE (idE 'x) (idE 'y)))))
              (numE 0)))

  (test (parse fib-rec-concrete) fib-rec))

(define (parse-type te)
  (type-case TE te
    [(numTE) (numT)]
    [(boolTE) (boolT)]
    [(arrowTE a b) (arrowT (parse-type a)
                           (parse-type b))]
    [(guessTE)(varT (gen-tvar-id!) (box (none)))]
    [(listTE element-te) (listT (parse-type element-te))]))
(define (type-lookup name-to-find env)
  (type-case TypeEnv env
    [(EmptyTypeEnv ) (error 'type-lookup "free variable, so no type")]
    [(BindType name ty rest)
     (if (symbol=? name-to-find name)
         ty
         (type-lookup name-to-find rest))]))

(define gen-tvar-id!
(let ((counter 0))
  (lambda ()
    (begin
      (set! counter (add1 counter))
      counter))))

(define (resolve t)
  (type-case Type t
    [(varT id val)
     (type-case (Optionof Type) (unbox val)
       [(none) t]
       [(some t2) (resolve t2)])]
    [else t]))

(define (uses-type-var? id t)
  (type-case Type (resolve t)
    [(varT t-id val) (= id t-id)]
    [(arrowT a b)
     (or (uses-type-var? id a)
         (uses-type-var? id b))]
    [else #f]))

(define (occurs? r t)
  (type-case Type r
    [(varT id val)
     (type-case Type (resolve t)
       [(arrowT a b) (uses-type-var? id t)]
       [else #f])]
    [else (expected-type-var 'occurs? r)]))

(define (type-error exp t1 t2)
  (error 'typecheck (string-append
                     "no type: "
                     (string-append
                      (to-string exp)
                      (string-append
                       " type "
                       (string-append
                        (to-string t1)
                        (string-append
                         " vs. "
                         (to-string t2))))))))

(define (expected-type-var where type)
  (error where (string-append "not a type variable " (to-string type))))

(define (unify-type-var! T tau2 expr)
  (type-case Type T
    [(varT id val)
     (type-case (Optionof Type) (unbox val)
       [(some tau1) (unify! tau1 tau2 expr)]
       [(none)
        (let ([t3 (resolve tau2)])
          (cond
            [(equal? T t3) (void)] ;; nothing to unify, same type variables
            [(occurs? T t3)  (type-error expr T t3)]
            [else  (set-box! val (some t3))]))])]
    [else (expected-type-var 'unify-type-var! T)]))

(define (unify-assert! tau type-val expr)
  (unless (equal? tau type-val)
    (type-error expr tau type-val)))

;; third argument is just for error reporting
(define (unify! t1 t2 expr)
  (type-case Type t1
    [(varT id is1) (unify-type-var! t1 t2 expr)]
    [else
     (type-case Type t2
       [(varT id2 is2) (unify-type-var! t2 t1 expr)]
       [(numT) (unify-assert! t1 (numT) expr)]
       [(boolT) (unify-assert! t1 (boolT) expr)]
       [(listT element-type)
        (type-case Type t1
          [(listT other-element-type) (unify! element-type other-element-type expr)]
          [else (type-error expr t1 t2)])]
       [(arrowT a2 b2)
        (type-case Type t1
          [(arrowT a1 b1)
           (begin
             (unify! a1 a2 expr)
             (unify! b1 b2 expr))]
          [else (type-error expr t1 t2)])])]))

(module+ test
  (test (unify! (listT (boolT)) (listT (boolT)) 'test) (void))
  (test/exn (unify! (listT (boolT)) (listT (numT)) 'test) "no type")
  (test/exn (unify! (boolT) (listT (numT)) 'test) "no type")
  (test/exn (unify! (listT (numT)) (boolT) 'test) "no type"))

(define (typecheck [exp : Exp] [env : TypeEnv]) : Type
  (type-case Exp exp
    [(numE n) (numT)]
    [(boolE b) (boolT)]
    [(notE ex) (begin
                 (unify! (typecheck ex env) (boolT) ex)
                 (boolT))]
    [(timesE l r) (begin
                    (unify! (typecheck l env) (numT) l)
                    (unify! (typecheck r env) (numT) r)
                    (numT))]
    [(plusE l r) (begin
                   (unify! (typecheck l env) (numT) l)
                   (unify! (typecheck r env) (numT) r)
                   (numT))]
    [(minusE l r) (begin
                    (unify! (typecheck l env) (numT) l)
                    (unify! (typecheck r env) (numT) r)
                    (numT))]
    [(if0E test-expr then-expr else-expr)
     (let ([test-ty (typecheck test-expr env)]
           [then-ty (typecheck then-expr env)]
           [else-ty (typecheck else-expr env)])
       (begin
         (unify! test-ty (numT) test-expr)
         (unify! then-ty else-ty else-expr)
         then-ty))]
    [(idE name) (type-lookup name env)]
    [(recE name ty rhs-expr body-expr)
     (let* ([type-ann (parse-type ty)]
            [new-env (BindType name type-ann env)]
            [rhs-ty (typecheck rhs-expr new-env)])
       (begin
         (unify! type-ann rhs-ty rhs-expr)
         (typecheck body-expr new-env)))]
    [(appE fn arg)
     (let ([r-type (varT (gen-tvar-id!) (box (none)))]
           [a-type (typecheck arg env)]
           [fn-type (typecheck fn env)])
       (begin
         (unify! (arrowT a-type r-type) fn-type fn)
         r-type))]
    [(lamE name te body)
     (let* ([arg-type (parse-type te)]
            [res-type (typecheck body (BindType name arg-type env))])
       (arrowT arg-type res-type))]
    [(listE elements) ....]
    ))
;; ----------------------------------------

(define-syntax-rule (test/type expr type)
  (test
   (begin (unify! (check expr) type expr) type)
   type))

(define-syntax-rule (test/notype expr) (test/exn (check expr) "no type"))

(module+ test
  (print-only-errors #t)

  (test (check `{+ 1 2}) (numT))
  (test/exn (run `x) "free variable")
  (test (run `{not false}) (boolV #t))
  (test (run `10) (numV 10))
  (test (run `{+ 10 17}) (numV 27))
  (test (run `{- 10 7}) (numV 3))
  (test (run `{{lam {x : num} {+ x 12}} {+ 1 17}}) (numV 30))

  (test (interp (idE 'x)
              (BindValue 'x (numV 10) (EmptyValueEnv)))
        (numV 10))

  (define lam-lam
    `{{lam {x : num}
           {{lam {f : (num -> num)}
                 {+ {f 1}
                    {{lam {x : num}
                          {f 2}}
                     3}}}
            {lam {y : num}
                 {+ x y}}}}
      0})

  (test (run lam-lam)
        (numV 3))
  )

(module+ test
  (test (run `{list 1 2}) (listV (list (numV 1) (numV 2))))
  (test (run `{list false true false}) (listV (list (boolV #f) (boolV #t) (boolV #f)))))

(module+ test
  (print-only-errors #t)

  (test/notype `x)

  (test/notype `{not 1})

  (test/type `{not false} (boolT))

  (test/type lam-lam (numT))

  (test/type  `10 (numT))

  (test/type `{+ 10 17} (numT))
  (test/type `{- 10 7} (numT))

  (test/notype `{+ false 17})
  (test/notype `{- false 17})
  (test/notype `{+ 17 false})
  (test/notype `{- 17 false})

  (test/type `{lam {x : num} {+ x 12}}
             (arrowT (numT) (numT)))

  (test/notype `{{lam {x : num} x} true})

  (test/type `{lam {x : num} {lam {y : bool} x}}
             (arrowT (numT) (arrowT (boolT)  (numT))))

  (test/type `{{lam {x : num} {+ x 12}} {+ 1 17}} (numT))

  (test/notype `{1 2})

  (test/notype `{+ {lam {x : num}  12} 2})

  ;; Added coverage test for type-to-string
  (test/notype `{{lam {f : {num -> num}}
                      {f 1}}
                 1})
  )

(module+ test
  (print-only-errors #t)
  ;; Tests for if0
  (test (run `{if0 0 1 0}) (numV 1))
  (test (run `{if0 1 1 0}) (numV 0))

  (test/type `{if0 0 1 0} (numT))
  (test/type `{if0 1 1 0} (numT))

  (test/notype `{if0 0 {lam {x : num} x} 0})
  (test/notype `{if0 {lam {x : num} x} 0 0})
  )

(module+ test
  (print-only-errors #t)

  ;; Tests for Rec
  (test (parse fact-rec-concrete) fact-rec)

  (test/type fib-rec-concrete (numT))
  (test (interp fib-rec (EmptyValueEnv)) (numV 5))

  (test/type fact-rec-concrete (numT))
  (test (interp fact-rec (EmptyValueEnv)) (numV 120))

  (test/notype `{rec {x : num} {lam {y : num} 3} 10})

  ;; Contrived test to get full coverage of lookup
  (test (interp (recE 'x (numTE)
                   (numE 10)
                   (recE 'y (numTE)
                        (numE 10)
                        (idE 'x)))
              (EmptyValueEnv))
        (numV 10)))

(module+ test

  (test/type `{{lam {x : ?} {+ x 12}} {+ 1 17}} (numT))

  ;; illustrate that the return of our typecheck function can be a bit messy
  (define wrapped-type (check `{{lam {x : ?} {+ x 12}} {+ 1 17}}))
  (test (varT? wrapped-type) #t)
  (test (varT-val wrapped-type) (box (some (numT))))

  (test/type  `{lam {x : ?} {+ x 12}} (arrowT (numT) (numT)))

  (test/type  `{lam {x : ?} {if0 0 x x}} (arrowT (numT) (numT)))

  ;; coverage for occurs check
  (test/notype `{lam {x : ?} {x x}})
  (test
   (let ([T (varT (gen-tvar-id!) (box (none)))])
     (occurs? T (arrowT (boolT) T))) #t)
  (test/exn (occurs? (boolT) (arrowT (boolT) (numT))) "not a type variable")

  ;; coverage for unify-type-var
  (test/exn (unify-type-var! (boolT) (boolT) 'x) "not a type variable")

  (test/exn (unify! (typecheck (lamE 'x (guessTE) (plusE (idE 'x) (numE 12)))
                               (EmptyTypeEnv))
                    (arrowT (boolT) (numT))
                    (numE -1))
            "no type")

  ;; soundness bug still exists
  #;(test/exn (typecheck (recE 'f (arrowTE (numTE) (numTE)) (idE 'f) (appE (idE 'f) (numE 10)))
                         (EmptyTypeEnv))
              "no type"))
(module+ test
  ;; lists of numbers
  (test/type `{list 1 2} (listT (numT)))

  ;; report error for mixed types
  (test/notype  `{list 1 true})

  ;; infer type of list
  (test/type `{lam {x : num} {list x}} (arrowT (numT) (listT (numT))))

  ;; functions taking list parameters
  (test/type `{lam {x : {listof num}} x}
             (arrowT (listT (numT)) (listT (numT))))

  ;; report error for mixed inferred types
  (test/notype `{lam {x : num} {list true x}})

  ;; infer type of function parameter from list element
  (test/type `{lam {x : ?} {list x 1}} (arrowT (numT) (listT (numT))))

  ;; complain about cyclic type (Y-combinator) inside list
  (test/notype `{lam {x : ?} {list {x x}}})

  ;; infer type of list from function application
  (test/type `{{lam {x : ?} {list x}} 2} (listT (numT)))
)