UNB/ CS/ David Bremner/ teaching/ cs4613/ assignments/ A4/ skeleton.rkt
#lang plait
(define-type TypeExp
  [numTE]
  [boolTE]
  [arrowTE (arg : TypeExp)
           (result : TypeExp)]
  [objTE (fields : (Listof (Symbol * TypeExp)))])

(define-type Exp
  [numE (n : Number)]
  [boolE (b : Boolean)]
  [plusE (left : Exp) (right : Exp)]
  [timesE (left : Exp) (right : Exp)]
  [minusE (left : Exp) (right : Exp)]
  [leqE (left : Exp) (right : Exp)]
  [lamE (var : Symbol) (te : TypeExp) (body : Exp)]
  [appE (fun : Exp) (arg : Exp)]
  [varE (name : Symbol)]
  [ifE (check : Exp) (zero : Exp) (non-zero : Exp)]
  [let1E (var : Symbol) (te : TypeExp) (value : Exp) (body : Exp)]
  [recE (var : Symbol) (te : TypeExp) (value : Exp) (body : Exp)]
  [objE (fields : (Listof (Symbol * Exp)))]
  [msgE (obj : Exp) (selector : Symbol)]
  )

(define-type Type
  [numT]
  [boolT]
  [arrowT (arg : Type)
           (result : Type)]
  [objT (fields : (Hashof Symbol Type))])

(define-type-alias TypeEnv (Hashof Symbol Type))

(define mt-type-env (hash empty)) ;; "empty type environment"
(define (type-lookup (s : Symbol) (n : TypeEnv))
  (type-case (Optionof Type) (hash-ref n s)
    [(none) (error s "not bound")]
    [(some b) b]))

