#lang plait
(define-type TPFAE
[Num (n : Number)]
[Bool (b : Boolean)]
[Not (expr : TPFAE)]
[Add (lhs : TPFAE)
(rhs : TPFAE)]
[Sub (lhs : TPFAE)
(rhs : TPFAE)]
[Id (name : Symbol)]
[Pair (left : TPFAE) (right : TPFAE)]
[Fst (v : TPFAE)]
[Snd (v : TPFAE)]
[Lam (param : Symbol)
(argty : TE)
(body : TPFAE)]
[Call (fun-expr : TPFAE)
(arg-expr : TPFAE)])
(define-type TE
[NumTE]
[BoolTE]
[PairTE (left : TE)
(right : TE)]
[ArrowTE (arg : TE)
(result : TE)])
(define-type FAE-Value
[NumV (n : Number)]
[BoolV (b : Boolean)]
[PairV (l : FAE-Value) (r : FAE-Value)]
[ClosureV (param : Symbol)
(body : TPFAE)
(env : ValueEnv)])
(define-type ValueEnv
[EmptyValueEnv]
[BindValue (name : Symbol)
(value : FAE-Value)
(rest : ValueEnv)])
(define-type Type
[NumT]
[BoolT]
[IdT (name : Symbol)]
[PairT (left : Type) (right : Type)]
[ArrowT (arg : Type)
(result : Type)])
(define-type TypeEnv
[EmptyTypeEnv]
[BindType (name : Symbol)
(type : Type)
(rest : TypeEnv)])
;; ----------------------------------------
;; eval : TPFAE ValueEnv -> FAE-Value
(define (eval a-fae env)
(type-case TPFAE a-fae
[(Num n) (NumV n)]
[(Bool b) (BoolV b)]
[(Not e) (BoolV (not (BoolV-b (eval e env))))]
[(Add l r) (num+ (eval l env) (eval r env))]
[(Sub l r) (num- (eval l env) (eval r env))]
[(Pair l r) (PairV (eval l env) (eval r env))]
[(Fst ex)
(type-case FAE-Value (eval ex env)
[(PairV l r) l]
[else (type-error 'eval "not a pair")])]
[(Snd ex)
(type-case FAE-Value (eval ex env)
[(PairV l r) r]
[else (type-error 'eval "not a pair")])] [(Id name) (lookup name env)]
[(Lam param arg-te body-expr)
(ClosureV param body-expr env)]
[(Call fun-expr arg-expr)
(local [(define fun-val
(eval fun-expr env))
(define arg-val
(eval arg-expr env))]
(eval (ClosureV-body fun-val)
(BindValue (ClosureV-param fun-val)
arg-val
(ClosureV-env fun-val))))]))
;; num-op : (Number Number -> Number) -> (FAE-Value FAE-Value -> FAE-Value)
(define (num-op op op-name x y)
(NumV (op (NumV-n x) (NumV-n y))))
(define (num+ x y) (num-op + '+ x y))
(define (num- x y) (num-op - '- x y))
(define (lookup name env)
(type-case ValueEnv env
[(EmptyValueEnv) (error 'lookup "free variable")]
[(BindValue sub-name num rest-env)
(if (equal? sub-name name)
num
(lookup name rest-env))]))
;; ----------------------------------------
(define (type-lookup name-to-find env)
(type-case TypeEnv env
[(EmptyTypeEnv) (error 'type-lookup "free variable, so no type")]
[(BindType name ty rest)
(if (equal? name-to-find name)
ty
(type-lookup name-to-find rest))]))
;; ----------------------------------------
(define (parse-type te)
(type-case TE te
[(NumTE) (NumT)]
[(BoolTE) (BoolT)]
[(ArrowTE a b) (ArrowT (parse-type a)
(parse-type b))]
[(PairTE l r) (PairT (parse-type l) (parse-type r))]))
(define (type-error fae msg)
(error 'typecheck (string-append
"no type: "
(string-append
(to-string fae)
(string-append " not "
msg)))))
(define (type-assert exprs type env result) : Type
(cond
[(empty? exprs) result]
[(not (equal? (typecheck (first exprs) env) type))
(type-error (first exprs) (type-to-string type))]
[else (type-assert (rest exprs) type env result)]))
(define (type-to-string [type : Type])
(type-case Type type
[(BoolT) "bool"]
[(NumT) "num"]
[else (to-string type)]))
(define (typecheck [fae : TPFAE] [env : TypeEnv]) : Type
(type-case TPFAE fae
[(Num n) (NumT)]
[(Bool b) (BoolT)]
[(Not e) (type-assert (list e) (BoolT) env (BoolT))]
[(Add l r) (type-assert (list l r) (NumT) env (NumT))]
[(Sub l r) (type-assert (list l r) (NumT) env (NumT))]
[(Id name) (type-lookup name env)]
[(Pair l r) (PairT (typecheck l env) (typecheck r env))]
[(Fst ex)
(type-case Type (typecheck ex env)
[(PairT l r) l]
[else (type-error ex "not a pair")])]
[(Snd ex)
(type-case Type (typecheck ex env)
[(PairT l r) r]
[else (type-error ex "not a pair")])]
[(Lam name te body)
(let* ([arg-type (parse-type te)]
[body-type (typecheck body (BindType name arg-type env))])
(ArrowT arg-type body-type))]
[(Call fn arg)
(type-case Type (typecheck fn env)
[(ArrowT arg-type result-type)
(type-assert (list arg) arg-type env result-type)]
[else (type-error fn "function")])]))
;; ----------------------------------------
(define (parse-error sx)
(error 'parse (string-append "parse error: " (to-string sx))))
(define (sx-ref sx n) (list-ref (s-exp->list sx) n))
(define (parse sx)
(local
[(define (px i) (parse (sx-ref sx i)))]
(cond
[(s-exp-number? sx) (Num (s-exp->number sx))]
[(s-exp-symbol? sx)
(let ([sym (s-exp->symbol sx)])
(case sym
[(true) (Bool #t)]
[(false) (Bool #f)]
[else (Id sym)]))]
[(s-exp-match? `(lam (SYMBOL : ANY) ANY) sx)
(let* ([args (sx-ref sx 1)]
[id (s-exp->symbol (sx-ref args 0))]
[te (parse-te (sx-ref args 2))]
[body (px 2)])
(Lam id te body))]
[(s-exp-match? `(ANY ANY) sx)
(cond
[(equal? (sx-ref sx 0) `not) (Not (px 1))]
[(equal? (sx-ref sx 0) `fst) (Fst (px 1))]
[(equal? (sx-ref sx 0) `snd) (Snd (px 1))]
[else (Call (px 0) (px 1))])]
[(s-exp-list? sx)
(case (s-exp->symbol (sx-ref sx 0))
[(+) (Add (px 1) (px 2))]
[(-) (Sub (px 1) (px 2))]
[(pair) (Pair (px 1) (px 2))]
[else (parse-error sx)])]
[else (parse-error sx)])))
(define (parse-te sx)
(cond
[(s-exp-symbol? sx)
(case (s-exp->symbol sx)
[(num) (NumTE)]
[(bool) (BoolTE)])]
[(s-exp-match? `(ANY * ANY) sx)
(PairTE (parse-te (sx-ref sx 0)) (parse-te (sx-ref sx 2)))]
[(s-exp-match? `(ANY -> ANY) sx)
(ArrowTE (parse-te (sx-ref sx 0)) (parse-te (sx-ref sx 2)))]))
(define (run s-expr)
(eval (parse s-expr) (EmptyValueEnv)))
(define (check s-expr)
(typecheck (parse s-expr) (EmptyTypeEnv)))
(define-syntax-rule (test/type expr type) (test (check expr) type))
(define-syntax-rule (test/notype expr) (test/exn (check expr) "no type"))
(module+ test
(print-only-errors #t)
(test (check `{+ 1 2}) (NumT))
(test/exn (run `x) "free variable")
(test (run `{not false}) (BoolV #t))
(test (run `10) (NumV 10))
(test (run `{+ 10 17}) (NumV 27))
(test (run `{- 10 7}) (NumV 3))
(test (run `{{lam {x : num} {+ x 12}} {+ 1 17}}) (NumV 30))
(test (eval (Id 'x)
(BindValue 'x (NumV 10) (EmptyValueEnv)))
(NumV 10))
(define lam-lam
`{{lam {x : num}
{{lam {f : (num -> num)}
{+ {f 1}
{{lam {x : num}
{f 2}}
3}}}
{lam {y : num}
{+ x y}}}}
0})
(test (run lam-lam)
(NumV 3))
)
(module+ test
(print-only-errors #t)
(test/notype `x)
(test/notype `{not 1})
(test/type `{not false} (BoolT))
(test/type lam-lam (NumT))
(test/type `10 (NumT))
(test/type `{+ 10 17} (NumT))
(test/type `{- 10 7} (NumT))
(test/notype `{+ false 17})
(test/notype `{- false 17})
(test/notype `{+ 17 false})
(test/notype `{- 17 false})
(test/type `{lam {x : num} {+ x 12}}
(ArrowT (NumT) (NumT)))
(test/notype `{{lam {x : num} x} true})
(test/type `{lam {x : num} {lam {y : bool} x}}
(ArrowT (NumT) (ArrowT (BoolT) (NumT))))
(test/type `{{lam {x : num} {+ x 12}} {+ 1 17}} (NumT))
(test/notype `{1 2})
(test/notype `{+ {lam {x : num} 12} 2})
;; Added coverage test for type-to-string
(test/notype `{{lam {f : {num -> num}}
{f 1}}
1})
)
(module+ test
;; Coverage tests for pairs
(define test-pair-ex1 `{pair 1 false})
(define pair-lam `{lam {x : {num * num}} {fst x}})
(test (check test-pair-ex1) (PairT (NumT) (BoolT)))
(test (run test-pair-ex1) (PairV (NumV 1) (BoolV #f)))
(test (check `{fst ,test-pair-ex1}) (NumT))
(test (check `{snd ,test-pair-ex1}) (BoolT))
(test (run `{fst ,test-pair-ex1}) (NumV 1))
(test (run `{snd ,test-pair-ex1}) (BoolV #f))
(test (run `{fst ,test-pair-ex1}) (NumV 1))
(test/exn (run `{fst {fst ,test-pair-ex1}}) "not a pair")
(test/exn (run `{snd {snd ,test-pair-ex1}}) "not a pair")
(test (check pair-lam) (ArrowT (PairT (NumT) (NumT)) (NumT)))
(test/exn (check `{,pair-lam 1}) "not (PairT (NumT) (NumT))")
)