UNB/ CS/ David Bremner/ teaching/ cs4613/ lectures/ lecture22/ generational.rkt
#lang plai/gc2/collector
;; Originally by Yixi Zhang
;; simplified for class by db

;; config for collection
(define (alloc-word) 0) ; next location of young generation allocation
(define (free-list-head) (2nd-gen-size)) ; a word holds the location of head of free slots list
(define (table-start) (add1 (2nd-gen-size))) ;  start position of table for intergenerational pointers

(define debug #t)
(define (print-heap-layout)
  (when debug
    (printf "init allocator 1-gen ~a 2-gen ~a table ~a fl ~a\n" (1st-gen-size) (2nd-gen-size)
            (table-start) (free-list-head))))

;; init-allocator : -> void
(define (init-allocator)
  (heap-set! (alloc-word) 1)
  (for ([i (in-range 1 (heap-size))])
    (heap-set! i 'free))
  (heap-set! (1st-gen-size) 'free-n)
  (heap-set! (+ 1 (1st-gen-size)) #f)
  (heap-set! (+ 2 (1st-gen-size))
             (- (2nd-gen-size)
                (1st-gen-size)))
  (heap-set! (free-list-head)
             (1st-gen-size))
  (heap-set! (table-start) (add1 (table-start)))
  (print-heap-layout))

;; 1st gen takes 1/4 of entire heap
(define (1st-gen-size)
  (let ([s (round (* (heap-size) 1/4))])
    (if (odd? s)
        (add1 s)
        s)))
(define (2nd-gen-size)
  (round (* (heap-size) 7/8)))

;; -> loc loc
(define (to-space)
  ;; returns start and end of to-space
  (let ([next (heap-ref (alloc-word))]
        [half (add1 (round (/ (1st-gen-size) 2)))])
    (if (< next half)
        (values 1 half)
        (values half (1st-gen-size)))))
(define (from-space)
  ;; returns start and end of from-space
  (let ([next (heap-ref (alloc-word))]
        [half (add1 (round (/ (1st-gen-size) 2)))])
    (if (< next half)
        (values half (1st-gen-size))
        (values 1 half))))

(define (at-to-space? loc)
  (let-values ([(begin end) (to-space)])
    (and (>= loc begin)
         (< loc end))))

(define (2nd-gen? loc)
  (and (>= loc (1st-gen-size))
       (< loc (2nd-gen-size))))

;; gc:deref : loc -> heap-value
;; must signal an error if fl-loc doesn't point to a flat value
(define (gc:deref fl-loc)
  (case (heap-ref fl-loc)
    [(flat) (heap-ref (+ fl-loc 1))]
    [(frwd) (gc:deref (heap-ref (+ fl-loc 1)))]
    [else (error 'gc:deref
                 "non-flat @ ~s"
                 fl-loc)]))

;; track/loc : loc -> loc
;; if loc points to a flat or pair or proc, then return loc
;; else if loc points to a frwd, return the frwd address
(define (track/loc loc-or-root)
  (define loc (->location loc-or-root))
  (case (heap-ref loc)
    [(flat pair proc) loc]
    [(frwd) (heap-ref (+ loc 1))]
    [else (error 'track/loc "wrong tag ~s at ~a" (heap-ref loc) loc)]))

;; gc:alloc-flat : heap-value -> loc
(define (gc:alloc-flat fv)
  (define ptr (malloc 2))
  (heap-set! ptr 'flat)
  (heap-set! (+ ptr 1) fv)
  ptr)

;; ->location : (or/c location? root?) . -> . location?
(define (->location thing)
  (cond
    [(location? thing) thing]
    [(root? thing) (read-root thing)]))

;; gc:cons : loc loc -> loc
;; hd and tl are guaranteed to have been earlier
;; results from either gc:alloc-flat or gc:cons
(define (gc:cons hd tl)
  (define ptr (malloc 3 hd tl))
  (define hd/loc (->location hd))
  (define tl/loc (->location tl))
  (define head (track/loc hd/loc))
  (define tail (track/loc tl/loc))
  (heap-set! ptr 'pair)
  (heap-set! (+ ptr 1) head)
  (heap-set! (+ ptr 2) tail)
  ptr)

;; gc:first : loc -> loc
;; must signal an error of pr-loc does not point to a pair
(define (gc:first pr-loc)
  (if (equal? (heap-ref pr-loc) 'pair)
      (heap-ref (+ (track/loc pr-loc) 1))
      (error 'first "non pair @ ~s" pr-loc)))

;; gc:rest : loc -> loc
;; must signal an error of pr-loc does not point to a pair
(define (gc:rest pr-loc)
  (if (equal? (heap-ref pr-loc) 'pair)
      (heap-ref (+ (track/loc pr-loc) 2))
      (error 'rest "non pair @ ~s" pr-loc)))

;; gc:flat? : loc -> boolean
;; loc is guaranteed to have been an earlier
;; result from either gc:alloc-flat or gc:cons
(define (gc:flat? loc)
  (case (heap-ref loc)
    [(flat) #t]
    [(frwd) (gc:flat? (heap-ref (+ loc 1)))]
    [else #f]))

;; gc:cons? : loc -> boolean
;; loc is guaranteed to have been an earlier
;; result from either gc:alloc-flat or gc:cons
(define (gc:cons? loc)
  (case (heap-ref loc)
    [(pair) #t]
    [(frwd) (gc:cons? (heap-ref (+ loc 1)))]
    [else #f]))

;; gc:set-first! : loc loc -> void
;; must signal an error of pr-loc does not point to a pair
(define (gc:set-first! pr-loc new)
  (cond
    [(gc:cons? pr-loc)
     (define loc (track/loc pr-loc))
     (heap-set! (+ loc 1) new)
     (when (and (2nd-gen? loc)
                (at-to-space? new))
       (table/alloc (+ loc 1) new))]
    [else (error 'set-first! "non pair at ~s" pr-loc)]))

;; gc:set-rest! : loc loc -> void
;; must signal an error of pr-loc does not point to a pair
(define (gc:set-rest! pr-loc new)
  (cond
    [(gc:cons? pr-loc)
     (define loc (track/loc pr-loc))
     (heap-set! (+ loc 2) new)
     (when (and (2nd-gen? loc)
                (at-to-space? new))
       (table/alloc (+ loc 2) new))]
    [else (error 'set-rest! "non pair @ ~s" pr-loc)]))

;; gc:closure : heap-value (vectorof loc) -> loc
;; allocates a closure with 'code-ptr' and the free variables
;; in the vector 'free-vars'.
(define (gc:closure code-ptr free-vars)
  (define fv-count (length free-vars))
  (define next (malloc (+ fv-count 3)
                      free-vars
                      '()))
  (define updated-free-vars
    (for/list ([v (in-list free-vars)])
                (track/loc v)))
  (heap-set! next 'proc)
  (heap-set! (+ next 1) code-ptr)
  (heap-set! (+ next 2) fv-count)
  (for ([x (in-range 0 fv-count)])
    (heap-set! (+ next 3 x)
               (list-ref updated-free-vars x)))
  next)

;; gc:closure-code-ptr : loc -> heap-value
;; given a location returned from an earlier allocation
;; check to see if it is a closure; if not signal an
;; error. if so, return the code-ptr
(define (gc:closure-code-ptr loc)
  (if (gc:closure? loc)
      (heap-ref (+ (track/loc loc) 1))
      (error 'gc:closure-code-ptr "non closure at ~a" loc)))

;; gc:closure-env-ref : loc number -> loc
;; given a location returned from an earlier allocation
;; check to see if it is a closure; if not signal an
;; error. if so, return the 'i'th variable in the closure
(define (gc:closure-env-ref loc i)
  (if (gc:closure? loc)
      (heap-ref (+ (track/loc loc) 3 i))
      (error 'gc:closure-env-ref "non closure at ~a" loc)))

;; gc:closure? : loc -> boolean
;; determine if a previously allocated location
;; holds a closure
(define (gc:closure? loc)
  (case (heap-ref loc)
    [(proc) #t]
    [(frwd) (gc:closure? (heap-ref (+ loc 1)))]
    [else #f]))


(define (table/alloc pointer target)
  (define next (heap-ref (table-start)))
  (cond
    [(>= (+ next 2) (heap-size))
     (move/pointers (+ 1 (table-start)))
     (heap-set! (+ 1 (table-start)) pointer)
     (heap-set! (+ 2 (table-start)) target)
     (heap-set! (table-start) (+ 3 (table-start)))]
    [else
     (heap-set! next pointer)
     (heap-set! (+ next 1) target)
     (heap-set! (table-start) (+ next 2))]))

;; alloc : number[size] roots roots -> loc
(define (malloc n . extra-roots)
  (define addr (heap-ref (alloc-word)))
  (cond
    [(enough-to-space? addr n)
     (heap-set! (alloc-word) (+ addr n))
     addr]
    [else
     (collect-garbage extra-roots)
     (switch/sweep-tospace n)]))

;; loc number -> bool
(define (enough-to-space? start size)
  (define-values (begin end) (to-space))
  (< (+ start size) end))

;; number -> loc
(define (switch/sweep-tospace number)
  (define-values (begin end) (from-space))
  (for ([i (in-range begin end)])
       (heap-set! i 'free))
  (heap-set! (alloc-word) (+ begin number))
  begin)

;; find-free-space : find free space by traversing free space list
;; layout := free-2 next
;;        |  free-n next size
;; next := (or/c location? #f)
;; free-list-head : holds the value of head of free space list
;;                | #f means already runs out of free space
;;
;; before give a series of free spaces to alloc function
;; must update the free space list
;; iff its the first free space in list must also udpate (free-list-head) as well
(define (find-free-space start prev size)
  (local [(define (next-in-free-list loc)
            (heap-ref (+ loc 1)))
          (define (update-next-in-prev prev loc)
            (heap-set! (if prev
                              (+ prev 1)
                              (free-list-head))
                          loc))]
    (cond
      [(not start) #f]
      [else
       (case (heap-ref start)
         [(free-2)
          (cond
            [(= size 2)
             (update-next-in-prev prev (next-in-free-list start))
             start]
            [else (find-free-space (heap-ref (+ start 1)) start size)])]
         [(free-n)
          (define length (heap-ref (+ start 2)))
          (cond
            [(= size length)
             (update-next-in-prev prev (next-in-free-list start))
             start]
            [(< size length)
             (define new-free (+ start size))
             (define new-size (- length size))
             (cond
               [(= new-size 1)
                (update-next-in-prev prev (next-in-free-list start))
                (heap-set! new-free 'free)
                start]
               [else
                (update-next-in-prev prev new-free)
                (cond
                  [(= new-size 2)
                   (heap-set! new-free 'free-2)
                   (heap-set! (+ new-free 1) (heap-ref (+ start 1)))]
                  [else (heap-set! new-free 'free-n)
                        (heap-set! (+ new-free 1) (heap-ref (+ start 1)))
                        (heap-set! (+ new-free 2) new-size)])
                start])]
            [else (find-free-space (heap-ref (+ start 1)) start size)])]
         [else (error 'find-free-space "wrong tag @ ~s" start)])])))

(define (2nd-gen-gc . extra-roots)
  (define start (1st-gen-size))
  (mark-white! start)
  (traverse/roots (get-root-set))
  (traverse/roots extra-roots)
  (make-pointers-to-2nd-gen-roots)
  (free-white! start #f #f #f))

(define (make-pointers-to-2nd-gen-roots)
  (define-values (begin end) (values 1 (1st-gen-size)))
  (let loop ([start begin])
    (cond
      [(= start end) (void)]
      [else
        (case (heap-ref start)
          [(flat) (loop (+ start 2))]
          [(pair) (define one-loc (heap-ref (+ start 1)))
                  (when (2nd-gen? one-loc) (traverse/roots one-loc))
                  (define another-loc (heap-ref (+ start 2)))
                  (when (2nd-gen? another-loc) (traverse/roots another-loc))
                  (loop (+ start 3))]
          [(proc) (define fv-counts (heap-ref (+ start 2)))
                  (for ([i (in-range fv-counts)])
                       (define loc (heap-ref (+ start 3 i)))
                       (when (2nd-gen? loc) (traverse/roots loc)))
                  (loop (+ start 3 fv-counts))]
          [(frwd) (define loc (heap-ref (+ start 1)))
                  (traverse/roots loc)
                  (case (heap-ref loc)
                    [(flat) (loop (+ start 2))]
                    [(pair) (loop (+ start 3))]
                    [(proc) (loop (+ start 3 (heap-ref (+ loc 2))))])]
          [(free) (loop (+ start 1))]
          [else (error 'make-pointers-to-2nd-gen-roots "wrong tag at ~a" start)])])))

(define (mark-white! i)
  (when (< i (2nd-gen-size))
    (case (heap-ref i)
      [(flat) (heap-set! i 'white-flat)
              (mark-white! (+ i 2))]
      [(pair) (heap-set! i 'white-pair)
              (mark-white! (+ i 3))]
      [(proc) (heap-set! i 'white-proc)
              (mark-white! (+ i 3 (heap-ref (+ i 2))))]
      [(free) (mark-white! (+ i 1))]
      [(free-2) (mark-white! (+ i 2))]
      [(free-n) (mark-white! (+ i (heap-ref (+ i 2))))]
      [else (error 'mark-white! "wrong tag at ~a" i)])))

(define (traverse/roots thing)
  (cond
    [(list? thing)
     (for-each traverse/roots thing)]
    [(root? thing)
     (traverse/loc (read-root thing))]
    [(number? thing)
     (traverse/loc thing)]))

(define (traverse/loc loc)
  (when (2nd-gen? loc)
    (case (heap-ref loc)
      [(white-flat)
       (heap-set! loc 'flat)]
      [(white-pair)
       (heap-set! loc 'pair)
       (traverse/loc (heap-ref (+ loc 1)))
       (traverse/loc (heap-ref (+ loc 2)))]
      [(white-proc)
       (heap-set! loc 'proc)
       (for ([x (in-range (heap-ref (+ loc 2)))])
         (traverse/loc (heap-ref (+ loc 3 x))))]
      [(pair flat proc) (void)]
      [else (error 'traverse/loc "wrong tag at ~a" loc)])))

;; object-length : location -> number
(define (object-length loc)
  (define tag (heap-ref loc))
  (case tag
    [(free) 1]
    [(free-2) 2]
    [(free-n) (heap-ref (+ loc 2))]
    [(flat white-flat) 2]
    [(pair white-pair) 3]
    [(proc white-proc) (+ 3 (heap-ref (+ loc 2)))]
    [else (error 'object-length "wrong tag ~s @ ~s" tag loc)]))

;; free spaces by constructing free-list
;; free-white! : location location location number -> void
(define (free-white! loc prev last-start spaces-so-far)
  (unless (or (and last-start spaces-so-far)
              (not (or last-start spaces-so-far)))
    (error 'free-white!
           "cumulating info are incorrect, last-start: ~s, spaces-so-far: ~s"
           last-start spaces-so-far))

  (cond
    [(>= loc (2nd-gen-size))
     (cond
       [(and last-start spaces-so-far)
        (cond
          [(= 1 spaces-so-far) (void)]
          [(= 2 spaces-so-far) (heap-set! last-start 'free-2)
                               (heap-set! (+ 1 last-start) #f)
                               (heap-set! (if prev (+ prev 1) (free-list-head))
                                             last-start)]
          [else (heap-set! last-start 'free-n)
                (heap-set! (+ 1 last-start) #f)
                (heap-set! (+ 2 last-start) spaces-so-far)
                (heap-set! (if prev (+ prev 1) (free-list-head))
                              last-start)])]
       [else (void)])]
    [else
      (define tag (heap-ref loc))
      (case tag
        [(flat pair proc)
         (define length (object-length loc))
         (cond
           [(and last-start
                 spaces-so-far
                 (= 1 spaces-so-far))
            (free-white! (+ loc length) prev #f #f)]
           [(and last-start
                 spaces-so-far
                 (>= spaces-so-far 2))
            (cond
              [(= 2 spaces-so-far) (heap-set! last-start 'free-2)
                                   (heap-set! (+ last-start 1) #f)]
              [else (heap-set! last-start 'free-n)
                    (heap-set! (+ last-start 1) #f)
                    (heap-set! (+ last-start 2) spaces-so-far)])
            (if prev
              (heap-set! (+ prev 1) last-start)
              (heap-set! (free-list-head) last-start))
            (free-white! (+ loc length) last-start #f #f)]
           [else (free-white! (+ loc length) prev #f #f)])]
        [(white-flat white-pair white-proc free free-2 free-n)
         (define length (object-length loc))
         (cond
           [(and last-start spaces-so-far)
            (free-white! (+ loc length) prev last-start (+ spaces-so-far length))]
           [else (free-white! (+ loc length) prev loc length)])]
        [else (error 'free-white! "wrong tag at ~a" loc)])]))

;; collect-garbage : roots -> void
(define (collect-garbage . extra-roots)
  (move/roots (get-root-set))
  (move/roots extra-roots)
  (move/pointers (+ 1 (table-start))))

;; move/roots : loc/(listof loc) -> loc
;; move every thing reachable from 'roots'
;; to the to space
(define (move/roots thing)
  (cond
    [(list? thing)
     (for-each move/roots thing)]
    [(root? thing)
     (define new-addr (move/loc (read-root thing)))
     (set-root! thing new-addr)
     (move/ref new-addr)]
    [(number? thing)
     (move/ref (move/loc thing))]))

;; move/loc : loc -> loc
;; move object to the other semi-space
;; and return the new addr of moved object
(define (move/loc loc)
  (cond
    [(at-to-space? loc)
     (case (heap-ref loc)
       [(flat) (define new-addr (copy/alloc 2))
               (heap-set! new-addr 'flat)
               (heap-set! (+ new-addr 1) (heap-ref (+ loc 1)))
               (heap-set! loc 'frwd)
               (heap-set! (+ loc 1) new-addr)
               new-addr] ; ⋮
       [(pair)
        (define new (copy/alloc 3 (heap-ref (+ loc 1))
                                (heap-ref (+ loc 2))))
        (heap-set! new 'pair)
        (heap-set! (+ new 1)
                   (track/loc (heap-ref (+ loc 1))))
        (heap-set! (+ new 2)
                   (track/loc (heap-ref (+ loc 2))))
        (heap-set! loc 'frwd) (heap-set! (+ loc 1) new)
        new]
       [(proc) (define length (+ 3 (heap-ref (+ loc 2))))
               (define free-vars (build-vector (- length 3)
                                               (lambda (i)
                                                 (heap-ref (+ loc 3 i)))))
               (define new-addr (copy/alloc length free-vars '()))
               (for ([x (in-range 0 3)])
                 (heap-set! (+ new-addr x) (heap-ref (+ loc x))))
               (for ([x (in-range 3 length)])
                 (heap-set! (+ new-addr x) (track/loc (heap-ref (+ loc x)))))
               (heap-set! loc 'frwd)
               (heap-set! (+ loc 1) new-addr)
               new-addr] ; ⋮
       [(frwd) (heap-ref (+ loc 1))]
       [else (error 'move/loc "wrong tag ~s at ~a" (heap-ref loc) loc)])]
    [else loc]))

(define (copy/alloc n . extra-roots)
  (define next (find-free-space (heap-ref (free-list-head)) #f n))
  (cond [next next]
    [else
     (2nd-gen-gc extra-roots)
     (define next (find-free-space
                   (heap-ref (free-list-head)) #f n))
     (unless next
       (error 'copy/alloc "no space"))
     next]))

;; move/ref : loc -> loc
;; move the referenced object to the other semi-space
;; and return the new addr of moved object
(define (move/ref loc)
  (case (heap-ref loc)
    [(flat) (void)]
    [(pair) (gc:set-first! loc (move/loc (heap-ref (+ loc 1))))
            (gc:set-rest! loc (move/loc (heap-ref (+ loc 2))))
            (move/ref (heap-ref (+ loc 1)))
            (move/ref (heap-ref (+ loc 2)))]
    [(proc) (define fv-count (heap-ref (+ loc 2)))
            (for ([x (in-range 0 fv-count)])
              (define l (+ loc 3 x))
              (heap-set! l (move/loc (heap-ref l)))
              (move/ref (heap-ref l)))]
    [(frwd) (move/ref (heap-ref (+ 1 loc)))]
    [else (error 'move/ref "wrong tag at ~a" loc)]))

(define (move/pointers loc)
  (cond
    [(or (= loc (heap-size))
         (equal? 'free (heap-ref loc)))
     (heap-set! (table-start) (add1 (table-start)))]
    [else
     (define new-addr (move/loc (heap-ref (+ loc 1))))
     (heap-set! (+ loc 1) new-addr)
     (move/ref new-addr)
     (heap-set! loc 'free)
     (heap-set! (+ loc 1) 'free)
     (move/pointers (+ loc 2))]))

(module+ test
  (print-only-errors #t)
  (with-heap (make-vector 1000)
    (init-allocator)
    (test/exn
     (let ([cons-addr
            (gc:cons
             (simple-root (gc:alloc-flat #f))
             (simple-root (gc:alloc-flat #t)))])
       (gc:deref cons-addr))
     "non-flat"))

  (with-heap (make-vector 1000)
    (init-allocator)
    (test/exn
     (let ([flat-addr (gc:alloc-flat #f)])
       (gc:first flat-addr))
     "non pair"))

  (with-heap (make-vector 1000)
    (init-allocator)
    (test/exn
     (let ([flat-addr (gc:alloc-flat #f)])
       (gc:closure-code-ptr flat-addr))
     "non closure"))

  (with-heap (make-vector 1000)
    (init-allocator)
    (test/exn
     (let ([flat-addr (gc:alloc-flat #f)])
       (gc:closure-env-ref flat-addr 0))
     "non closure"))
  )

;; Part 2
(module+ test
  (with-heap (make-vector 1000)
    (init-allocator)
    (let ([flat-addr (gc:alloc-flat #t)])
      (test (gc:flat? flat-addr) #t)
      (test (gc:cons? flat-addr) #f)
      (test (gc:deref flat-addr) #t)))
  )


;; Part 3 cons cells
;; first and rest
(module+ test
  (with-heap (make-vector 1000)
    (init-allocator)
    (let ([cons-loc
           (gc:cons
            (simple-root (gc:alloc-flat 'first))
            (simple-root (gc:alloc-flat 'rest)))])
      (test (gc:deref (gc:rest cons-loc)) 'rest)
      (test (gc:deref (gc:first cons-loc)) 'first)))

  ;; setting cons parts
  (with-heap (make-vector 1000)
    (init-allocator)
    (let ([cons-loc
           (gc:cons
            (simple-root (gc:alloc-flat 'first))
            (simple-root (gc:alloc-flat 'rest)))])
      (test
       (begin (gc:set-first! cons-loc (gc:alloc-flat 'first-mutated))
              (gc:deref (gc:first cons-loc)))
       'first-mutated)
      (test
       (begin (gc:set-rest! cons-loc (gc:alloc-flat 'rest-mutated))
              (gc:deref (gc:rest cons-loc)))
       'rest-mutated)))
  )

; part 4 closures
(module+ test
  (with-heap (make-vector 1000)
    (init-allocator)
    (let ([closure-loc
           (gc:closure 'code-pointer  (list (simple-root (gc:alloc-flat 'sekrit))))])
      (test (gc:deref (gc:closure-env-ref closure-loc 0)) 'sekrit)
      (test (gc:closure-code-ptr closure-loc) 'code-pointer)))
  )