UNB/ CS/ David Bremner/ teaching/ cs3613/ lectures/ trcfae-t.rkt
#lang plai-typed

(define-type FAE
  [num (n : number)]
  [add (lhs : FAE)
       (rhs : FAE)]
  [sub (lhs : FAE)
       (rhs : FAE)]
  [id (name : symbol)]
  [fun (param : symbol)
       (argty : TE)
       (body : FAE)]
  [call (fun-expr : FAE)
       (arg-expr : FAE)]
  [if0 (test-expr : FAE)
       (then-expr : FAE)
       (else-expr : FAE)]
  [rec (name : symbol)
    (ty : TE)
    (rhs-expr : FAE)
    (body-expr : FAE)])

(define-type TE
  [numTE]
  [boolTE]
  [arrowTE (arg : TE)
           (result : TE)])

(define-type FAE-Value
  [numV (n : number)]
  [closureV (param : symbol)
            (body : FAE)
            (ds : DefrdSub)])

(define-type DefrdSub
  [mtSub]
  [aSub (name : symbol)
        (value : FAE-Value)
        (rest : DefrdSub)]
  [aRecSub (name : symbol)
           (value-box : (boxof FAE-Value))
           (rest : DefrdSub)])

(define-type Type
  [numT]
  [boolT]
  [arrowT (arg : Type)
          (result : Type)])

(define-type TypeEnv
  [mtEnv]
  [aBind (name : symbol)
         (type : Type)
         (rest : TypeEnv)])

;; ----------------------------------------

;; interp : FAE DefrdSub -> FAE-Value
(define (interp a-fae ds)
  (type-case FAE a-fae
    [num (n) (numV n)]
    [add (l r) (num+ (interp l ds) (interp r ds))]
    [sub (l r) (num- (interp l ds) (interp r ds))]
    [id (name) (lookup name ds)]
    [fun (param arg-te body-expr)
         (closureV param body-expr ds)]
    [call (fun-expr arg-expr)
         (local [(define fun-val
                   (interp fun-expr ds))
                 (define arg-val
                   (interp arg-expr ds))]
           (interp (closureV-body fun-val)
                   (aSub (closureV-param fun-val)
                         arg-val
                         (closureV-ds fun-val))))]
    [if0 (test-expr then-expr else-expr)
         (if (numzero? (interp test-expr ds))
             (interp then-expr ds)
             (interp else-expr ds))]
    [rec (bound-id type named-expr body-expr)
      (local [(define value-holder (box (numV 42)))
              (define new-ds (aRecSub bound-id
                                      value-holder
                                      ds))]
        (begin
          (set-box! value-holder (interp named-expr new-ds))
          (interp body-expr new-ds)))]))


;; num-op : (number number -> number) -> (FAE-Value FAE-Value -> FAE-Value)
(define (num-op op op-name x y)
  (numV (op (numV-n x) (numV-n y))))