(test/exn (type-lookup 'x mt-type-env) "not bound")

(define (type-extend (env : TypeEnv) (s : Symbol) (t : Type))
  (hash-set env s t))

(define (interp-te te)
  (type-case TypeExp te
    [(numTE) (numT)]
    [(boolTE) (boolT)]
    [(arrowTE a b) (arrowT (interp-te a)
                           (interp-te b))]
    [(objTE fields) (objT (hash
                            (map (lambda (key-val)
                                   (values (fst key-val)
                                           (interp-te (snd key-val))))
                                 fields)))]))
(module+ test
  (test (interp-te (objTE
                    (list (pair 'add1 (arrowTE (numTE) (numTE)))
                          (pair 'compare (arrowTE (numTE) (boolTE))))))
        (objT (hash (list (pair 'add1 (arrowT (numT) (numT)))
                          (pair 'compare (arrowT (numT) (boolT))))))))

(define (subtype? X Y) ....)

(module+ test
  (define hello-t (objT (hash (list (pair 'hello (numT))))))
  (define hello-goodbye-t (objT (hash (list
                                       (pair 'hello (numT))
                                       (pair 'goodbye (boolT))))))
  (test (subtype? (numT) (boolT)) #f)
  (test (subtype? (numT) (numT)) #t)
  (test (subtype? (numT) hello-t) #f)
  (test (subtype? hello-t (objT (hash (list (pair 'hello (boolT)))))) #f)
  (test (subtype? hello-goodbye-t hello-t) #t)
  (test (subtype? hello-t hello-goodbye-t) #f))

(define (typecheck [exp : Exp] [env : TypeEnv]) : Type
  (local
      [(define (num2 l r type)
         (let ([left-t (typecheck l env)]
               [right-t (typecheck r env)])
           (if (and (equal? (numT) left-t) (equal? (numT) right-t))
               type
               (error 'typecheck "expected 2 num"))))]
    (type-case Exp exp

      [(numE n) (numT)]
      [(boolE b) (boolT)]
      [(plusE l r) (num2 l r (numT))]
      [(minusE l r) (num2 l r (numT))]
      [(timesE l r) (num2 l r (numT))]
      [(leqE l r) (num2 l r (boolT))]
      [(varE s) (type-lookup s env)]
      [(lamE name te body)
       (let* ([arg-type (interp-te te)]
              [body-type (typecheck body (type-extend env name arg-type))])
         (arrowT arg-type body-type))]
      [(appE fn arg)
       (type-case Type (typecheck fn env)
         [(arrowT arg-type result-type)
          (let ([actual-type (typecheck arg env)])
            (if (equal? arg-type actual-type)
                result-type
                (error 'typecheck "argument type")))]
         [else (error 'typecheck "not function")])]
      [(ifE c t f)
       (if (equal? (typecheck c env) (boolT))
           (let ([t-type (typecheck t env)]
                 [f-type (typecheck f env)])
             (if (equal? f-type t-type)
                 f-type
                 (error 'typecheck "branches must have same type")))
           (error 'typecheck "expected boolean"))]
      [(let1E var te val body)
       (let* ([var-t (interp-te te)]
              [val-t (typecheck val env)]
              [body-t (typecheck body (type-extend env var var-t))])
         (if (equal? var-t val-t)
             body-t
             (error 'typecheck "type does not match annotation")))]
      [(recE var te val body)
       (let* ([var-t (interp-te te)]
              [val-t (typecheck val (type-extend env var var-t))]
              [body-t (typecheck body (type-extend env var var-t))])
         (if (equal? var-t val-t)
             body-t
             (error 'typecheck "type does not match annotation")))]
      [(objE fields) ....]
      [(msgE obj selector) ....] )))

(define (parse-error sx)
  (error 'parse (string-append "parse error: " (to-string sx))))

(module+ test
  (test/exn (parse `"strings are not in our language") "parse")
  (test/exn (parse `{& 1 2}) "parse")
)

(define (sx-ref sx n) (list-ref (s-exp->list sx) n))

(define (parse-te sx)
  (cond
    [(s-exp-symbol? sx)
     (case (s-exp->symbol sx)
       [(num) (numTE)]
       [(bool) (boolTE)])]
    [(s-exp-match? `(ANY -> ANY) sx)
     (arrowTE (parse-te (sx-ref sx 0)) (parse-te (sx-ref sx 2)))]
    [(s-exp-match? `(obj (SYMBOL ANY) ...) sx)
     (objTE
      (map (lambda (element)
             (pair (s-exp->symbol (sx-ref element 0))
                   (parse-te (sx-ref element 1))))
           (rest (s-exp->list sx))))]))

(define (parse sx)
  (local
      [(define (px i) (parse (sx-ref sx i)))]
    (cond
      [(s-exp-number? sx) (numE (s-exp->number sx))]
      [(s-exp-symbol? sx)
       (let ([sym (s-exp->symbol sx)])
         (case sym
           [(true) (boolE #t)]
           [(false) (boolE #f)]
           [else (varE sym)]))]
      [(s-exp-match? `(msg ANY SYMBOL) sx)
       (msgE (px 1) (s-exp->symbol (sx-ref sx 2)))]
      [(s-exp-match? `(obj (SYMBOL ANY) ...) sx)
       (objE
        (map
         (lambda (element)
           (pair (s-exp->symbol (sx-ref element 0))
                 (parse (sx-ref element 1))))
         (rest (s-exp->list sx))))]
      [(s-exp-match? `(lam (SYMBOL : ANY) ANY) sx)
       (let* ([args (sx-ref sx 1)]
              [id (s-exp->symbol (sx-ref args 0))]
              [te (parse-te (sx-ref args 2))]
              [body (px 2)])
         (lamE id te body))]
      [(s-exp-match? `(let1 (SYMBOL : ANY) ANY ANY) sx)
       (let* ([args (sx-ref sx 1)]
              [id (s-exp->symbol (sx-ref args 0))]
              [te (parse-te (sx-ref args 2))]
              [rhs (px 2)]
              [body (px 3)])
         (let1E id te rhs body))]
      [(s-exp-match? `(rec (SYMBOL : ANY) ANY ANY) sx)
       (let* ([args (sx-ref sx 1)]
              [id (s-exp->symbol (sx-ref args 0))]
              [te (parse-te (sx-ref args 2))]
              [rhs (px 2)]
              [body (px 3)])
         (recE id te rhs body))]
      [(s-exp-match? `(ANY ANY) sx)
       (appE (px 0) (px 1))]
      [(s-exp-list? sx)
       (case (s-exp->symbol (sx-ref sx 0))
         [(+) (plusE (px 1) (px 2))]
         [(-) (minusE (px 1) (px 2))]
         [(*) (timesE (px 1) (px 2))]
         [(<=) (leqE (px 1) (px 2))]
         [(if) (ifE (px 1) (px 2) (px 3))]
         [else (parse-error sx)])]
      [else (parse-error sx)])))

(module+ test
  (test (parse `{obj {hello true} {goodbye 42}})
        (objE (list (pair 'hello (boolE #t))
                    (pair 'goodbye (numE 42)))))
  (test (parse `{lam {x : (obj (n-func (num -> num)))} x})
        (lamE 'x (objTE (list (pair 'n-func (arrowTE (numTE) (numTE)))))
              (varE 'x)))

  )


(tc : (S-Exp -> Type))
(define (tc s)
  (typecheck (parse s) mt-type-env))

(module+ test
  (test (tc `{+ 1 2}) (numT))
  (test/exn (tc `{+ true 2}) "numbers")
  (test/exn (tc `{1 1}) "function")
  (test/exn (tc `{{lam {b : bool} false} 1}) "argument type")
  (test/exn (tc `{if false 1 true}) "branches")
  (test/exn (tc `{if 1 false true}) "boolean")
  (test/exn (tc `{let1 [x : num] true x}) "annotation")
  (test/exn (tc `{rec [x : num] true x}) "annotation"))

(module+ test
  (define sampler `{obj {hello true}
                        {goodbye false}
                        {a-num 42}
                        {n-func {lam {x : num} x}}
                        {b-func {lam {x : bool} x}}
                        })
  (test (tc sampler)
        (objT (hash (list (pair 'hello (boolT))
                          (pair 'goodbye (boolT))
                          (pair 'a-num (numT))
                          (pair 'n-func (arrowT (numT) (numT)))
                          (pair 'b-func (arrowT (boolT) (boolT)))))))
  (test (tc `{msg ,sampler hello}) (boolT))
  (test/exn (tc `{msg 1 hello}) "object")
  (test/exn (tc `{msg ,sampler blah}) "unknown field")
  (define obj-fun `{lam {x : (obj (n-func (num -> num)))} {{msg x n-func} 3}})
  (test (tc obj-fun) (arrowT
                      (objT (hash (list (pair 'n-func (arrowT (numT) (numT))))))
                      (numT)))
  (test (tc `{,obj-fun {obj {n-func {lam {x : num} x}}}}) (numT))
  (test/exn (tc `{,obj-fun 2}) "argument type")
  (test/exn (tc `{if true ,obj-fun 2}) "branches")
  (test (tc `{rec {fact : (obj (run (num -> num)))}
                  {obj {run {lam {n : num}
                                 {if {<= n 0} 1 {* n {{msg fact run} {- n 1}}}}}}}
                  {{msg fact run} 10}})
        (numT))

  )

(module+ test
    (test (tc `{,obj-fun {obj {n-func {lam {x : num} x}}
                              {b-func {lam {x : bool} x}}}}) (numT))

    (test (tc `{let1 {f : {(obj (n-func (num -> num))) -> num}}
                     ,obj-fun
                     {f ,sampler}})
          (numT)))