UNB/ CS/ David Bremner/ teaching/ cs4613/ lectures/ lecture20/ mark-sweep.rkt
#lang plai/gc2/collector
; init-allocator : -> void?
(define (init-allocator)
  (for ([i (in-range 0 (heap-size))])
    (heap-set! i 'free)))

; gc:flat? : location? -> boolean?
(define (gc:flat? addr)
  (equal? (heap-ref addr) 'flat))
; gc:deref location? -> heap-
(define (gc:deref addr)
  (unless (gc:flat? addr)
    (error 'gc:flat? "not a flat: ~a" addr))
  (heap-ref (+ addr 1)))

; gc:cons? : location? -> boolean?
(define (gc:cons? addr)
  (equal? (heap-ref addr) 'cons))
; gc:first : location? -> location?
(define (gc:first addr)
  (unless (gc:cons? addr)
    (error 'gc:first "not a cons: ~a" addr))
  (heap-ref (+ addr 1)))
; gc:rest : location? -> location?
(define (gc:rest addr)
  (unless (gc:cons? addr)
    (error 'gc:rest "not a cons: ~a" addr))
  (heap-ref (+ addr 2)))

; gc:set-first! : location? location? -> void?
(define (gc:set-first! addr v)
  (unless (gc:cons? addr)
    (error 'gc:set-first! "not a cons: ~a" addr))
  (heap-set! (+ addr 1) v))
; gc:set-rest! : location? location? -> void
(define (gc:set-rest! addr v)
  (unless (gc:cons? addr)
    (error 'gc:set-rest! "not a cons: ~a" addr))
  (heap-set! (+ addr 2) v))

; gc:closure? : location? -> boolean?
(define (gc:closure? addr)
  (equal? (heap-ref addr) 'clos))
; gc:closure-code-ptr : location? -> heap-value?
(define (gc:closure-code-ptr addr)
  (unless (gc:closure? addr)
    (error 'gc:closure-code-ptr "not a closure: ~a" addr))
  (heap-ref (+ addr 1)))
; gc:closure-env-ref : location? integer? -> location?
(define (gc:closure-env-ref addr i)
  (unless (gc:closure? addr)
    (error 'gc:closure-env-ref "not a closure: ~a" addr))
  (heap-ref (+ addr 3 i)))

; gc:alloc-flat : heap-value? -> location?
(define (gc:alloc-flat v)
  (define address (malloc 2))
  (heap-set! address 'flat)
  (heap-set! (+ 1 address) v)
  address)

; gc:cons : root? root? -> location?
(define (gc:cons v1 v2)
  (define address (malloc 3 v1 v2))
  (heap-set! address 'cons)
  (heap-set! (+ address 1) (read-root v1))
  (heap-set! (+ address 2) (read-root v2))
  address)

;; Test roots are properly passed to gc
(module+ test
  (with-heap (make-vector 10)
    (init-allocator)
    ;; pre-fill heap with garbage
    (for ([i (in-range 3)])
      (gc:alloc-flat i))
    (test (current-heap) #(flat 0 flat 1 flat 2 free free free free))
    ;; Force collection in the middle of cons allocation
    (gc:cons (simple-root (gc:alloc-flat 'first))
             (simple-root (gc:alloc-flat 'rest)))
    (test (current-heap) #(cons 6 8 free free free flat first flat rest)))
  )

(define (gc:closure code-ptr free-variables)
  (define address
    (malloc (+ 3 (length free-variables)) free-variables))
  (heap-set! address 'clos)
  (heap-set! (+ address 1) code-ptr)
  (heap-set! (+ address 2) (length free-variables))
  (for ([i (in-range 0 (length free-variables))]
        [f (in-list free-variables)])
    (heap-set! (+ address 3 i) (read-root f)))
  address)

;; Test roots are properly passed to gc
;; Test roots are properly passed to gc
(module+ test
  (with-heap (make-vector 10)
    (init-allocator)
    ;; pre-fill heap with garbage
    (for ([i (in-range 3)])
      (gc:alloc-flat i))
    (test (current-heap) #(flat 0 flat 1 flat 2 free free free free))
    ;; Force collection in the middle of cons allocation
    (gc:closure 'code
                (list (simple-root (gc:alloc-flat 'var1))
                      (simple-root (gc:alloc-flat 'var2))))
    (test (current-heap) #(clos code 2 6 8 free flat var1 flat var2)))
  )

(define (malloc n . extra-roots)
  (define initial (find-free-space n))
  (unless initial
    (collect-garbage extra-roots))
  (define second (or initial (find-free-space n)))
  (if second second
      (error 'alloc "out of memory")))

(define (find-free-space n) ;  size -> (or/c addr #f)
  (define (loop start)
    (and
     (< start (heap-size))
     (case (heap-ref start)
       [(flat) (loop (+ start 2))]
       [(cons) (loop (+ start 3))]
       [(clos) (loop
                (+ start 3 (heap-ref (+ start 2))))]
       [(free) (if (n-free-blocks? start n)
                   start
                   (loop (+ start 1)))]
       [else (error 'find-free-space
                    "unexpected tag ~a" start)])))
  (loop 0))

; n-free-blocks? : location? integer? -> boolean?
(define (n-free-blocks? start n)
  (or
   (<= n 0)
   (and
    (< start (heap-size))
    (equal? (heap-ref start) 'free)
    (n-free-blocks? (+ start 1) (- n 1)))))

(module+ test
  (with-heap (make-vector 6 #f)
    (init-allocator)
    (test (malloc 4) 0)
    (heap-set! 0 'flat)
    (heap-set! 1 42)
    (test (current-heap)
         #(flat 42 free free free free))
    (test (malloc 2) 2)))

; collect-garbage : roots? -> void?
(define (collect-garbage . extra-roots)
  (validate-heap)
  (mark-white!)
  (traverse/roots (get-root-set))
  (traverse/roots extra-roots)
  (free-white!)
  (validate-heap))

; validate-heap : -> void?
(define (validate-heap)
  (define (valid-pointer? p)
    (unless (< p (heap-size))
      (error 'validate-heap "pointer out of bounds ~a" p))
    (unless (member (heap-ref p) '(flat cons clos))
      (error 'validate-heap "pointer to non-tag ~a" p)))
  (let loop ([i 0])
    (when (< i (heap-size))
      (case (heap-ref i)
        [(flat) (loop (+ i 2))]
        [(cons)
         (valid-pointer? (heap-ref (+ i 1)))
         (valid-pointer? (heap-ref (+ i 2)))
         (loop (+ i 3))]
        [(clos)
         (for ([j (in-range 0 (heap-ref (+ i 2)))])
           (valid-pointer? (heap-ref (+ i 3 j))))
         (loop (+ i 3 (heap-ref (+ i 2))))] ; use stored length
        [(free) (loop (+ i 1))]
        [else (error 'validate-heap
                     "unexpected tag: ~a" i)]))); @$\vdots$@
  )

(define (mark-white!)
  (let loop ([i 0])
    (when (< i (heap-size))
      (case (heap-ref i)
        [(cons)
         (heap-set! i 'white-cons) (loop (+ i 3))]
        [(flat)
         (heap-set! i 'white-flat) (loop (+ i 2))]
        [(clos) (heap-set! i 'white-clos)
                (loop (+ i 3 (heap-ref (+ i 2))))]
        [(free) (loop (+ i 1))]
        [else (error 'mark-white! "bad tag: ~a" i)]))))

(define (traverse/roots roots)
  (cond
    [(list? roots)
     (for-each traverse/roots roots)]
    [(root? roots)
     (traverse/loc (read-root roots))]
    [else
     (error 'traverse/roots
            "unexpected roots: ~a" roots)]))

(define (traverse/loc loc)
  (case (heap-ref loc)
    [(flat gray-flat cons gray-cons clos gray-clos) (void)]
    [(white-flat) (heap-set! loc 'flat)]
    [(white-cons) (heap-set! loc 'gray-cons)
     (traverse/loc (heap-ref (+ loc 1)))
     (traverse/loc (heap-ref (+ loc 2)))
     (heap-set! loc 'cons)]
    [(white-clos) (heap-set! loc 'gray-clos)
     (for ([i (in-range 0 (heap-ref (+ loc 2)))])
       (traverse/loc (heap-ref (+ loc i 3))))
     (heap-set! loc 'clos)]
    [else (error 'traverse/loc 
                 "unexpected tag: ~a" loc)]))

(define (free-white!)
  (let loop ([i 0])
    (when (< i (heap-size))
      (case (heap-ref i)
        [(cons) (loop (+ i 3))]
        [(flat) (loop (+ i 2))]
        [(clos) (loop (+ i 3 (heap-ref (+ i 2))))]
        [(free) (loop (+ i 1))]
        [(white-flat) (heap-set! i 'free)
                      (heap-set! (+ i 1) 'free)
                      (loop (+ i 2))]
        [(white-cons) (heap-set! i 'free)
                      (heap-set! (+ i 1) 'free)
                      (heap-set! (+ i 2) 'free)
                      (loop (+ i 3))]
        [(white-clos) (heap-set! i 'free)
                      (heap-set! (+ i 1) 'free)
                      (define size (heap-ref (+ i 2)))
                      (for ([x (in-range 0 size)])
                        (heap-set! (+ i 3 x) 'free))
                      (heap-set! (+ i 2) 'free)
                      (loop (+ i 3 size))] ; @$\vdots$@
        [else (error 'free-white! "unexpected tag: ~a" i)]))))