(define (num+ x y) (num-op + '+ x y))
(define (num- x y) (num-op - '- x y))

(define (numzero? x) (= 0 (numV-n x)))

(define (lookup name ds)
  (type-case DefrdSub ds
    [mtSub () (error 'lookup "free variable")]
    [aSub (sub-name val rest-ds)
          (if (symbol=? sub-name name)
              val
              (lookup name rest-ds))]
    [aRecSub (sub-name val-box rest-ds)
             (if (symbol=? sub-name name)
                 (unbox val-box)
                 (lookup name rest-ds))]))


;; ----------------------------------------

(define (get-type name-to-find env)
  (type-case TypeEnv env
    [mtEnv () (error 'get-type "free variable, so no type")]
    [aBind (name ty rest)
           (if (symbol=? name-to-find name)
               ty
               (get-type name-to-find rest))]))

;; ----------------------------------------

(define (parse-type te)
  (type-case TE te
    [numTE () (numT)]
    [boolTE () (boolT)]
    [arrowTE (a b) (arrowT (parse-type a)
                           (parse-type b))]))

(define (type-error fae msg)
  (error 'typecheck (string-append
                     "no type: "
                     (string-append
                      (to-string fae)
                      (string-append " not "
                                     msg)))))

(define typecheck : (FAE TypeEnv -> Type)
  (lambda (fae env)
    (type-case FAE fae
      [num (n) (numT)]
      [add (l r) (type-case Type (typecheck l env)
                   [numT ()
                         (type-case Type (typecheck r env)
                           [numT () (numT)]
                           [else (type-error r "num")])]
                   [else (type-error l "num")])]
      [sub (l r) (type-case Type (typecheck l env)
                   [numT ()
                         (type-case Type (typecheck r env)
                           [numT () (numT)]
                           [else (type-error r "num")])]
                   [else (type-error l "num")])]
      [id (name) (get-type name env)]
      [fun (name te body)
           (local [(define arg-type (parse-type te))]
             (arrowT arg-type
                     (typecheck body (aBind name
                                            arg-type
                                            env))))]
      [call (fn arg)
           (type-case Type (typecheck fn env)
             [arrowT (arg-type result-type)
                     (if (equal? arg-type
                                 (typecheck arg env))
                         result-type
                         (type-error arg
                                     (to-string arg-type)))]
             [else (type-error fn "function")])]
      [if0 (test-expr then-expr else-expr)
           (type-case Type (typecheck test-expr env)
             [numT () (local [(define test-ty (typecheck then-expr env))]
                        (if (equal? test-ty (typecheck else-expr env))
                            test-ty
                            (type-error else-expr
                                        (to-string test-ty))))]
             [else (type-error test-expr "num")])]
      [rec (name ty rhs-expr body-expr)
        (local [(define rhs-ty (parse-type ty))
                (define new-ds (aBind name
                                      rhs-ty
                                      env))]
          (if (equal? rhs-ty (typecheck rhs-expr new-ds))
              (typecheck body-expr new-ds)
              (type-error rhs-expr (to-string rhs-ty))))])))


;; ----------------------------------------

(test (interp (num 10)
              (mtSub))
      (numV 10))
(test (interp (add (num 10) (num 17))
              (mtSub))
      (numV 27))
(test (interp (sub (num 10) (num 7))
              (mtSub))
      (numV 3))
(test (interp (call (fun 'x (numTE) (add (id 'x) (num 12)))
                   (add (num 1) (num 17)))
              (mtSub))
      (numV 30))
(test (interp (id 'x)
              (aSub 'x (numV 10) (mtSub)))
      (numV 10))

(test (interp (call (fun 'x (numTE)
                        (call (fun 'f (arrowTE (numTE) (numTE))
                                  (add (call (id 'f) (num 1))
                                       (call (fun 'x (numTE)
                                                 (call (id 'f)
                                                      (num 2)))
                                            (num 3))))
                             (fun 'y (numTE)
                                  (add (id 'x) (id 'y)))))
                   (num 0))
              (mtSub))
      (numV 3))

(test (interp (if0 (num 0) (num 1) (num 2))
              (mtSub))
      (numV 1))
(test (interp (if0 (num 1) (num 1) (num 2))
              (mtSub))
      (numV 2))
(test (interp (rec 'a (numTE)
                (num 10)
                (add (id 'a) (num 1)))
              (mtSub))
      (numV 11))
(test (interp (rec 'fib (arrowTE (numTE) (numTE))
                (fun 'x (numTE)
                     (if0 (id' x)
                          (num 1)
                          (if0 (sub (id 'x) (num 1))
                               (num 1)
                               (add (call (id 'fib) (sub (id 'x) (num 1)))
                                    (call (id 'fib) (sub (id 'x) (num 2)))))))
                (call (id 'fib) (num 4)))
              (mtSub))
      (numV 5))


(test/exn (lambda ()
            (interp (id 'x) (mtSub)))
          "free variable")

(test (typecheck (num 10) (mtEnv))
      (numT))

(test (typecheck (add (num 10) (num 17)) (mtEnv))
      (numT))
(test (typecheck (sub (num 10) (num 7)) (mtEnv))
      (numT))

(test (typecheck (fun 'x (numTE) (add (id 'x) (num 12))) (mtEnv))
      (arrowT (numT) (numT)))

(test (typecheck (fun 'x (numTE) (fun 'y (boolTE) (id 'x))) (mtEnv))
      (arrowT (numT) (arrowT (boolT)  (numT))))

(test (typecheck (call (fun 'x (numTE) (add (id 'x) (num 12)))
                      (add (num 1) (num 17)))
                 (mtEnv))
      (numT))

(test (typecheck (call (fun 'x (numTE)
                           (call (fun 'f (arrowTE (numTE) (numTE))
                                     (add (call (id 'f) (num 1))
                                          (call (fun 'x (numTE) (call (id 'f) (num 2)))
                                               (num 3))))
                                (fun 'y (numTE)
                                     (add (id 'x)
                                          (id' y)))))
                      (num 0))
                 (mtEnv))
      (numT))

(test (typecheck (if0 (num 0) (num 1) (num 2))
                 (mtEnv))
      (numT))
(test (typecheck (if0 (num 0) 
                      (fun 'x (numTE) (id 'x))
                      (fun 'y (numTE) (num 3)))
                 (mtEnv))
      (arrowT (numT) (numT)))
(test (typecheck (rec 'a (numTE)
                   (num 10)
                   (add (id 'a) (num 1)))
                 (mtEnv))
      (numT))
(test (typecheck (rec 'fib (arrowTE (numTE) (numTE))
                   (fun 'x (numTE)
                        (if0 (id' x)
                             (num 1)
                             (if0 (sub (id 'x) (num 1))
                                  (num 1)
                                  (add (call (id 'fib) (sub (id 'x) (num 1)))
                                       (call (id 'fib) (sub (id 'x) (num 2)))))))
                   (call (id 'fib) (num 4)))
                 (mtEnv))
      (numT))


(test/exn (typecheck (call (num 1) (num 2)) (mtEnv))
          "no type")

(test/exn (typecheck (add (fun 'x (numTE) (num 12))
                          (num 2))
                     (mtEnv))
          "no type")
(test/exn (typecheck (if0 (num 0) 
                          (num 7)
                          (fun 'y (numTE) (num 3)))
                     (mtEnv))
          "no type")
(test/exn (typecheck (rec 'x (numTE)
                       (fun 'y (numTE) (num 3))
                       (num 10))
                     (mtEnv))
          "no type")
(test/exn (typecheck (rec 'x (arrowTE (numTE) (numTE))
                       (fun 'y (numTE) (num 3))
                       (add (num 1) (id 'x)))
                     (mtEnv))
          "no type")