diff --git a/extraction/RecvBufferWrapper.fst b/extraction/RecvBufferWrapper.fst new file mode 100644 index 000000000..4cef4abe0 --- /dev/null +++ b/extraction/RecvBufferWrapper.fst @@ -0,0 +1,256 @@ +module RecvBufferWrapper + +#lang-pulse +open Pulse.Lib.Pervasives +open FStar.SizeT +module SZ = FStar.SizeT +module Seq = FStar.Seq +module A = Pulse.Lib.Array +module CB = Pulse.Lib.CircularBuffer +module Spec = Pulse.Lib.CircularBuffer.Spec +open Pulse.Lib.CircularBuffer.Spec +module GT = Pulse.Lib.CircularBuffer.GapTrack +module Pow2 = Pulse.Lib.CircularBuffer.Pow2 +module RM = Pulse.Lib.RangeVec +module RMSpec = Pulse.Lib.RangeMap.Spec +open Pulse.Lib.Trade + +type byte = Spec.byte + +/// Re-export circular_buffer and range_vec_t types +let circular_buffer = CB.circular_buffer +let range_vec_t = RM.range_vec_t +let write_result = CB.write_result + +/// Re-export read_view +let read_view = CB.read_view + +fn create + (alloc_len: SZ.t{SZ.v alloc_len > 0}) + (virt_len: SZ.t{SZ.v virt_len > 0}) + requires pure ( + Pow2.is_pow2 (SZ.v alloc_len) /\ + Pow2.is_pow2 (SZ.v virt_len) /\ + SZ.v alloc_len <= SZ.v virt_len /\ + SZ.v alloc_len <= Spec.cb_max_length /\ + SZ.v virt_len <= CB.pow2_63) + returns res : (circular_buffer & range_vec_t) + ensures exists* st. + CB.is_circular_buffer (fst res) (snd res) st ** + pure (Spec.cb_wf st /\ + st.base_offset == 0 /\ + st.alloc_length == SZ.v alloc_len /\ + st.virtual_length == SZ.v virt_len /\ + GT.contiguous_prefix_length st.contents == 0) +{ + CB.create alloc_len virt_len +} + +fn free + (cb: circular_buffer) + (rm: range_vec_t) + (#st: erased Spec.cb_state) + requires CB.is_circular_buffer cb rm st + ensures emp +{ + CB.free cb rm +} + +fn read_length + (cb: circular_buffer) (rm: range_vec_t) + (#st: erased Spec.cb_state) + requires CB.is_circular_buffer cb rm st + returns n : SZ.t + ensures CB.is_circular_buffer cb rm st ** + pure (SZ.v n == GT.contiguous_prefix_length st.contents) +{ + CB.read_length cb rm +} + +fn get_total_length + (cb: circular_buffer) (rm: range_vec_t) + (#st: erased Spec.cb_state) + requires CB.is_circular_buffer cb rm st + returns n: SZ.t + ensures CB.is_circular_buffer cb rm st ** + pure (SZ.v n <= st.base_offset + st.alloc_length) +{ + CB.get_total_length cb rm +} + +fn get_alloc_length + (cb: circular_buffer) + (rm: range_vec_t) + (#st: erased Spec.cb_state) + requires CB.is_circular_buffer cb rm st ** pure (Spec.cb_wf st) + returns n : SZ.t + ensures CB.is_circular_buffer cb rm st ** pure (SZ.v n == st.alloc_length) +{ + CB.get_alloc_length cb rm +} + +fn drain + (cb: circular_buffer) + (rm: range_vec_t) + (n: SZ.t) + (#st: erased Spec.cb_state) + requires + CB.is_circular_buffer cb rm st ** + pure (Spec.cb_wf st /\ SZ.v n <= st.alloc_length /\ + SZ.v n <= GT.contiguous_prefix_length st.contents /\ + SZ.fits (st.base_offset + SZ.v n)) + returns no_more_data: bool + ensures + CB.is_circular_buffer cb rm (Spec.drain_result st (SZ.v n)) ** + pure (no_more_data == (GT.contiguous_prefix_length st.contents = SZ.v n)) +{ + CB.drain cb rm n +} + +fn write_buffer + (cb: circular_buffer) (rm: range_vec_t) + (abs_offset: SZ.t) (src: A.array byte) (write_len: SZ.t) + (#p: perm) + (#src_data: erased (Seq.seq byte)) + (#st: erased Spec.cb_state) + requires + CB.is_circular_buffer cb rm st ** + A.pts_to src #p src_data ** + pure (Spec.cb_wf st /\ + SZ.v write_len == Seq.length src_data /\ + SZ.v write_len == A.length src /\ + SZ.v write_len > 0 /\ + SZ.v abs_offset + SZ.v write_len <= st.base_offset + st.virtual_length /\ + SZ.fits (SZ.v abs_offset + SZ.v write_len)) + returns wr: write_result + ensures exists* st'. + CB.is_circular_buffer cb rm st' ** + A.pts_to src #p src_data ** + pure (Spec.cb_wf st' /\ + st'.base_offset == st.base_offset /\ + st'.virtual_length == st.virtual_length /\ + (not wr.wrote ==> st'.alloc_length == st.alloc_length /\ + st'.read_start == st.read_start /\ + st'.contents == st.contents) /\ + (wr.wrote ==> st'.alloc_length >= st.alloc_length /\ + GT.contiguous_prefix_length st'.contents >= + GT.contiguous_prefix_length st.contents)) +{ + CB.write_buffer cb rm abs_offset src write_len +} + +fn read_buffer + (cb: circular_buffer) + (rm: range_vec_t) + (dst: A.array byte) + (read_len: SZ.t) + (#dst_data: erased (Seq.seq byte)) + (#st: erased Spec.cb_state) + requires + CB.is_circular_buffer cb rm st ** + A.pts_to dst dst_data ** + pure (Spec.cb_wf st /\ + SZ.v read_len <= GT.contiguous_prefix_length st.contents /\ + SZ.v read_len <= st.alloc_length /\ + SZ.v read_len <= A.length dst /\ + A.is_full_array dst) + ensures exists* (dst_data': Seq.seq byte). + CB.is_circular_buffer cb rm st ** + A.pts_to dst dst_data' ** + pure (Spec.cb_wf st /\ + SZ.v read_len <= Seq.length st.contents /\ + SZ.v read_len <= Seq.length dst_data' /\ + Seq.length dst_data' == Seq.length dst_data /\ + (forall (i:nat{i < SZ.v read_len}). + Some? (Seq.index st.contents i) /\ + Seq.index dst_data' i == Some?.v (Seq.index st.contents i))) +{ + CB.read_buffer cb rm dst read_len +} + +fn resize + (cb: circular_buffer) (rm: range_vec_t) (new_al: SZ.t{SZ.v new_al > 0}) + (#st: erased Spec.cb_state) + requires CB.is_circular_buffer cb rm st ** + pure (Spec.cb_wf st /\ Pow2.is_pow2 (SZ.v new_al) /\ + SZ.v new_al >= st.alloc_length /\ + SZ.v new_al <= st.virtual_length /\ + SZ.v new_al <= Spec.cb_max_length /\ + SZ.v new_al <= CB.pow2_63) + ensures CB.is_circular_buffer cb rm (Spec.resize_result st (SZ.v new_al)) +{ + CB.resize cb rm new_al +} + +fn set_virtual_length + (cb: circular_buffer) (rm: range_vec_t) (new_vl: SZ.t{SZ.v new_vl > 0}) + (#st: erased Spec.cb_state) + requires CB.is_circular_buffer cb rm st ** + pure (Spec.cb_wf st /\ + Pow2.is_pow2 (SZ.v new_vl) /\ + SZ.v new_vl >= st.virtual_length /\ + SZ.v new_vl <= CB.pow2_63) + ensures CB.is_circular_buffer cb rm ({ st with virtual_length = SZ.v new_vl }) +{ + CB.set_virtual_length cb rm new_vl +} + +fn read_zerocopy + (cb: circular_buffer) + (rm: range_vec_t) + (read_len: SZ.t) + (#st: erased Spec.cb_state) + requires + CB.is_circular_buffer cb rm st ** + pure (Spec.cb_wf st /\ + SZ.v read_len <= GT.contiguous_prefix_length st.contents /\ + SZ.v read_len <= st.alloc_length /\ + SZ.v read_len > 0) + returns rv: read_view + ensures exists* (s1 s2: Seq.seq byte). + CB.zc_segs rv s1 s2 ** + (CB.zc_segs rv s1 s2 @==> CB.is_circular_buffer cb rm st) ** + pure ( + SZ.v rv.len1 + SZ.v rv.len2 == SZ.v read_len /\ + SZ.v rv.off1 + SZ.v rv.len1 <= st.alloc_length /\ + SZ.v rv.off2 + SZ.v rv.len2 <= st.alloc_length) +{ + CB.read_zerocopy cb rm read_len +} + +fn release_read + (cb: circular_buffer) + (rm: range_vec_t) + (rv: read_view) + (#st: erased Spec.cb_state) + (#s1 #s2: erased (Seq.seq byte)) + requires + CB.zc_segs rv s1 s2 ** + (CB.zc_segs rv s1 s2 @==> CB.is_circular_buffer cb rm st) + ensures + CB.is_circular_buffer cb rm st +{ + CB.release_read cb rm rv +} + +fn drain_after_read + (cb: circular_buffer) + (rm: range_vec_t) + (rv: read_view) + (drain_len: SZ.t) + (#st: erased Spec.cb_state) + (#s1 #s2: erased (Seq.seq byte)) + requires + CB.zc_segs rv s1 s2 ** + (CB.zc_segs rv s1 s2 @==> CB.is_circular_buffer cb rm st) ** + pure (Spec.cb_wf st /\ + SZ.v drain_len <= st.alloc_length /\ + SZ.v drain_len <= GT.contiguous_prefix_length st.contents /\ + SZ.fits (st.base_offset + SZ.v drain_len)) + returns no_more_data: bool + ensures + CB.is_circular_buffer cb rm (Spec.drain_result st (SZ.v drain_len)) ** + pure (no_more_data == (GT.contiguous_prefix_length st.contents = SZ.v drain_len)) +{ + CB.drain_after_read cb rm rv drain_len +} diff --git a/extraction/TestKrmlBug.fst b/extraction/TestKrmlBug.fst new file mode 100644 index 000000000..f64f65a35 --- /dev/null +++ b/extraction/TestKrmlBug.fst @@ -0,0 +1,19 @@ +module TestKrmlBug +#lang-pulse + +open Pulse.Lib.Pervasives +open Pulse.Lib.Box +module SZ = FStar.SizeT +module AVL = Pulse.Lib.AVLTree +module T = Pulse.Lib.Spec.AVLTree + +// Wraps an AVL tree in a heap-allocated box +type my_tree = box (AVL.tree_t SZ.t unit) + +fn my_create (_u: unit) + requires emp + returns r: AVL.tree_t SZ.t unit + ensures AVL.is_tree r T.Leaf +{ + AVL.create SZ.t unit +} diff --git a/extraction/_c_output/Pulse_Lib_RangeVec.c b/extraction/_c_output/Pulse_Lib_RangeVec.c new file mode 100644 index 000000000..525bcf3a8 --- /dev/null +++ b/extraction/_c_output/Pulse_Lib_RangeVec.c @@ -0,0 +1,213 @@ +/* + This file was generated by KaRaMeL + KaRaMeL invocation: /home/eioannidis/karamel/krml -skip-compilation -skip-makefiles -skip-linking -warn-error -15-4-2 -tmpdir _c_output -library Pulse.Lib.Vector _krml_output/Pulse_Lib_RangeVec.krml _krml_output/Pulse_Lib_Vector.krml + F* version: + KaRaMeL version: a4caa585 + */ + +#include "Pulse_Lib_RangeVec.h" + +Pulse_Lib_RangeVec_range +Pulse_Lib_RangeVec_default_range = { .start = (size_t)0U, .len = (size_t)1U }; + +Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range +*Pulse_Lib_RangeVec_range_vec_create(void) +{ + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range + *rv = Pulse_Lib_Vector_create(Pulse_Lib_RangeVec_default_range, (size_t)1U); + Pulse_Lib_Vector_pop_back(rv, (void *)0U, (void *)0U); + return rv; +} + +void +Pulse_Lib_RangeVec_range_vec_free( + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range *rv +) +{ + Pulse_Lib_Vector_free(rv, (void *)0U, (void *)0U); +} + +size_t +Pulse_Lib_RangeVec_range_vec_contiguous_from( + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range *rv, + size_t base +) +{ + size_t sz = Pulse_Lib_Vector_size(rv, (void *)0U, (void *)0U); + if (sz == (size_t)0U) + return (size_t)0U; + else + { + Pulse_Lib_RangeVec_range first = Pulse_Lib_Vector_at(rv, (size_t)0U, (void *)0U, (void *)0U); + size_t r_high = first.start + first.len; + if (first.start <= base && base < r_high) + return r_high - base; + else + return (size_t)0U; + } +} + +size_t +Pulse_Lib_RangeVec_range_vec_max_endpoint( + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range *rv +) +{ + size_t sz = Pulse_Lib_Vector_size(rv, (void *)0U, (void *)0U); + if (sz == (size_t)0U) + return (size_t)0U; + else + { + size_t last_idx = sz - (size_t)1U; + Pulse_Lib_RangeVec_range last = Pulse_Lib_Vector_at(rv, last_idx, (void *)0U, (void *)0U); + return last.start + last.len; + } +} + +void +Pulse_Lib_RangeVec_vec_insert_at( + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range *rv, + size_t i, + Pulse_Lib_RangeVec_range r +) +{ + Pulse_Lib_Vector_push_back(rv, r, (void *)0U, (void *)0U); + size_t sz = Pulse_Lib_Vector_size(rv, (void *)0U, (void *)0U); + if (sz > (size_t)1U && i < sz - (size_t)1U) + { + size_t j = sz - (size_t)1U; + bool cont = true; + while (cont) + { + size_t jv = j; + if (jv > i) + { + Pulse_Lib_RangeVec_range + prev = Pulse_Lib_Vector_at(rv, jv - (size_t)1U, (void *)0U, (void *)0U); + Pulse_Lib_Vector_set(rv, jv, prev, (void *)0U, (void *)0U); + size_t new_j = jv - (size_t)1U; + j = new_j; + if (new_j == i) + cont = false; + } + else + cont = false; + } + Pulse_Lib_Vector_set(rv, i, r, (void *)0U, (void *)0U); + } +} + +void +Pulse_Lib_RangeVec_vec_remove_range( + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range *rv, + size_t i, + size_t count +) +{ + size_t sz = Pulse_Lib_Vector_size(rv, (void *)0U, (void *)0U); + size_t dst_end = sz - count; + size_t j = i; + bool shift_cont = true; + while (shift_cont) + { + size_t jv = j; + if (jv < dst_end) + { + Pulse_Lib_RangeVec_range val_ = Pulse_Lib_Vector_at(rv, jv + count, (void *)0U, (void *)0U); + Pulse_Lib_Vector_set(rv, jv, val_, (void *)0U, (void *)0U); + j = jv + (size_t)1U; + } + else + shift_cont = false; + } + size_t k = (size_t)0U; + bool pop_cont = true; + while (pop_cont) + { + size_t kv = k; + if (kv < count) + { + Pulse_Lib_Vector_pop_back(rv, (void *)0U, (void *)0U); + size_t new_k = kv + (size_t)1U; + k = new_k; + if (new_k == count) + pop_cont = false; + } + else + pop_cont = false; + } +} + +void +Pulse_Lib_RangeVec_range_vec_add( + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range *rv, + size_t offset, + size_t len +) +{ + size_t sz = Pulse_Lib_Vector_size(rv, (void *)0U, (void *)0U); + size_t off_plus_len = offset + len; + size_t idx = (size_t)0U; + bool search = true; + while (search) + { + size_t iv = idx; + if (iv < sz) + { + Pulse_Lib_RangeVec_range r = Pulse_Lib_Vector_at(rv, iv, (void *)0U, (void *)0U); + size_t high = r.start + r.len; + if (high < offset) + idx = iv + (size_t)1U; + else + search = false; + } + else + search = false; + } + size_t iv = idx; + if (sz == (size_t)0U || iv == sz) + { + Pulse_Lib_RangeVec_range r = { .start = offset, .len = len }; + Pulse_Lib_RangeVec_vec_insert_at(rv, iv, r); + } + else + { + Pulse_Lib_RangeVec_range first_r = Pulse_Lib_Vector_at(rv, iv, (void *)0U, (void *)0U); + if (off_plus_len < first_r.start) + Pulse_Lib_RangeVec_vec_insert_at(rv, + iv, + ((Pulse_Lib_RangeVec_range){ .start = offset, .len = len })); + else + { + size_t first_high = first_r.start + first_r.len; + size_t ite; + if (off_plus_len > first_high) + ite = off_plus_len; + else + ite = first_high; + size_t merged_high = ite; + size_t j = iv + (size_t)1U; + bool merge_cont = true; + while (merge_cont) + { + size_t jv = j; + if (jv < sz) + { + Pulse_Lib_RangeVec_range r = Pulse_Lib_Vector_at(rv, jv, (void *)0U, (void *)0U); + size_t mh = merged_high; + if (mh >= r.start) + { + size_t r_high = r.start + r.len; + if (r_high > mh) + merged_high = r_high; + j = jv + (size_t)1U; + } + else + merge_cont = false; + } + else + merge_cont = false; + } + } + } +} + diff --git a/extraction/_c_output/Pulse_Lib_RangeVec.h b/extraction/_c_output/Pulse_Lib_RangeVec.h new file mode 100644 index 000000000..e8d7ff4b3 --- /dev/null +++ b/extraction/_c_output/Pulse_Lib_RangeVec.h @@ -0,0 +1,76 @@ +/* + This file was generated by KaRaMeL + KaRaMeL invocation: /home/eioannidis/karamel/krml -skip-compilation -skip-makefiles -skip-linking -warn-error -15-4-2 -tmpdir _c_output -library Pulse.Lib.Vector _krml_output/Pulse_Lib_RangeVec.krml _krml_output/Pulse_Lib_Vector.krml + F* version: + KaRaMeL version: a4caa585 + */ + +#ifndef Pulse_Lib_RangeVec_H +#define Pulse_Lib_RangeVec_H + +#include "krmllib.h" + +typedef struct Pulse_Lib_RangeVec_range_s +{ + size_t start; + size_t len; +} +Pulse_Lib_RangeVec_range; + +extern Pulse_Lib_RangeVec_range Pulse_Lib_RangeVec_default_range; + +typedef struct Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range_s +{ + Pulse_Lib_RangeVec_range *arr; + size_t sz; + size_t cap; + Pulse_Lib_RangeVec_range default_val; +} +Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range; + +typedef Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range +*Pulse_Lib_RangeVec_range_vec_t; + +Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range +*Pulse_Lib_RangeVec_range_vec_create(void); + +void +Pulse_Lib_RangeVec_range_vec_free( + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range *rv +); + +size_t +Pulse_Lib_RangeVec_range_vec_contiguous_from( + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range *rv, + size_t base +); + +size_t +Pulse_Lib_RangeVec_range_vec_max_endpoint( + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range *rv +); + +void +Pulse_Lib_RangeVec_vec_insert_at( + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range *rv, + size_t i, + Pulse_Lib_RangeVec_range r +); + +void +Pulse_Lib_RangeVec_vec_remove_range( + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range *rv, + size_t i, + size_t count +); + +void +Pulse_Lib_RangeVec_range_vec_add( + Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range *rv, + size_t offset, + size_t len +); + + +#define Pulse_Lib_RangeVec_H_DEFINED +#endif /* Pulse_Lib_RangeVec_H */ diff --git a/extraction/_c_output/Pulse_Lib_Vector.h b/extraction/_c_output/Pulse_Lib_Vector.h new file mode 100644 index 000000000..849b59844 --- /dev/null +++ b/extraction/_c_output/Pulse_Lib_Vector.h @@ -0,0 +1,16 @@ +#ifndef Pulse_Lib_Vector_H +#define Pulse_Lib_Vector_H +#include "Pulse_Lib_RangeVec.h" + +typedef Pulse_Lib_RangeVec_range range_t_; +typedef Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range vec_t_; + +extern vec_t_ *Pulse_Lib_Vector_create(range_t_ def, size_t n); +extern void Pulse_Lib_Vector_free(vec_t_ *v, void *_s, void *_cap); +extern range_t_ Pulse_Lib_Vector_at(vec_t_ *v, size_t i, void *_s, void *_cap); +extern void Pulse_Lib_Vector_set(vec_t_ *v, size_t i, range_t_ x, void *_s, void *_cap); +extern size_t Pulse_Lib_Vector_size(vec_t_ *v, void *_s, void *_cap); +extern void Pulse_Lib_Vector_push_back(vec_t_ *v, range_t_ x, void *_s, void *_cap); +extern range_t_ Pulse_Lib_Vector_pop_back(vec_t_ *v, void *_s, void *_cap); + +#endif diff --git a/extraction/bench_rangevec.c b/extraction/bench_rangevec.c new file mode 100644 index 000000000..57a43b266 --- /dev/null +++ b/extraction/bench_rangevec.c @@ -0,0 +1,159 @@ +/* + * bench_rangevec.c — Benchmark for extracted Pulse RangeVec (vector-based range tracker) + * + * Compile: + * gcc -O2 -I ~/karamel/include -I ~/karamel/krmllib/dist/minimal \ + * bench_rangevec.c _c_output/Pulse_Lib_RangeVec.c -o bench_rangevec + * + * Run: + * ./bench_rangevec [iterations] + */ + +#include +#include +#include +#include +#include + +#include "_c_output/Pulse_Lib_RangeVec.h" + +/* ─── Timing ─────────────────────────────────────────────────── */ + +static inline uint64_t now_ns(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (uint64_t)ts.tv_sec * 1000000000ULL + (uint64_t)ts.tv_nsec; +} + +/* ─── PRNG (xorshift64) ─────────────────────────────────────── */ + +static uint64_t rng_state = 0x123456789ABCDEF0ULL; +static inline uint64_t xorshift64(void) { + uint64_t x = rng_state; + x ^= x << 13; x ^= x >> 7; x ^= x << 17; + return (rng_state = x); +} + +static void shuffle(uint32_t* arr, uint32_t n) { + for (uint32_t i = n - 1; i > 0; i--) { + uint32_t j = (uint32_t)(xorshift64() % (i + 1)); + uint32_t tmp = arr[i]; arr[i] = arr[j]; arr[j] = tmp; + } +} + +/* ─── Scenarios ──────────────────────────────────────────────── */ + +static void bench_sequential_add(uint32_t iters, uint32_t n_ranges, uint32_t chunk) { + uint64_t t0 = now_ns(); + for (uint32_t it = 0; it < iters; it++) { + Pulse_Lib_RangeVec_range_vec_t rv = Pulse_Lib_RangeVec_range_vec_create(); + for (uint32_t i = 0; i < n_ranges; i++) { + Pulse_Lib_RangeVec_range_vec_add(rv, (size_t)i * chunk, (size_t)chunk); + } + size_t cf = Pulse_Lib_RangeVec_range_vec_contiguous_from(rv, 0); + (void)cf; + Pulse_Lib_RangeVec_range_vec_free(rv); + } + uint64_t t1 = now_ns(); + double ms = (double)(t1 - t0) / 1e6; + uint64_t total_ops = (uint64_t)iters * n_ranges; + double ops_s = (double)total_ops / ((double)(t1 - t0) / 1e9); + printf(" Sequential add (%u ranges x %uB): %8.2f ms %10.0f add/s\n", + n_ranges, chunk, ms, ops_s); +} + +static void bench_ooo_add(uint32_t iters, uint32_t n_ranges, uint32_t chunk) { + uint32_t* order = malloc(n_ranges * sizeof(uint32_t)); + for (uint32_t i = 0; i < n_ranges; i++) order[i] = i; + + uint64_t t0 = now_ns(); + for (uint32_t it = 0; it < iters; it++) { + shuffle(order, n_ranges); + Pulse_Lib_RangeVec_range_vec_t rv = Pulse_Lib_RangeVec_range_vec_create(); + for (uint32_t i = 0; i < n_ranges; i++) { + Pulse_Lib_RangeVec_range_vec_add(rv, (size_t)order[i] * chunk, (size_t)chunk); + } + size_t cf = Pulse_Lib_RangeVec_range_vec_contiguous_from(rv, 0); + (void)cf; + Pulse_Lib_RangeVec_range_vec_free(rv); + } + uint64_t t1 = now_ns(); + double ms = (double)(t1 - t0) / 1e6; + uint64_t total_ops = (uint64_t)iters * n_ranges; + double ops_s = (double)total_ops / ((double)(t1 - t0) / 1e9); + printf(" OOO add (%u ranges x %uB): %8.2f ms %10.0f add/s\n", + n_ranges, chunk, ms, ops_s); + free(order); +} + +static void bench_gap_stress(uint32_t iters, uint32_t n_ranges, uint32_t chunk) { + uint64_t t0 = now_ns(); + for (uint32_t it = 0; it < iters; it++) { + Pulse_Lib_RangeVec_range_vec_t rv = Pulse_Lib_RangeVec_range_vec_create(); + /* Add every other range (max gaps) */ + for (uint32_t i = 0; i < n_ranges; i += 2) { + Pulse_Lib_RangeVec_range_vec_add(rv, (size_t)i * chunk, (size_t)chunk); + } + /* Fill gaps */ + for (uint32_t i = 1; i < n_ranges; i += 2) { + Pulse_Lib_RangeVec_range_vec_add(rv, (size_t)i * chunk, (size_t)chunk); + } + size_t cf = Pulse_Lib_RangeVec_range_vec_contiguous_from(rv, 0); + (void)cf; + Pulse_Lib_RangeVec_range_vec_free(rv); + } + uint64_t t1 = now_ns(); + double ms = (double)(t1 - t0) / 1e6; + uint64_t total_ops = (uint64_t)iters * n_ranges; + double ops_s = (double)total_ops / ((double)(t1 - t0) / 1e9); + printf(" Gap-fill add (%u ranges x %uB): %8.2f ms %10.0f add/s\n", + n_ranges, chunk, ms, ops_s); +} + +static void bench_queries(uint32_t iters, uint32_t n_ranges, uint32_t chunk) { + /* Build once, then query many times */ + Pulse_Lib_RangeVec_range_vec_t rv = Pulse_Lib_RangeVec_range_vec_create(); + for (uint32_t i = 0; i < n_ranges; i++) { + Pulse_Lib_RangeVec_range_vec_add(rv, (size_t)i * chunk, (size_t)chunk); + } + + uint64_t t0 = now_ns(); + for (uint32_t it = 0; it < iters; it++) { + size_t cf = Pulse_Lib_RangeVec_range_vec_contiguous_from(rv, 0); + size_t me = Pulse_Lib_RangeVec_range_vec_max_endpoint(rv); + (void)cf; (void)me; + } + uint64_t t1 = now_ns(); + double ms = (double)(t1 - t0) / 1e6; + double ops_s = (double)iters * 2 / ((double)(t1 - t0) / 1e9); + printf(" Queries (cf+maxep, %u ranges): %8.2f ms %10.0f query/s\n", + n_ranges, ms, ops_s); + Pulse_Lib_RangeVec_range_vec_free(rv); +} + +/* ─── Main ───────────────────────────────────────────────────── */ + +int main(int argc, char* argv[]) { + uint32_t iters = 1000; + if (argc > 1) { + iters = (uint32_t)atoi(argv[1]); + if (iters == 0) iters = 1000; + } + + printf("═══════════════════════════════════════════════════════════\n"); + printf(" RangeVec (Vector-based) Benchmark\n"); + printf(" Iterations: %u\n", iters); + printf("═══════════════════════════════════════════════════════════\n\n"); + + bench_sequential_add(iters, 256, 16); + bench_sequential_add(iters, 64, 256); + bench_ooo_add(iters, 256, 16); + bench_ooo_add(iters, 64, 256); + bench_gap_stress(iters, 256, 16); + bench_gap_stress(iters, 64, 256); + bench_queries(iters * 100, 256, 16); + bench_queries(iters * 100, 64, 256); + + printf("\n═══════════════════════════════════════════════════════════\n"); + return 0; +} diff --git a/extraction/krml-bug.md b/extraction/krml-bug.md new file mode 100644 index 000000000..b2566f0b3 --- /dev/null +++ b/extraction/krml-bug.md @@ -0,0 +1,118 @@ +# KaRaMeL Bug: Erased type parameter monomorphized as `any` vs `()` across `.fsti` boundary + +## Minimal reproducer (3 files, 45 lines total) + +### `MyLib.fsti` + +```fstar +module MyLib +#lang-pulse +open Pulse.Lib.Pervasives + +noeq type node (k v: Type0) = { + key: k; + value: v; + left: option (node k v); + right: option (node k v); +} + +type tree (k v: Type0) = option (node k v) + +val is_tree (#k #v: Type0) (ct: tree k v) (ft: Ghost.erased (tree k v)) : slprop + +fn create (k v: Type0) + requires emp + returns x: tree k v + ensures is_tree x (Ghost.hide (None #(node k v))) +``` + +### `MyLib.fst` + +```fstar +module MyLib +#lang-pulse +open Pulse.Lib.Pervasives + +let is_tree #k #v ct ft = pure (ct == Ghost.reveal ft) + +fn create (k v: Type0) + requires emp + returns x: tree k v + ensures is_tree x (Ghost.hide (None #(node k v))) +{ + let r : tree k v = None #(node k v); + fold (is_tree r (Ghost.hide (None #(node k v)))); + r +} +``` + +### `MyCaller.fst` + +```fstar +module MyCaller +#lang-pulse +open Pulse.Lib.Pervasives + +fn test (_u: unit) + requires emp + returns r: MyLib.tree FStar.SizeT.t unit + ensures MyLib.is_tree r (Ghost.hide (None #(MyLib.node FStar.SizeT.t unit))) +{ + MyLib.create FStar.SizeT.t unit +} +``` + +## Steps to reproduce + +```bash +PULSE_HOME=~/pulse # adjust as needed +FO="--include $PULSE_HOME/out/lib/pulse --include $PULSE_HOME/out/lib/pulse/lib \ + --include $PULSE_HOME/build/lib.pulse.checked --include $PULSE_HOME/build/lib.common.checked \ + --include $PULSE_HOME/lib/pulse/lib --include $PULSE_HOME/lib/common \ + --cache_checked_modules --cache_dir _cache --warn_error -321 \ + --ext optimize_let_vc --load_cmxs pulse" +mkdir -p _cache _output + +# 1. Verify +fstar.exe $FO MyLib.fsti && fstar.exe $FO MyLib.fst && fstar.exe $FO MyCaller.fst + +# 2. Extract to .krml +fstar.exe $FO --codegen krml --odir _output --extract_module MyLib MyLib.fst +fstar.exe $FO --codegen krml --odir _output --extract_module MyCaller MyCaller.fst + +# 3. Run KaRaMeL +krml -skip-compilation -skip-makefiles -skip-linking \ + -warn-error -2-4-9-15-17-18-26 \ + $(find $KRML_HOME/krmllib/.extract -name '*.krml') _output/*.krml +``` + +## Observed + +``` +Cannot re-check MyCaller.test as valid Low* and will not extract it. +``` + +With `-d checker`: + +``` +option__MyLib_node__size_t_any <=? option__MyLib_node__size_t_() +``` + +## Root cause + +Pulse extraction replaces the erased type parameter `v` with `any` in `MyLib.fst`'s `.krml` output, but the caller `MyCaller.fst` instantiates `v = unit` which becomes `()`. After monomorphization, two incompatible C struct types are created: + +| Variant | C fields | Source | +|---------|----------|--------| +| `node__size_t_()` | `{ key, left, right }` | Caller: `v = unit` → erased to `()` (3 fields) | +| `node__size_t_any` | `{ key, value, left, right }` | Impl: `v = any` → kept as `void*` (4 fields) | + +`create` returns `node__size_t_any*` but the caller expects `node__size_t_()*`. These have incompatible memory layouts. + +**Note:** This does NOT happen with plain F* (non-Pulse) modules. Plain F* extraction keeps `v` as a proper type variable in the `.krml`, so monomorphization correctly unifies it. The issue is specific to how Pulse `fn` elaboration handles Type-kinded parameters during extraction. + +## Expected + +The erased `v` should either: +1. Be kept as a type variable in the `.krml` (like plain F* does), so monomorphization unifies it, or +2. `any` and `()` should be treated as equivalent during monomorphization for erased parameters diff --git a/extraction/krmlinit_rv.c b/extraction/krmlinit_rv.c new file mode 100644 index 000000000..28388ff46 --- /dev/null +++ b/extraction/krmlinit_rv.c @@ -0,0 +1,54 @@ +#include +#include +#include +#include +#include "Pulse_Lib_RangeVec.h" + +/* ---- Vector operations for Pulse_Lib_RangeVec_range ---- */ +/* Ghost parameters (erased #s, #cap) are passed as void* null pointers */ + +typedef Pulse_Lib_RangeVec_range range_t; +typedef Pulse_Lib_Vector_vector_internal__Pulse_Lib_RangeVec_range vec_t; + +vec_t *Pulse_Lib_Vector_create(range_t def, size_t n) { + vec_t *v = (vec_t *)malloc(sizeof(vec_t)); + v->arr = (range_t *)malloc(n * sizeof(range_t)); + for (size_t i = 0; i < n; i++) v->arr[i] = def; + v->sz = n; + v->cap = n; + v->default_val = def; + return v; +} + +void Pulse_Lib_Vector_free(vec_t *v, void *_s, void *_cap) { + free(v->arr); + free(v); +} + +range_t Pulse_Lib_Vector_at(vec_t *v, size_t i, void *_s, void *_cap) { + return v->arr[i]; +} + +void Pulse_Lib_Vector_set(vec_t *v, size_t i, range_t x, void *_s, void *_cap) { + v->arr[i] = x; +} + +size_t Pulse_Lib_Vector_size(vec_t *v, void *_s, void *_cap) { + return v->sz; +} + +void Pulse_Lib_Vector_push_back(vec_t *v, range_t x, void *_s, void *_cap) { + if (v->sz >= v->cap) { + size_t new_cap = v->cap == 0 ? 1 : v->cap * 2; + range_t *new_arr = (range_t *)malloc(new_cap * sizeof(range_t)); + if (v->sz > 0) memcpy(new_arr, v->arr, v->sz * sizeof(range_t)); + free(v->arr); + v->arr = new_arr; + v->cap = new_cap; + } + v->arr[v->sz++] = x; +} + +range_t Pulse_Lib_Vector_pop_back(vec_t *v, void *_s, void *_cap) { + return v->arr[--v->sz]; +} diff --git a/extraction/repro_krml_bundle_bug.sh b/extraction/repro_krml_bundle_bug.sh new file mode 100755 index 000000000..4525221c7 --- /dev/null +++ b/extraction/repro_krml_bundle_bug.sh @@ -0,0 +1,102 @@ +#!/bin/bash +# Reproducer for KaRaMeL issue: Pulse fn functions fail Low* re-check +# due to monomorphization type mismatch between erased type variants. +# +# Bug summary: +# When a Pulse module (AVLTree, with .fsti) defines a polymorphic data type +# `tree_t (k:Type) (v:Type)` where `v` is erased, krml's monomorphizer creates +# two incompatible type instantiations: +# - node (from the caller's erased argument) +# - node (from the callee's internal representation) +# +# The checker then rejects the function because subtype check fails: +# option__node__key_any* <=? option__node__key_()* +# These are different C struct layouts (one has a `void* value` field, the +# other doesn't). +# +# Impact: 5 Pulse `fn` functions in RangeMap.fst are silently dropped from +# the C output. Users must provide manual C implementations. +# +# Expected: krml should unify `()` and `any` as the same erased type during +# monomorphization, producing a single struct type. +# +# Environment: +# - fstar.exe version: (see output) +# - krml version: (see output) +# - Pulse: FStarLang/pulse branch lef/circular_buffer + +set -e + +PULSE_HOME="${PULSE_HOME:-$HOME/pulse}" +KRML="${KRML:-$HOME/karamel/krml}" +KRMLLIB="${KRMLLIB:-$HOME/karamel/krmllib}" + +echo "fstar.exe version: $(fstar.exe --version 2>&1 | head -1)" +echo "krml version: $($KRML -version 2>&1 | head -1)" +echo "" + +FSTAR_OPTS=" + --include $PULSE_HOME/out/lib/pulse + --include $PULSE_HOME/out/lib/pulse/lib + --include $PULSE_HOME/build/lib.pulse.checked + --include $PULSE_HOME/build/lib.common.checked + --include $PULSE_HOME/lib/pulse/lib + --include $PULSE_HOME/lib/common + --cache_checked_modules + --cache_dir /tmp/krml_repro_cache + --warn_error -321 + --ext optimize_let_vc + --load_cmxs pulse +" + +MODULES="Pulse.Lib.RangeMap Pulse.Lib.RangeMap.Spec Pulse.Lib.AVLTree Pulse.Lib.Spec.AVLTree" + +rm -rf /tmp/krml_repro_cache /tmp/krml_repro_output /tmp/krml_repro_c +mkdir -p /tmp/krml_repro_cache /tmp/krml_repro_output /tmp/krml_repro_c + +echo "=== Step 1: Verify & extract Pulse modules to .krml ===" +for mod in $MODULES; do + src="$PULSE_HOME/lib/pulse/lib/${mod}.fst" + echo " Verifying $mod..." + fstar.exe $FSTAR_OPTS "$src" 2>&1 | tail -1 +done + +for mod in $MODULES; do + src="$PULSE_HOME/lib/pulse/lib/${mod}.fst" + echo " Extracting $mod to .krml..." + fstar.exe $FSTAR_OPTS --already_cached ',*' \ + --codegen krml --extract_module "$mod" \ + --odir /tmp/krml_repro_output "$src" 2>&1 | tail -1 +done + +echo "" +echo "=== Step 2: Run krml (expect 5 functions to be dropped) ===" +$KRML -skip-compilation -skip-makefiles -skip-linking \ + -warn-error -2 -warn-error -15 -warn-error -4 -warn-error -26 \ + -warn-error -18 -warn-error -9 -warn-error -17 \ + -tmpdir /tmp/krml_repro_c \ + "$KRMLLIB/.extract/"*.krml \ + /tmp/krml_repro_output/*.krml 2>&1 | grep "Cannot re-check" + +echo "" +echo "=== Step 3: Type mismatch detail ===" +$KRML -skip-compilation -skip-makefiles -skip-linking \ + -warn-error -2 -warn-error -15 -warn-error -4 -warn-error -26 \ + -warn-error -18 -warn-error -9 -warn-error -17 \ + -d checker \ + -tmpdir /tmp/krml_repro_c \ + "$KRMLLIB/.extract/"*.krml \ + /tmp/krml_repro_output/*.krml 2>&1 | grep -B1 "Cannot re-check.*range_map_create" + +echo "" +echo "=== Root Cause ===" +echo "The AVLTree module (has .fsti) defines:" +echo " type tree_t (k:Type) (v:Type) = option (node k v)" +echo "" +echo "After Pulse elaboration + monomorphization, krml creates TWO structs:" +echo " node__range_() = { key: range; left: option*; right: option* }" +echo " node__range_any = { key: range; value: void*; left: option*; right: option* }" +echo "" +echo "RangeMap uses node but AVLTree internally uses node." +echo "The checker fails because these have different C memory layouts." +echo "krml should unify () and any as the same erased type." diff --git a/lib/pulse/lib/Pulse.Lib.AVLTree.fst b/lib/pulse/lib/Pulse.Lib.AVLTree.fst index 91933dce6..b016e0690 100644 --- a/lib/pulse/lib/Pulse.Lib.AVLTree.fst +++ b/lib/pulse/lib/Pulse.Lib.AVLTree.fst @@ -27,54 +27,55 @@ module Box = Pulse.Lib.Box open Pulse.Lib.Box { box, (:=), (!) } noeq -type node (t:Type0) = { - data : t; - left : tree_t t; - right : tree_t t; +type node (k:Type0) (v:Type0) = { + key : k; + value : v; + left : tree_t k v; + right : tree_t k v; } -and node_ptr (t:Type0) = box (node t) +and node_ptr (k:Type0) (v:Type0) = box (node k v) //A nullable pointer to a node -and tree_t (t:Type0) = option (node_ptr t) +and tree_t (k:Type0) (v:Type0) = option (node_ptr k v) -let rec is_tree #t ct ft = match ft with +let rec is_tree #k #v ct ft = match ft with | T.Leaf -> pure (ct == None) - | T.Node data left' right' -> - exists* (p:node_ptr t) (lct:tree_t t) (rct:tree_t t). + | T.Node nd_key nd_val left' right' -> + exists* (p:node_ptr k v) (lct:tree_t k v) (rct:tree_t k v). pure (ct == Some p) ** - (p |-> { data = data ; left = lct ; right = rct}) ** + (p |-> { key = nd_key ; value = nd_val ; left = lct ; right = rct}) ** is_tree lct left' ** is_tree rct right' ghost -fn elim_is_tree_leaf (#t:Type0) (x:tree_t t) +fn elim_is_tree_leaf (#k:Type0) (#v:Type0) (x:tree_t k v) requires is_tree x T.Leaf ensures pure (x == None) { - unfold (is_tree x T.Leaf) + unfold (is_tree x T.Leaf) } ghost -fn intro_is_tree_leaf (#t:Type0) (x:tree_t t) +fn intro_is_tree_leaf (#k:Type0) (#v:Type0) (x:tree_t k v) requires pure (x == None) ensures is_tree x T.Leaf { - fold (is_tree x T.Leaf); + fold (is_tree x T.Leaf); } ghost -fn elim_is_tree_node (#t:Type0) (ct:tree_t t) (data:t) (ltree:T.tree t) (rtree:T.tree t) - requires is_tree ct (T.Node data ltree rtree) +fn elim_is_tree_node (#k:Type0) (#v:Type0) (ct:tree_t k v) (nd_key:k) (nd_val:v) (ltree:T.tree k v) (rtree:T.tree k v) + requires is_tree ct (T.Node nd_key nd_val ltree rtree) ensures ( - exists* (p:node_ptr t) (lct:tree_t t) (rct:tree_t t). + exists* (p:node_ptr k v) (lct:tree_t k v) (rct:tree_t k v). pure (ct == Some p) ** - (p |-> { data; left = lct;right = rct }) ** + (p |-> { key = nd_key; value = nd_val; left = lct; right = rct }) ** is_tree lct ltree ** is_tree rct rtree ) @@ -87,32 +88,32 @@ module G = FStar.Ghost ghost -fn intro_is_tree_node (#t:Type0) (ct:tree_t t) (v:node_ptr t) (#node:node t) (#ltree:T.tree t) (#rtree:T.tree t) +fn intro_is_tree_node (#k:Type0) (#v:Type0) (ct:tree_t k v) (p:node_ptr k v) (#nd:node k v) (#ltree:T.tree k v) (#rtree:T.tree k v) requires - (v |-> node) ** - is_tree node.left ltree ** - is_tree node.right rtree ** - pure (ct == Some v) + (p |-> nd) ** + is_tree nd.left ltree ** + is_tree nd.right rtree ** + pure (ct == Some p) ensures - is_tree ct (T.Node node.data ltree rtree) + is_tree ct (T.Node nd.key nd.value ltree rtree) { - fold (is_tree ct (T.Node node.data ltree rtree)) + fold (is_tree ct (T.Node nd.key nd.value ltree rtree)) } [@@no_mkeys] // internal only -let is_tree_cases #t (x : option (box (node t))) (ft : T.tree t) +let is_tree_cases #k #v (x : option (box (node k v))) (ft : T.tree k v) = match x with | None -> pure (ft == T.Leaf) - | Some v -> - exists* (n:node t) (ltree:T.tree t) (rtree:T.tree t). - (v |-> n) ** - pure (ft == T.Node n.data ltree rtree) ** + | Some p -> + exists* (n:node k v) (ltree:T.tree k v) (rtree:T.tree k v). + (p |-> n) ** + pure (ft == T.Node n.key n.value ltree rtree) ** is_tree n.left ltree ** is_tree n.right rtree ghost -fn cases_of_is_tree #t (x:tree_t t) (ft:T.tree t) +fn cases_of_is_tree #k #v (x:tree_t k v) (ft:T.tree k v) requires is_tree x ft ensures is_tree_cases x ft { @@ -122,8 +123,8 @@ fn cases_of_is_tree #t (x:tree_t t) (ft:T.tree t) fold (is_tree_cases None ft); rewrite is_tree_cases None ft as is_tree_cases x ft; } - T.Node data ltree rtree -> { - unfold (is_tree x (T.Node data ltree rtree)); + T.Node nd_key nd_val ltree rtree -> { + unfold (is_tree x (T.Node nd_key nd_val ltree rtree)); with p lct rct. _; with n. assert p |-> n; with l'. rewrite is_tree lct l' as is_tree n.left l'; @@ -136,9 +137,9 @@ fn cases_of_is_tree #t (x:tree_t t) (ft:T.tree t) - + ghost -fn is_tree_case_none (#t:Type) (x:tree_t t) (#l:T.tree t) +fn is_tree_case_none (#k:Type) (#v:Type) (x:tree_t k v) (#l:T.tree k v) preserves is_tree x l requires pure (x == None) ensures pure (l == T.Leaf) @@ -152,29 +153,29 @@ fn is_tree_case_none (#t:Type) (x:tree_t t) (#l:T.tree t) - + ghost -fn is_tree_case_some (#t:Type) (x:tree_t t) (v:node_ptr t) (#ft:T.tree t) +fn is_tree_case_some (#k:Type) (#v:Type) (x:tree_t k v) (p:node_ptr k v) (#ft:T.tree k v) requires is_tree x ft - requires pure (x == Some v) + requires pure (x == Some p) ensures - exists* (node:node t) (ltree:T.tree t) (rtree:T.tree t). - (v |-> node) ** - is_tree node.left ltree ** - is_tree node.right rtree ** - pure (ft == T.Node node.data ltree rtree) - -{ - rewrite each x as Some v; - cases_of_is_tree (Some v) ft; + exists* (nd:node k v) (ltree:T.tree k v) (rtree:T.tree k v). + (p |-> nd) ** + is_tree nd.left ltree ** + is_tree nd.right rtree ** + pure (ft == T.Node nd.key nd.value ltree rtree) + +{ + rewrite each x as Some p; + cases_of_is_tree (Some p) ft; unfold is_tree_cases; } /////////////////////////////////////////////////////////////////////////////// - -fn rec height (#t:Type0) (x:tree_t t) + +fn rec height (#k:Type0) (#v:Type0) (x:tree_t k v) preserves is_tree x 'l returns n:nat ensures pure (n == T.height 'l) @@ -202,7 +203,7 @@ fn rec height (#t:Type0) (x:tree_t t) -fn is_empty (#t:Type) (x:tree_t t) (#ft:G.erased(T.tree t)) +fn is_empty (#k:Type) (#v:Type) (x:tree_t k v) (#ft:G.erased(T.tree k v)) preserves is_tree x ft returns b:bool ensures pure (b <==> (T.is_empty ft)) @@ -222,54 +223,38 @@ fn is_empty (#t:Type) (x:tree_t t) (#ft:G.erased(T.tree t)) } -let null_tree_t (t:Type0) : tree_t t = None +let null_tree_t (k:Type0) (v:Type0) : tree_t k v = None -fn create (t:Type0) - returns x:tree_t t +fn create (k:Type0) (v:Type0) + returns x:tree_t k v ensures is_tree x T.Leaf { - let tree = null_tree_t t; + let tree = null_tree_t k v; intro_is_tree_leaf tree; tree } -fn node_cons (#t:Type0) (v:t) (ltree:tree_t t) (rtree:tree_t t) (#l:(T.tree t)) (#r:(T.tree t)) +fn node_cons (#k:Type0) (#v:Type0) (nd_key:k) (nd_val:v) (ltree:tree_t k v) (rtree:tree_t k v) (#l:(T.tree k v)) (#r:(T.tree k v)) requires is_tree ltree l ** - is_tree rtree r //functional equivalent of x is 'l; x is the tail of the constructed tree. - returns y:tree_t t - ensures is_tree y (T.Node v l r) + is_tree rtree r + returns y:tree_t k v + ensures is_tree y (T.Node nd_key nd_val l r) ensures (pure (Some? y)) { - let y = Box.alloc { data=v; left =ltree; right = rtree }; + let y = Box.alloc { key=nd_key; value=nd_val; left=ltree; right=rtree }; intro_is_tree_node (Some y) y; Some y } -/// Appends value [v] at the leftmost leaf of the tree that [ptr] points to. - -fn rec append_left_none (#t:Type0) (x:tree_t t) (v:t) (#ft:G.erased (T.tree t)) - preserves is_tree x ft - requires pure (None? x) - returns y:tree_t t - ensures is_tree y (T.Node v T.Leaf T.Leaf) -{ - let left = create t; - let right = create t; - let y = node_cons v left right; - y -} - - - -fn rec append_left (#t:Type0) (x:tree_t t) (v:t) (#ft:G.erased (T.tree t)) +fn rec append_left (#k:Type0) (#v:Type0) (x:tree_t k v) (ak:k) (av:v) (#ft:G.erased (T.tree k v)) requires is_tree x ft - returns y:tree_t t - ensures is_tree y (T.append_left ft v) + returns y:tree_t k v + ensures is_tree y (T.append_left ft ak av) { match x { None -> { @@ -279,18 +264,18 @@ fn rec append_left (#t:Type0) (x:tree_t t) (v:t) (#ft:G.erased (T.tree t)) elim_is_tree_leaf None; - let left = create t; - let right = create t; - - - let y = node_cons v left right; - + let left = create k v; + let right = create k v; + + + let y = node_cons ak av left right; + let np = Some?.v y; - + is_tree_case_some y np; intro_is_tree_node y np; - y + y } Some vl -> { @@ -298,7 +283,7 @@ fn rec append_left (#t:Type0) (x:tree_t t) (v:t) (#ft:G.erased (T.tree t)) let node = !vl; - let left_new = append_left node.left v; + let left_new = append_left node.left ak av; vl := {node with left = left_new}; @@ -307,15 +292,15 @@ fn rec append_left (#t:Type0) (x:tree_t t) (v:t) (#ft:G.erased (T.tree t)) x } } -} +} -fn rec append_right (#t:Type0) (x:tree_t t) (v:t) (#ft:G.erased (T.tree t)) +fn rec append_right (#k:Type0) (#v:Type0) (x:tree_t k v) (ak:k) (av:v) (#ft:G.erased (T.tree k v)) requires is_tree x ft - returns y:tree_t t - ensures is_tree y (T.append_right ft v) + returns y:tree_t k v + ensures is_tree y (T.append_right ft ak av) { match x { None -> { @@ -324,62 +309,43 @@ fn rec append_right (#t:Type0) (x:tree_t t) (v:t) (#ft:G.erased (T.tree t)) elim_is_tree_leaf None; - let left = create t; - let right = create t; - - - let y = node_cons v left right; - + let left = create k v; + let right = create k v; + + + let y = node_cons ak av left right; + let np = Some?.v y; - + is_tree_case_some y np; intro_is_tree_node y np; - - y + + y } Some np -> { is_tree_case_some (Some np) np; let node = !np; - let right_new = append_right node.right v; - + let right_new = append_right node.right ak av; + np := {node with right = right_new}; - + intro_is_tree_node x np; - + x } } -} - - - - -fn node_data (#t:Type) (x:tree_t t) (#ft:G.erased (T.tree t)) - preserves is_tree x ft - requires (pure (Some? x)) - returns v:t -{ - let np = Some?.v x; - - is_tree_case_some x np; - - let node = !np; - - let v = node.data; - - intro_is_tree_node x np; - v } -fn rec mem (#t:eqtype) (x:tree_t t) (v: t) (#ft:G.erased (T.tree t)) + +fn rec mem (#k:eqtype) (#v:Type0) (x:tree_t k v) (search_key: k) (#ft:G.erased (T.tree k v)) preserves is_tree x ft returns b:bool - ensures pure (b <==> (T.mem ft v)) + ensures pure (b <==> (T.mem ft search_key)) { match x { None -> { @@ -391,16 +357,16 @@ fn rec mem (#t:eqtype) (x:tree_t t) (v: t) (#ft:G.erased (T.tree t)) is_tree_case_some (Some vl) vl; let n = !vl; - let dat = n.data; + let dat = n.key; - if (dat = v) + if (dat = search_key) { intro_is_tree_node x vl; true } else{ - let b1 = mem n.left v; - let b2 = mem n.right v; + let b1 = mem n.left search_key; + let b2 = mem n.right search_key; let b3 = b1 || b2; intro_is_tree_node x vl; @@ -414,126 +380,127 @@ fn rec mem (#t:eqtype) (x:tree_t t) (v: t) (#ft:G.erased (T.tree t)) -fn get_some_ref (#t:Type) (x:tree_t t) +fn get_some_ref (#k:Type) (#v:Type) (x:tree_t k v) requires is_tree x 'l requires pure (T.Node? 'l) - returns v:node_ptr t -ensures - exists* (node:node t) (ltree:T.tree t) (rtree:T.tree t). - pure (x == Some v) ** - pure ('l == T.Node node.data ltree rtree) ** - (v |-> node) ** - is_tree node.left ltree ** - is_tree node.right rtree + returns p:node_ptr k v +ensures + exists* (nd:node k v) (ltree:T.tree k v) (rtree:T.tree k v). + pure (x == Some p) ** + pure ('l == T.Node nd.key nd.value ltree rtree) ** + (p |-> nd) ** + is_tree nd.left ltree ** + is_tree nd.right rtree { match x { None -> { is_tree_case_none None; unreachable () } - Some v -> { - is_tree_case_some (Some v) v; - v + Some p -> { + is_tree_case_some (Some p) p; + p } } } -[@@pulse_unfold] let _left (t:T.tree 'a{T.Node? t}) = T.Node?.left t -[@@pulse_unfold] let _right (t:T.tree 'a{T.Node? t}) = T.Node?.right t -[@@pulse_unfold] let _data (t:T.tree 'a{T.Node? t}) = T.Node?.data t +[@@pulse_unfold] let _left (#k:Type) (#v:Type) (t:T.tree k v{T.Node? t}) = T.Node?.left t +[@@pulse_unfold] let _right (#k:Type) (#v:Type) (t:T.tree k v{T.Node? t}) = T.Node?.right t +[@@pulse_unfold] let _key (#k:Type) (#v:Type) (t:T.tree k v{T.Node? t}) = T.Node?.key t +[@@pulse_unfold] let _val (#k:Type) (#v:Type) (t:T.tree k v{T.Node? t}) = T.Node?.value t fn read_node - (#a:Type0) - (tree : tree_t a) - (#t : erased (T.tree a){T.Node? t}) + (#k:Type0) (#v:Type0) + (tree : tree_t k v) + (#t : erased (T.tree k v){T.Node? t}) requires is_tree tree t - (* ^ Some? p should be trivial, but just kick the ball to the caller *) - returns res : tree_t a & a & tree_t a & squash (Some? tree) - (* ^ squash to help with spec well-formedness *) + returns res : tree_t k v & k & v & tree_t k v & squash (Some? tree) ensures ( - let (l, x, r, _) = res in - (Some?.v tree |-> {left = l; data = x; right = r}) + let (l, xk, xv, r, _) = res in + (Some?.v tree |-> {left = l; key = xk; value = xv; right = r}) ** is_tree l (_left t) ** is_tree r (_right t) - ** pure (x == _data t) + ** pure (xk == _key t) + ** pure (xv == _val t) ) { let p = get_some_ref tree; - with node. assert p |-> node; + with nd. assert p |-> nd; let n = !p; rewrite p |-> n as Some?.v tree |-> n; - (n.left, n.data, n.right, ()) + (n.left, n.key, n.value, n.right, ()) } fn write_node - (#a:Type0) - (tree : tree_t a{Some? tree}) - (lp : tree_t a) - (data : a) - (rp : tree_t a) - (#lt #rt : erased (T.tree a)) + (#k:Type0) (#v:Type0) + (tree : tree_t k v{Some? tree}) + (lp : tree_t k v) + (nd_key : k) + (nd_val : v) + (rp : tree_t k v) + (#lt #rt : erased (T.tree k v)) requires (Some?.v tree |-> 'n0) ** is_tree lp lt ** is_tree rp rt ensures - is_tree tree (T.Node data lt rt) + is_tree tree (T.Node nd_key nd_val lt rt) { - let n = { data = data; left = lp; right = rp }; + let n = { key = nd_key; value = nd_val; left = lp; right = rp }; let Some p = tree; p := n; - fold (is_tree tree (T.Node data lt rt)) + fold (is_tree tree (T.Node nd_key nd_val lt rt)) } -fn rotate_left (#t:Type0) (tree:tree_t t) (#l: G.erased (T.tree t){ Some? (T.rotate_left l) }) +fn rotate_left (#k:Type0) (#v:Type0) (tree:tree_t k v) (#l: G.erased (T.tree k v){ Some? (T.rotate_left l) }) requires is_tree tree l - returns y : tree_t t + returns y : tree_t k v ensures is_tree y (Some?.v (T.rotate_left l)) { - let a, b, p', _ = read_node tree; - let c, d, e, _ = read_node p'; - write_node p' a b c; - write_node tree p' d e; - tree (* Note: in-place mutation, we could make this return unit instead. *) + let a, bk, bv, p', _ = read_node tree; + let c, dk, dv, e, _ = read_node p'; + write_node p' a bk bv c; + write_node tree p' dk dv e; + tree } -fn rotate_right (#t:Type0) (tree:tree_t t) (#l:G.erased (T.tree t){ Some? (T.rotate_right l) }) +fn rotate_right (#k:Type0) (#v:Type0) (tree:tree_t k v) (#l:G.erased (T.tree k v){ Some? (T.rotate_right l) }) requires is_tree tree l - returns y:tree_t t + returns y:tree_t k v ensures (is_tree y (Some?.v (T.rotate_right l))) { - let p', d, e, _ = read_node tree; - let a, b, c, _ = read_node p'; - write_node p' c d e; - write_node tree a b p'; + let p', dk, dv, e, _ = read_node tree; + let a, bk, bv, c, _ = read_node p'; + write_node p' c dk dv e; + write_node tree a bk bv p'; tree } -fn rotate_right_left (#t:Type0) (tree:tree_t t) (#l:G.erased (T.tree t){ Some? (T.rotate_right_left l) }) +fn rotate_right_left (#k:Type0) (#v:Type0) (tree:tree_t k v) (#l:G.erased (T.tree k v){ Some? (T.rotate_right_left l) }) requires is_tree tree l - returns y : tree_t t + returns y : tree_t k v ensures is_tree y (Some?.v (T.rotate_right_left l)) { - let a, x, zp, _ = read_node tree; - let yp, z, d, _ = read_node zp; - let b, y, c, _ = read_node yp; - write_node zp c z d; - write_node yp a x b; - write_node tree yp y zp; + let a, xk, xv, zp, _ = read_node tree; + let yp, zk, zv, d, _ = read_node zp; + let b, yk, yv, c, _ = read_node yp; + write_node zp c zk zv d; + write_node yp a xk xv b; + write_node tree yp yk yv zp; tree } -fn rotate_left_right (#t:Type0) (tree:tree_t t) (#l:G.erased (T.tree t){ Some? (T.rotate_left_right l) }) +fn rotate_left_right (#k:Type0) (#v:Type0) (tree:tree_t k v) (#l:G.erased (T.tree k v){ Some? (T.rotate_left_right l) }) requires is_tree tree l - returns y :tree_t t + returns y :tree_t k v ensures is_tree y (Some?.v (T.rotate_left_right l)) { - let zp, x, d, _ = read_node tree; - let a, z, yp, _ = read_node zp; - let b, y, c, _ = read_node yp; - write_node zp a z b; - write_node yp c x d; - write_node tree zp y yp; + let zp, xk, xv, d, _ = read_node tree; + let a, zk, zv, yp, _ = read_node zp; + let b, yk, yv, c, _ = read_node yp; + write_node zp a zk zv b; + write_node yp c xk xv d; + write_node tree zp yk yv yp; tree } @@ -541,7 +508,7 @@ fn rotate_left_right (#t:Type0) (tree:tree_t t) (#l:G.erased (T.tree t){ Some? ( module M = FStar.Math.Lib -fn rec is_balanced (#t:Type0) (tree:tree_t t) +fn rec is_balanced (#k:Type0) (#v:Type0) (tree:tree_t k v) preserves is_tree tree 'l returns b:bool ensures pure (b <==> (T.is_balanced 'l)) @@ -566,9 +533,9 @@ fn rec is_balanced (#t:Type0) (tree:tree_t t) let b3 = is_balanced n.left; let b4 = b1 && b2 && b3; - + intro_is_tree_node tree vl; - + b4 } } @@ -577,9 +544,9 @@ fn rec is_balanced (#t:Type0) (tree:tree_t t) -fn rec rebalance_avl (#t:Type0) (tree:tree_t t) (#l:G.erased(T.tree t)) +fn rec rebalance_avl (#k:Type0) (#v:Type0) (tree:tree_t k v) (#l:G.erased(T.tree k v)) requires is_tree tree l - returns y:tree_t t + returns y:tree_t k v ensures (is_tree y (T.rebalance_avl l)) { let b = is_balanced tree; @@ -592,7 +559,7 @@ fn rec rebalance_avl (#t:Type0) (tree:tree_t t) (#l:G.erased(T.tree t)) Some vl -> { rewrite each (Some vl) as tree; is_tree_case_some tree vl; - + if (b) { intro_is_tree_node tree vl; @@ -603,15 +570,15 @@ fn rec rebalance_avl (#t:Type0) (tree:tree_t t) (#l:G.erased(T.tree t)) let n = !vl; let height_l = height n.left; let height_r = height n.right; - + let diff_height = height_l - height_r ; - if (diff_height > 1) + if (diff_height > 1) { let vll = get_some_ref n.left; intro_is_tree_node n.left vll; is_tree_case_some n.left vll; - + let nl = !vll; let height_ll = height nl.left; @@ -621,15 +588,15 @@ fn rec rebalance_avl (#t:Type0) (tree:tree_t t) (#l:G.erased(T.tree t)) { (*Only in this branch, this situation happens, Node x (Node z t1 (Node y t2 t3)) t4*) let vllr = get_some_ref nl.right; - + (*pack tree back in the order it is unpacked*) intro_is_tree_node nl.right vllr; - + intro_is_tree_node n.left vll; - - + + intro_is_tree_node tree vl; - + let y = rotate_left_right tree; y } @@ -656,7 +623,7 @@ fn rec rebalance_avl (#t:Type0) (tree:tree_t t) (#l:G.erased(T.tree t)) { (*Node x t1 (Node z (Node y t2 t3) t4)*) let vlrl = get_some_ref nr.left; - + (*pack tree back in the order it is unpacked*) intro_is_tree_node nr.left vlrl; intro_is_tree_node n.right vlr; @@ -672,14 +639,14 @@ fn rec rebalance_avl (#t:Type0) (tree:tree_t t) (#l:G.erased(T.tree t)) let y = rotate_left tree; y } - + } else { intro_is_tree_node tree vl; tree } - + } } } @@ -687,10 +654,10 @@ fn rec rebalance_avl (#t:Type0) (tree:tree_t t) (#l:G.erased(T.tree t)) -fn rec insert_avl (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) +fn rec insert_avl (#k:Type0) (#v:Type0) (cmp: T.cmp k) (tree:tree_t k v) (nd_key: k) (nd_val: v) requires is_tree tree 'l - returns y:tree_t t - ensures (is_tree y (T.insert_avl cmp 'l key)) + returns y:tree_t k v + ensures (is_tree y (T.insert_avl cmp 'l nd_key nd_val)) { match tree { None -> { @@ -698,43 +665,43 @@ fn rec insert_avl (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) elim_is_tree_leaf None; - let left = create t; - let right = create t; - - - let y = node_cons key left right; - + let left = create k v; + let right = create k v; + + + let y = node_cons nd_key nd_val left right; + let np = Some?.v y; - + is_tree_case_some y np; intro_is_tree_node y np; - + y } Some vl -> { is_tree_case_some (Some vl) vl; - with node. assert vl |-> node; + with nd. assert vl |-> nd; let n = !vl; - let delta = cmp n.data key; + let delta = cmp n.key nd_key; if (delta >= 0) { - let new_left = insert_avl cmp n.left key; - let vl' = {data = n.data; left = new_left; right = n.right}; + let new_left = insert_avl cmp n.left nd_key nd_val; + let vl' = {key = n.key; value = n.value; left = new_left; right = n.right}; vl := vl'; rewrite each new_left as vl'.left; - rewrite each node.right as vl'.right; + rewrite each nd.right as vl'.right; intro_is_tree_node (Some vl) vl #vl'; let new_tree = rebalance_avl (Some vl); new_tree } else { - let new_right = insert_avl cmp n.right key; - let vl' = {data = n.data; left = n.left; right = new_right }; + let new_right = insert_avl cmp n.right nd_key nd_val; + let vl' = {key = n.key; value = n.value; left = n.left; right = new_right }; vl := vl'; rewrite each new_right as vl'.right; - rewrite each node.left as vl'.left; + rewrite each nd.left as vl'.left; intro_is_tree_node (Some vl) vl #vl'; let new_tree = rebalance_avl (Some vl); new_tree @@ -743,23 +710,23 @@ fn rec insert_avl (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) } } - + ghost -fn is_tree_case_some1 (#t:Type) (x:tree_t t) (v:node_ptr t) (#ft:T.tree t) +fn is_tree_case_some1 (#k:Type) (#v:Type) (x:tree_t k v) (p:node_ptr k v) (#ft:T.tree k v) preserves is_tree x ft - requires pure (x == Some v) + requires pure (x == Some p) ensures pure (T.Node? ft) { - rewrite each x as Some v; - cases_of_is_tree (Some v) ft; + rewrite each x as Some p; + cases_of_is_tree (Some p) ft; unfold is_tree_cases; - intro_is_tree_node (Some v) v; - with 't. rewrite is_tree (Some v) 't as is_tree x 't; + intro_is_tree_node (Some p) p; + with 'ft. rewrite is_tree (Some p) 'ft as is_tree x 'ft; } -fn rec tree_max_c (#t:Type0) (tree:tree_t t) (#l:G.erased(T.tree t){T.Node? l}) +fn rec tree_max_c (#k:Type0) (#v:Type0) (tree:tree_t k v) (#l:G.erased(T.tree k v){T.Node? l}) preserves is_tree tree l - returns y:t + returns y:(k & v) ensures pure (y == T.tree_max l) { match tree { @@ -772,10 +739,11 @@ fn rec tree_max_c (#t:Type0) (tree:tree_t t) (#l:G.erased(T.tree t){T.Node? l}) let n = !vl; match n.right { None -> { - let d = n.data; + let dk = n.key; + let dv = n.value; is_tree_case_none n.right; intro_is_tree_node tree vl; - d + (dk, dv) } Some vlr -> { is_tree_case_some1 n.right vlr; @@ -784,15 +752,15 @@ fn rec tree_max_c (#t:Type0) (tree:tree_t t) (#l:G.erased(T.tree t){T.Node? l}) max } } - + } } } -fn rec delete_avl (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) +fn rec delete_avl (#k:Type0) (#v:Type0) (cmp: T.cmp k) (tree:tree_t k v) (del_key: k) requires is_tree tree 'l - returns y : tree_t t - ensures is_tree y (T.delete_avl cmp 'l key) + returns y : tree_t k v + ensures is_tree y (T.delete_avl cmp 'l del_key) { match tree { None -> { @@ -802,14 +770,14 @@ fn rec delete_avl (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) } Some vl -> { is_tree_case_some (Some vl) vl; - with node. assert vl |-> node; + with nd. assert vl |-> nd; let n = !vl; - let delta = cmp n.data key; + let delta = cmp n.key del_key; if (delta = 0) { let left = n.left; let right = n.right; - rewrite each node.left as left; - rewrite each node.right as right; + rewrite each nd.left as left; + rewrite each nd.right as right; //explicit ltree and rtree is needed, to find a proof for the existence of func ltree and rtree with ltree. assert is_tree left ltree; with rtree. assert is_tree right rtree; @@ -819,18 +787,18 @@ fn rec delete_avl (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) match right { None -> { (*Leaf,Leaf*) is_tree_case_none None #rtree; - let tr = create t; + let tr = create k v; Box.free vl; - rewrite each rtree as T.Leaf #t; - rewrite each ltree as T.Leaf #t; - unfold is_tree #t None T.Leaf; - unfold is_tree #t None T.Leaf; + rewrite each rtree as T.Leaf #k #v; + rewrite each ltree as T.Leaf #k #v; + unfold is_tree #k #v None T.Leaf; + unfold is_tree #k #v None T.Leaf; tr } Some vlr -> {(*Leaf,Node_*) is_tree_case_some (Some vlr) vlr; let rnode = !vlr; - let vl' = {data = rnode.data; left = rnode.left; right = rnode.right}; + let vl' = {key = rnode.key; value = rnode.value; left = rnode.left; right = rnode.right}; vl := vl'; with ltree. rewrite is_tree rnode.left ltree as is_tree vl'.left ltree; @@ -838,9 +806,9 @@ fn rec delete_avl (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) rewrite is_tree rnode.right rtree as is_tree vl'.right rtree; intro_is_tree_node (Some vl) vl #vl'; with ltree. - assert (is_tree #t None ltree); + assert (is_tree #k #v None ltree); Box.free vlr; - elim_is_tree_leaf #t None; + elim_is_tree_leaf #k #v None; (Some vl) } } @@ -852,7 +820,7 @@ fn rec delete_avl (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) is_tree_case_some (Some vll) vll; is_tree_case_none None; let lnode = !vll; - let vl' = {data = lnode.data; left = lnode.left; right = lnode.right}; + let vl' = {key = lnode.key; value = lnode.value; left = lnode.left; right = lnode.right}; vl := vl'; with ltree. rewrite is_tree lnode.left ltree as is_tree vl'.left ltree; @@ -867,8 +835,8 @@ fn rec delete_avl (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) Some vlr -> {(*Node_,Node_*) is_tree_case_some1 (Some vlr) vlr; let m = tree_max_c (Some vll); - let new_left = delete_avl cmp (Some vll) m; - let vl' = {data = m; left = new_left; right = right}; + let new_left = delete_avl cmp (Some vll) (fst m); + let vl' = {key = fst m; value = snd m; left = new_left; right = right}; vl := vl'; with ltree. rewrite is_tree new_left ltree as is_tree vl'.left ltree; @@ -876,7 +844,7 @@ fn rec delete_avl (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) rewrite is_tree (Some vlr) rtree as is_tree vl'.right rtree; intro_is_tree_node (Some vl) vl #vl'; let new_tree = rebalance_avl (Some vl); - assert (is_tree new_tree (T.delete_avl cmp 'l key)); + assert (is_tree new_tree (T.delete_avl cmp 'l del_key)); new_tree } } @@ -885,32 +853,189 @@ fn rec delete_avl (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) } else { if (delta < 0) { assert (pure (delta < 0)); - let new_left = delete_avl cmp n.left key; - let vl' = {data = n.data; left = new_left; right = n.right}; + let new_right = delete_avl cmp n.right del_key; + let vl' = {key = n.key; value = n.value; left = n.left; right = new_right}; vl := vl'; with ltree. - rewrite is_tree new_left ltree as is_tree vl'.left ltree; + rewrite is_tree n.left ltree as is_tree vl'.left ltree; with rtree. - rewrite is_tree n.right rtree as is_tree vl'.right rtree; + rewrite is_tree new_right rtree as is_tree vl'.right rtree; intro_is_tree_node (Some vl) vl #vl'; let new_tree = rebalance_avl (Some vl); new_tree } else { - let new_right = delete_avl cmp n.right key; - let vl' = {data = n.data; left = n.left; right = new_right}; + let new_left = delete_avl cmp n.left del_key; + let vl' = {key = n.key; value = n.value; left = new_left; right = n.right}; vl := vl'; with ltree. - rewrite is_tree n.left ltree as is_tree vl'.left ltree; + rewrite is_tree new_left ltree as is_tree vl'.left ltree; with rtree. - rewrite is_tree new_right rtree as is_tree vl'.right rtree; + rewrite is_tree n.right rtree as is_tree vl'.right rtree; intro_is_tree_node (Some vl) vl #vl'; - + let new_tree = rebalance_avl (Some vl); - assert (is_tree new_tree (T.delete_avl cmp 'l key)); - + assert (is_tree new_tree (T.delete_avl cmp 'l del_key)); + new_tree } } } } } + +fn rec find_min (#k:Type0) (#v:Type0) (cmp: T.cmp k) (tree:tree_t k v) (#l:G.erased(T.tree k v){T.Node? l}) + preserves is_tree tree l + returns y:(k & v) + ensures pure (y == T.tree_min l) +{ + match tree { + None -> { + is_tree_case_none None; + unreachable () + } + Some vl -> { + is_tree_case_some (Some vl) vl; + let n = !vl; + match n.left { + None -> { + let dk = n.key; + let dv = n.value; + is_tree_case_none n.left; + intro_is_tree_node tree vl; + (dk, dv) + } + Some vll -> { + is_tree_case_some1 n.left vll; + let min = find_min cmp n.left; + intro_is_tree_node tree vl; + min + } + } + } + } +} + +fn rec find_max (#k:Type0) (#v:Type0) (cmp: T.cmp k) (tree:tree_t k v) (#l:G.erased(T.tree k v){T.Node? l}) + preserves is_tree tree l + returns y:(k & v) + ensures pure (y == T.tree_max l) +{ + match tree { + None -> { + is_tree_case_none None; + unreachable () + } + Some vl -> { + is_tree_case_some (Some vl) vl; + let n = !vl; + match n.right { + None -> { + let dk = n.key; + let dv = n.value; + is_tree_case_none n.right; + intro_is_tree_node tree vl; + (dk, dv) + } + Some vr -> { + is_tree_case_some1 n.right vr; + let max = find_max cmp n.right; + intro_is_tree_node tree vl; + max + } + } + } + } +} + +fn rec find_le (#k:Type0) (#v:Type0) (cmp: T.cmp k) (tree:tree_t k v) (search_key:k) (#ft:G.erased (T.tree k v)) + preserves is_tree tree ft + returns y:option (k & v) + ensures pure (y == T.find_le cmp ft search_key) +{ + match tree { + None -> { + is_tree_case_none None; + rewrite is_tree None ft as is_tree tree ft; + let r : option (k & v) = None; + r + } + Some vl -> { + is_tree_case_some (Some vl) vl; + let n = !vl; + let delta = cmp n.key search_key; + if (delta > 0) { + let r = find_le cmp n.left search_key; + intro_is_tree_node tree vl; + r + } else if (delta = 0) { + intro_is_tree_node tree vl; + let r : option (k & v) = Some (n.key, n.value); + r + } else { + let r = find_le cmp n.right search_key; + intro_is_tree_node tree vl; + match r { + Some _ -> { r } + None -> { let r2 : option (k & v) = Some (n.key, n.value); r2 } + } + } + } + } +} + +fn rec find_ge (#k:Type0) (#v:Type0) (cmp: T.cmp k) (tree:tree_t k v) (search_key:k) (#ft:G.erased (T.tree k v)) + preserves is_tree tree ft + returns y:option (k & v) + ensures pure (y == T.find_ge cmp ft search_key) +{ + match tree { + None -> { + is_tree_case_none None; + rewrite is_tree None ft as is_tree tree ft; + let r : option (k & v) = None; + r + } + Some vl -> { + is_tree_case_some (Some vl) vl; + let n = !vl; + let delta = cmp n.key search_key; + if (delta < 0) { + let r = find_ge cmp n.right search_key; + intro_is_tree_node tree vl; + r + } else if (delta = 0) { + intro_is_tree_node tree vl; + let r : option (k & v) = Some (n.key, n.value); + r + } else { + let r = find_ge cmp n.left search_key; + intro_is_tree_node tree vl; + match r { + Some _ -> { r } + None -> { let r2 : option (k & v) = Some (n.key, n.value); r2 } + } + } + } + } +} + +fn rec free (#k:Type0) (#v:Type0) (x:tree_t k v) (#ft:G.erased (T.tree k v)) + requires is_tree x ft + ensures emp +{ + match x { + None -> { + is_tree_case_none None; + rewrite is_tree None ft as is_tree None (T.Leaf #k #v); + elim_is_tree_leaf (None #(node_ptr k v)); + () + } + Some vl -> { + is_tree_case_some (Some vl) vl; + let n = !vl; + free n.left; + free n.right; + Box.free vl + } + } +} diff --git a/lib/pulse/lib/Pulse.Lib.AVLTree.fsti b/lib/pulse/lib/Pulse.Lib.AVLTree.fsti index 5321797aa..e10bfe1e5 100644 --- a/lib/pulse/lib/Pulse.Lib.AVLTree.fsti +++ b/lib/pulse/lib/Pulse.Lib.AVLTree.fsti @@ -25,40 +25,64 @@ open Pulse.Lib.Pervasives module T = Pulse.Lib.Spec.AVLTree module G = FStar.Ghost -val tree_t (a:Type u#0): Type u#0 +val tree_t (k:Type u#0) (v:Type u#0): Type u#0 -val is_tree #t ([@@@mkey] ct:tree_t t) (ft:T.tree t) +val is_tree #k #v ([@@@mkey] ct:tree_t k v) (ft:T.tree k v) : Tot slprop (decreases ft) -fn height (#t:Type0) (x:tree_t t) (#ft:G.erased (T.tree t)) +fn height (#k:Type0) (#v:Type0) (x:tree_t k v) (#ft:G.erased (T.tree k v)) preserves is_tree x ft returns n : nat ensures pure (n == T.height ft) -fn is_empty (#t:Type) (x:tree_t t) (#ft:G.erased(T.tree t)) +fn is_empty (#k:Type) (#v:Type) (x:tree_t k v) (#ft:G.erased(T.tree k v)) preserves is_tree x ft returns b : bool ensures pure (b <==> (T.is_empty ft)) -fn create (t:Type0) - returns x : tree_t t +fn create (k:Type0) (v:Type0) + returns x : tree_t k v ensures is_tree x T.Leaf -fn mem (#t:eqtype) (x:tree_t t) (v: t) (#ft:G.erased (T.tree t)) +fn mem (#k:eqtype) (#v:Type0) (x:tree_t k v) (key: k) (#ft:G.erased (T.tree k v)) preserves is_tree x ft returns b : bool - ensures pure (b <==> (T.mem ft v)) + ensures pure (b <==> (T.mem ft key)) fn insert_avl - (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) - (#l: G.erased(T.tree t)) + (#k:Type0) (#v:Type0) (cmp: T.cmp k) (tree:tree_t k v) (key: k) (val_: v) + (#l: G.erased(T.tree k v)) requires is_tree tree l - returns y : tree_t t - ensures is_tree y (T.insert_avl cmp l key) + returns y : tree_t k v + ensures is_tree y (T.insert_avl cmp l key val_) fn delete_avl - (#t:Type0) (cmp: T.cmp t) (tree:tree_t t) (key: t) - (#l: G.erased(T.tree t)) + (#k:Type0) (#v:Type0) (cmp: T.cmp k) (tree:tree_t k v) (key: k) + (#l: G.erased(T.tree k v)) requires is_tree tree l - returns y : tree_t t + returns y : tree_t k v ensures is_tree y (T.delete_avl cmp l key) + +fn find_min (#k:Type0) (#v:Type0) (cmp: T.cmp k) (x:tree_t k v) (#ft:G.erased (T.tree k v){T.Node? ft}) + requires is_tree x ft + returns y:(k & v) + ensures is_tree x ft ** pure (y == T.tree_min ft) + +fn find_max (#k:Type0) (#v:Type0) (cmp: T.cmp k) (x:tree_t k v) (#ft:G.erased (T.tree k v){T.Node? ft}) + requires is_tree x ft + returns y:(k & v) + ensures is_tree x ft ** pure (y == T.tree_max ft) + +fn find_le (#k:Type0) (#v:Type0) (cmp: T.cmp k) (x:tree_t k v) (key:k) (#ft:G.erased (T.tree k v)) + preserves is_tree x ft + returns y:option (k & v) + ensures pure (y == T.find_le cmp ft key) + +fn find_ge (#k:Type0) (#v:Type0) (cmp: T.cmp k) (x:tree_t k v) (key:k) (#ft:G.erased (T.tree k v)) + preserves is_tree x ft + returns y:option (k & v) + ensures pure (y == T.find_ge cmp ft key) + +fn free (#k:Type0) (#v:Type0) (x:tree_t k v) (#ft:G.erased (T.tree k v)) + requires is_tree x ft + ensures emp diff --git a/lib/pulse/lib/Pulse.Lib.CircularBuffer.GapTrack.fst b/lib/pulse/lib/Pulse.Lib.CircularBuffer.GapTrack.fst new file mode 100644 index 000000000..0f03fb5bd --- /dev/null +++ b/lib/pulse/lib/Pulse.Lib.CircularBuffer.GapTrack.fst @@ -0,0 +1,342 @@ +module Pulse.Lib.CircularBuffer.GapTrack + +/// Gap tracking spec for circular buffer. +/// Defines contiguous_prefix_length on seq (option byte) and proves +/// monotonicity/update properties. + +module Seq = FStar.Seq + +type byte = FStar.UInt8.t + +/// Length of the longest contiguous prefix of Some values +let rec contiguous_prefix_length (s:Seq.seq (option byte)) + : Tot nat (decreases (Seq.length s)) + = if Seq.length s = 0 then 0 + else match Seq.index s 0 with + | None -> 0 + | Some _ -> 1 + contiguous_prefix_length (Seq.tail s) + +/// Prefix length is bounded by sequence length +let rec prefix_length_bounded (s:Seq.seq (option byte)) + : Lemma (ensures contiguous_prefix_length s <= Seq.length s) + (decreases (Seq.length s)) + = if Seq.length s = 0 then () + else match Seq.index s 0 with + | None -> () + | Some _ -> prefix_length_bounded (Seq.tail s) + +/// All elements before the prefix length are Some +let rec prefix_elements_are_some (s:Seq.seq (option byte)) (i:nat) + : Lemma (requires i < contiguous_prefix_length s /\ i < Seq.length s) + (ensures Some? (Seq.index s i)) + (decreases (Seq.length s)) + = prefix_length_bounded s; + if i = 0 then () + else begin + assert (Some? (Seq.index s 0)); + prefix_elements_are_some (Seq.tail s) (i - 1) + end + +/// Element at the prefix length (if it exists) is None +let rec prefix_boundary_is_none_aux (s:Seq.seq (option byte)) (pl:nat) + : Lemma (requires pl == contiguous_prefix_length s /\ pl < Seq.length s /\ pl > 0) + (ensures None? (Seq.index s pl)) + (decreases pl) + = assert (Some? (Seq.index s 0)); + let s' = Seq.tail s in + let pl' = contiguous_prefix_length s' in + if pl' = 0 then () + else prefix_boundary_is_none_aux s' pl' + +let prefix_boundary_is_none (s:Seq.seq (option byte)) + : Lemma (requires contiguous_prefix_length s < Seq.length s) + (ensures None? (Seq.index s (contiguous_prefix_length s))) + = let pl = contiguous_prefix_length s in + prefix_length_bounded s; + if pl = 0 then () + else prefix_boundary_is_none_aux s pl + +/// Converse of prefix_elements_are_some: +/// If all positions 0..n-1 are Some, then contiguous_prefix_length >= n. +let rec all_some_prefix_ge (s:Seq.seq (option byte)) (n:nat) + : Lemma (requires n <= Seq.length s /\ + (forall (i:nat{i < n}). Some? (Seq.index s i))) + (ensures contiguous_prefix_length s >= n) + (decreases n) + = if n = 0 then () + else ( + assert (Some? (Seq.index s 0)); + // cpl s = 1 + cpl (tail s) + // All positions 0..n-2 of tail are Some (shifted from 1..n-1 of s) + let s' = Seq.tail s in + let aux (i:nat{i < n - 1}) + : Lemma (Some? (Seq.index s' i)) + = assert (Seq.index s' i == Seq.index s (i + 1)); + assert (Some? (Seq.index s (i + 1))) + in + FStar.Classical.forall_intro aux; + all_some_prefix_ge s' (n - 1) + ) + +/// Writing Some at an index strictly beyond the prefix doesn't change the prefix +let rec write_beyond_prefix_preserves (s:Seq.seq (option byte)) (i:nat) (b:byte) + : Lemma (requires i < Seq.length s /\ i > contiguous_prefix_length s) + (ensures contiguous_prefix_length (Seq.upd s i (Some b)) == contiguous_prefix_length s) + (decreases (Seq.length s)) + = if Seq.length s = 0 then () + else match Seq.index s 0 with + | None -> () + | Some _ -> + let s' = Seq.tail s in + assert (Seq.upd s i (Some b) `Seq.equal` + Seq.cons (Seq.index s 0) (Seq.upd s' (i - 1) (Some b))); + write_beyond_prefix_preserves s' (i - 1) b + +/// Writing Some at exactly the prefix length extends the prefix by ≥ 1 +let rec write_at_prefix_extends (s:Seq.seq (option byte)) (b:byte) + : Lemma (requires + contiguous_prefix_length s < Seq.length s /\ + None? (Seq.index s (contiguous_prefix_length s))) + (ensures + contiguous_prefix_length (Seq.upd s (contiguous_prefix_length s) (Some b)) >= + contiguous_prefix_length s + 1) + (decreases (Seq.length s)) + = let pl = contiguous_prefix_length s in + if pl = 0 then () + else begin + let s_tail = Seq.tail s in + let pl' = contiguous_prefix_length s_tail in + let s' = Seq.upd s pl (Some b) in + assert (s' `Seq.equal` Seq.cons (Seq.index s 0) (Seq.upd s_tail (pl - 1) (Some b))); + write_at_prefix_extends s_tail b + end + +/// Overwriting an existing Some preserves the prefix +let rec write_some_at_existing_some (s:Seq.seq (option byte)) (i:nat) (b:byte) + : Lemma (requires i < Seq.length s /\ Some? (Seq.index s i) /\ i < contiguous_prefix_length s) + (ensures contiguous_prefix_length (Seq.upd s i (Some b)) >= contiguous_prefix_length s) + (decreases (Seq.length s)) + = if i = 0 then begin + let s' = Seq.upd s 0 (Some b) in + assert (Some? (Seq.index s' 0)); + assert (Seq.tail s' `Seq.equal` Seq.tail s) + end + else begin + let s' = Seq.upd s i (Some b) in + assert (s' `Seq.equal` Seq.cons (Seq.index s 0) (Seq.upd (Seq.tail s) (i - 1) (Some b))); + write_some_at_existing_some (Seq.tail s) (i - 1) b + end + +/// Monotonicity: writing Some never decreases the prefix length +let write_some_monotone (s:Seq.seq (option byte)) (i:nat) (b:byte) + : Lemma (requires i < Seq.length s) + (ensures contiguous_prefix_length (Seq.upd s i (Some b)) >= contiguous_prefix_length s) + = let pl = contiguous_prefix_length s in + prefix_length_bounded s; + if i > pl then + write_beyond_prefix_preserves s i b + else if i < pl then begin + prefix_elements_are_some s i; + write_some_at_existing_some s i b + end + else if pl < Seq.length s then begin + prefix_boundary_is_none s; + write_at_prefix_extends s b + end + else () + +/// Creating a sequence of Nones +let rec create_nones (n:nat) : Tot (s:Seq.seq (option byte){Seq.length s == n}) (decreases n) = + if n = 0 then Seq.empty + else Seq.cons None (create_nones (n - 1)) + +/// Prefix of all-Nones is 0 +let prefix_of_nones (n:nat) + : Lemma (ensures contiguous_prefix_length (create_nones n) == 0) + = if n = 0 then () else () + +/// All elements of create_nones are None +let rec create_nones_all_none (n:nat) (i:nat{i < n}) + : Lemma (ensures None? (Seq.index (create_nones n) i)) + (decreases n) + = if i = 0 then () + else create_nones_all_none (n - 1) (i - 1) + +/// Characterization: if all [0,p) are Some and (p==len or s[p] is None), +/// then cpl(s) == p. +let rec cpl_characterization (s: Seq.seq (option byte)) (p: nat) + : Lemma + (requires + p <= Seq.length s /\ + (forall (k:nat). k < p ==> Some? (Seq.index s k)) /\ + (p == Seq.length s \/ (p < Seq.length s /\ None? (Seq.index s p)))) + (ensures contiguous_prefix_length s == p) + (decreases p) + = if p = 0 then () + else begin + assert (Some? (Seq.index s 0)); + let ts = Seq.tail s in + assert (forall (k:nat). k < p - 1 ==> Seq.index ts k == Seq.index s (k + 1)); + if p - 1 < Seq.length ts then + assert (Seq.index ts (p - 1) == Seq.index s p) + else (); + cpl_characterization ts (p - 1) + end + +/// Write a range of bytes at consecutive offsets +let rec write_range_contents + (contents: Seq.seq (option byte)) + (offset: nat) + (data: Seq.seq byte) + : Pure (Seq.seq (option byte)) + (requires offset + Seq.length data <= Seq.length contents) + (ensures fun r -> Seq.length r == Seq.length contents) + (decreases (Seq.length data)) + = if Seq.length data = 0 then contents + else + let c' = Seq.upd contents offset (Some (Seq.index data 0)) in + write_range_contents c' (offset + 1) (Seq.tail data) + +/// Writing a range of bytes never decreases the contiguous prefix length +let rec write_range_monotone + (contents: Seq.seq (option byte)) + (offset: nat) + (data: Seq.seq byte) + : Lemma + (requires offset + Seq.length data <= Seq.length contents) + (ensures contiguous_prefix_length (write_range_contents contents offset data) + >= contiguous_prefix_length contents) + (decreases (Seq.length data)) + = if Seq.length data = 0 then () + else begin + let c' = Seq.upd contents offset (Some (Seq.index data 0)) in + write_some_monotone contents offset (Seq.index data 0); + write_range_monotone c' (offset + 1) (Seq.tail data) + end + +/// Stepping lemma: writing one more byte = upd of the previous range result +let rec write_range_snoc + (contents: Seq.seq (option byte)) + (offset: nat) + (data: Seq.seq byte) + (i: nat) + : Lemma + (requires offset + Seq.length data <= Seq.length contents /\ + i < Seq.length data /\ + offset + i + 1 <= Seq.length contents) + (ensures + write_range_contents contents offset (Seq.slice data 0 (i + 1)) `Seq.equal` + Seq.upd (write_range_contents contents offset (Seq.slice data 0 i)) + (offset + i) (Some (Seq.index data i))) + (decreases i) + = if i = 0 then () + else begin + let d_ip1 = Seq.slice data 0 (i + 1) in + let d_i = Seq.slice data 0 i in + let c' = Seq.upd contents offset (Some (Seq.index data 0)) in + assert (Seq.length d_ip1 > 0); + assert (Seq.index d_ip1 0 == Seq.index data 0); + let tail_ip1 = Seq.tail d_ip1 in + let tail_i = Seq.tail d_i in + assert (tail_ip1 `Seq.equal` Seq.slice data 1 (i + 1)); + assert (tail_i `Seq.equal` Seq.slice data 1 i); + assert (Seq.length tail_ip1 == i); + assert (Seq.length tail_i == i - 1); + // Shift to tail data + let d' = Seq.tail data in + assert (tail_ip1 `Seq.equal` Seq.slice d' 0 i); + assert (tail_i `Seq.equal` Seq.slice d' 0 (i - 1)); + assert (Seq.index d' (i - 1) == Seq.index data i); + write_range_snoc c' (offset + 1) d' (i - 1) + end + +/// Total wrapper for write_range_contents (no precondition; identity when out of bounds) +let write_range_contents_t + (contents: Seq.seq (option byte)) + (offset: nat) + (data: Seq.seq byte) + : Seq.seq (option byte) + = if offset + Seq.length data <= Seq.length contents + then write_range_contents contents offset data + else contents + +/// Equivalence: when precondition holds, total version equals partial version +let write_range_contents_t_eq + (contents: Seq.seq (option byte)) + (offset: nat) + (data: Seq.seq byte) + : Lemma + (requires offset + Seq.length data <= Seq.length contents) + (ensures write_range_contents_t contents offset data == + write_range_contents contents offset data) + = () + +/// Length preservation for total version +let write_range_contents_t_length + (contents: Seq.seq (option byte)) + (offset: nat) + (data: Seq.seq byte) + : Lemma (Seq.length (write_range_contents_t contents offset data) == + Seq.length contents) + = if offset + Seq.length data <= Seq.length contents + then () + else () + +/// Pointwise characterization of write_range_contents +let rec write_range_index + (contents: Seq.seq (option byte)) + (offset: nat) + (data: Seq.seq byte) + (i: nat) + : Lemma + (requires offset + Seq.length data <= Seq.length contents /\ + i < Seq.length contents) + (ensures + Seq.index (write_range_contents contents offset data) i == + (if offset <= i && i < offset + Seq.length data + then Some (Seq.index data (i - offset)) + else Seq.index contents i)) + (decreases (Seq.length data)) + = if Seq.length data = 0 then () + else begin + let c' = Seq.upd contents offset (Some (Seq.index data 0)) in + if i = offset then begin + // First byte written at offset — show it's overwritten + write_range_index c' (offset + 1) (Seq.tail data) i + end + else if i > offset && i < offset + Seq.length data then begin + // In the written range but not the first position + write_range_index c' (offset + 1) (Seq.tail data) i; + assert (Seq.index c' i == Seq.index contents i); + assert (i - offset >= 1); + assert (Seq.index (Seq.tail data) (i - (offset + 1)) == Seq.index data (i - offset)) + end + else begin + // Outside the written range — unchanged + write_range_index c' (offset + 1) (Seq.tail data) i; + if i = offset then () + else assert (Seq.index c' i == Seq.index contents i) + end + end + +/// Forall version: characterize every index of write_range_contents +let write_range_forall_index + (contents: Seq.seq (option byte)) + (offset: nat) + (data: Seq.seq byte) + : Lemma + (requires offset + Seq.length data <= Seq.length contents) + (ensures + forall (i:nat{i < Seq.length contents}). + Seq.index (write_range_contents contents offset data) i == + (if offset <= i && i < offset + Seq.length data + then Some (Seq.index data (i - offset)) + else Seq.index contents i)) + = let aux (i:nat{i < Seq.length contents}) + : Lemma (Seq.index (write_range_contents contents offset data) i == + (if offset <= i && i < offset + Seq.length data + then Some (Seq.index data (i - offset)) + else Seq.index contents i)) + = write_range_index contents offset data i + in + FStar.Classical.forall_intro aux diff --git a/lib/pulse/lib/Pulse.Lib.CircularBuffer.Modular.fst b/lib/pulse/lib/Pulse.Lib.CircularBuffer.Modular.fst new file mode 100644 index 000000000..59a451d21 --- /dev/null +++ b/lib/pulse/lib/Pulse.Lib.CircularBuffer.Modular.fst @@ -0,0 +1,40 @@ +module Pulse.Lib.CircularBuffer.Modular + +/// Modular/circular indexing lemmas for circular buffer operations. +/// All lemmas are pure, no Pulse dependency. + +module ML = FStar.Math.Lemmas + +/// Circular index is always in bounds +let circular_index_in_bounds (read_start:nat) (offset:nat) (cap:pos) + : Lemma (ensures (read_start + offset) % cap < cap) + = ML.lemma_mod_lt (read_start + offset) cap + +/// Two different offsets within capacity map to different circular indices +let circular_index_injective (read_start:nat) (o1 o2:nat) (cap:pos) + : Lemma (requires o1 < cap /\ o2 < cap /\ o1 <> o2) + (ensures (read_start + o1) % cap <> (read_start + o2) % cap) + = // (read_start + o1) % cap = (read_start % cap + o1) % cap (and same for o2) + ML.lemma_mod_plus_distr_l read_start o1 cap; + ML.lemma_mod_plus_distr_l read_start o2 cap; + let r = read_start % cap in + // Now we need (r + o1) % cap <> (r + o2) % cap + // Since |o1 - o2| < cap, and r < cap, the difference (r+o1) - (r+o2) = o1 - o2 + // has absolute value < cap, so they can't be equal mod cap unless o1 = o2 + assert (r < cap); + if (r + o1) % cap = (r + o2) % cap then begin + // Derive contradiction: supposing equal, then o1 - o2 is divisible by cap + // But |o1 - o2| < cap, so o1 = o2. Contradiction. + ML.lemma_mod_plus_distr_l r o1 cap; + ML.lemma_mod_plus_distr_l r o2 cap + end else () + +/// Advancing read_start by n (mod cap) is equivalent to adding n to circular index +let advance_read_start (read_start:nat) (n:nat) (offset:nat) (cap:pos) + : Lemma (requires read_start < cap) + (ensures ( + let new_start = (read_start + n) % cap in + (new_start + offset) % cap == (read_start + n + offset) % cap)) + = let new_start = (read_start + n) % cap in + ML.lemma_mod_plus_distr_l (read_start + n) offset cap; + ML.lemma_mod_plus_distr_l new_start offset cap diff --git a/lib/pulse/lib/Pulse.Lib.CircularBuffer.Pow2.fst b/lib/pulse/lib/Pulse.Lib.CircularBuffer.Pow2.fst new file mode 100644 index 000000000..5deee487f --- /dev/null +++ b/lib/pulse/lib/Pulse.Lib.CircularBuffer.Pow2.fst @@ -0,0 +1,77 @@ +module Pulse.Lib.CircularBuffer.Pow2 + +/// Power-of-2 arithmetic and doubling reachability lemmas. +/// Used by the circular buffer resize logic. + +module ML = FStar.Math.Lemmas + +/// Recursive definition: n is a power of 2 +let rec is_pow2 (n:pos) : Tot bool (decreases n) = + if n = 1 then true + else if n % 2 <> 0 then false + else is_pow2 (n / 2) + +/// Doubling a power of 2 yields a power of 2 +let rec doubling_stays_pow2 (n:pos) + : Lemma (requires is_pow2 n) + (ensures is_pow2 (n + n)) + (decreases n) + = if n = 1 then () + else begin + doubling_stays_pow2 (n / 2) + end + +/// Helper: if pow2 a < pow2 b, then 2*a ≤ b +let rec pow2_double_le (a:pos) (b:pos) + : Lemma (requires is_pow2 a /\ is_pow2 b /\ a < b) + (ensures a + a <= b) + (decreases b) + = if a = 1 then () + else begin + // a >= 2 so a % 2 = 0, and b > a >= 2 so b % 2 = 0 + pow2_double_le (a / 2) (b / 2) + end + +/// Full reachability: repeated doubling from start reaches some pow2 in [target, vlen] +let rec doubling_reaches_in_range (start:pos) (vlen:pos) (target:pos) + : Lemma (requires + is_pow2 start /\ + is_pow2 vlen /\ + start <= vlen /\ + target <= vlen /\ + target > 0) + (ensures (exists (reached:pos). + is_pow2 reached /\ + reached >= target /\ + reached <= vlen)) + (decreases (vlen - start)) + = if start >= target then () + else begin + doubling_stays_pow2 start; + pow2_double_le start vlen; + doubling_reaches_in_range (start + start) vlen target + end + +/// Compute the smallest power-of-2 >= needed, by doubling base +let rec next_pow2_ge (base: pos) (needed: pos) + : Pure pos + (requires is_pow2 base) + (ensures fun r -> is_pow2 r /\ r >= needed /\ r >= base) + (decreases (if base >= needed then 0 else needed - base)) + = if base >= needed then base + else begin + doubling_stays_pow2 base; + next_pow2_ge (base + base) needed + end + +/// next_pow2_ge never exceeds a power-of-2 bound that is >= both base and needed +let rec next_pow2_ge_le_bound (al: pos) (needed: pos) (bound: pos) + : Lemma (requires is_pow2 al /\ is_pow2 bound /\ al <= bound /\ needed <= bound) + (ensures next_pow2_ge al needed <= bound) + (decreases (if al >= needed then 0 else needed - al)) + = if al >= needed then () + else begin + doubling_stays_pow2 al; + pow2_double_le al bound; + next_pow2_ge_le_bound (al + al) needed bound + end diff --git a/lib/pulse/lib/Pulse.Lib.CircularBuffer.Spec.fst b/lib/pulse/lib/Pulse.Lib.CircularBuffer.Spec.fst new file mode 100644 index 000000000..a77ec812f --- /dev/null +++ b/lib/pulse/lib/Pulse.Lib.CircularBuffer.Spec.fst @@ -0,0 +1,1119 @@ +module Pulse.Lib.CircularBuffer.Spec + +/// Specification of the circular buffer (Circular mode, MsQuic recv_buffer.c). +/// Defines state, coherence, operation specs, and the no-overcommit theorem. + +module Seq = FStar.Seq +module ML = FStar.Math.Lemmas +module Pow2 = Pulse.Lib.CircularBuffer.Pow2 +module Mod = Pulse.Lib.CircularBuffer.Modular +module GT = Pulse.Lib.CircularBuffer.GapTrack + +type byte = FStar.UInt8.t + +/// --- Physical Index --- + +/// Compute the physical array index for logical offset i (always in bounds) +let phys_index (read_start: nat) (i: nat) (al: pos) : Tot (j:nat{j < al}) = + Mod.circular_index_in_bounds read_start i al; + (read_start + i) % al + +/// --- Buffer State --- + +noeq type cb_state = { + base_offset: nat; + read_start: nat; + alloc_length: pos; + virtual_length: pos; + contents: Seq.seq (option byte); +} + +/// Platform bound on maximum allocatable buffer size (simulates finite memory). +assume val cb_max_length : n:pos{ Pow2.is_pow2 n /\ n <= 0x8000000000000000 } + +let cb_wf (st: cb_state) : prop = + Pow2.is_pow2 st.alloc_length /\ + Pow2.is_pow2 st.virtual_length /\ + st.alloc_length <= st.virtual_length /\ + st.alloc_length <= cb_max_length /\ + st.read_start < st.alloc_length /\ + Seq.length st.contents == st.alloc_length + +/// --- Physical-Logical Coherence --- + +/// Coherence at a single position +let coherent_at + (al: pos) + (phys: Seq.seq byte{Seq.length phys == al}) + (contents: Seq.seq (option byte){Seq.length contents == al}) + (read_start: nat{read_start < al}) + (i: nat{i < al}) + : prop + = Some? (Seq.index contents i) ==> + Seq.index phys (phys_index read_start i al) == Some?.v (Seq.index contents i) + +/// Full coherence: all positions agree +let phys_log_coherent + (al: pos) + (phys: Seq.seq byte{Seq.length phys == al}) + (contents: Seq.seq (option byte){Seq.length contents == al}) + (read_start: nat{read_start < al}) + : prop + = forall (i:nat{i < al}). coherent_at al phys contents read_start i + +/// --- Write Spec --- + +/// Writing byte b at logical offset i preserves coherence +let write_preserves_coherence + (al: pos) + (phys: Seq.seq byte{Seq.length phys == al}) + (contents: Seq.seq (option byte){Seq.length contents == al}) + (read_start: nat{read_start < al}) + (i: nat{i < al}) + (b: byte) + : Lemma (requires phys_log_coherent al phys contents read_start) + (ensures phys_log_coherent al + (Seq.upd phys (phys_index read_start i al) b) + (Seq.upd contents i (Some b)) + read_start) + = let pidx = phys_index read_start i al in + let new_phys = Seq.upd phys pidx b in + let new_contents = Seq.upd contents i (Some b) in + let aux (j:nat{j < al}) + : Lemma (coherent_at al new_phys new_contents read_start j) + = if j = i then () + else Mod.circular_index_injective read_start i j al + in + FStar.Classical.forall_intro aux + +/// --- Linearize (Resize) Spec --- + +/// Construct the linearized physical buffer after resize +let linearized_phys + (old_al: pos) (new_al: pos) + (old_phys: Seq.seq byte{Seq.length old_phys == old_al}) + (old_read_start: nat{old_read_start < old_al}) + : Pure (Seq.seq byte) + (requires new_al >= old_al) + (ensures fun r -> Seq.length r == new_al) + = Seq.init new_al (fun k -> + if k < old_al then Seq.index old_phys (phys_index old_read_start k old_al) + else 0uy) + +/// Extend contents with Nones for new capacity +let resized_contents + (old_al: pos) (new_al: pos) + (old_contents: Seq.seq (option byte){Seq.length old_contents == old_al}) + : Pure (Seq.seq (option byte)) + (requires new_al >= old_al) + (ensures fun r -> Seq.length r == new_al) + = Seq.init new_al (fun k -> + if k < old_al then Seq.index old_contents k + else None) + +/// Linearization preserves coherence (read_start resets to 0) +let linearize_preserves_coherence + (old_al: pos) (new_al: pos) + (old_phys: Seq.seq byte{Seq.length old_phys == old_al}) + (old_contents: Seq.seq (option byte){Seq.length old_contents == old_al}) + (old_read_start: nat{old_read_start < old_al}) + : Lemma + (requires + new_al >= old_al /\ + phys_log_coherent old_al old_phys old_contents old_read_start) + (ensures + phys_log_coherent new_al + (linearized_phys old_al new_al old_phys old_read_start) + (resized_contents old_al new_al old_contents) + 0) + = let np = linearized_phys old_al new_al old_phys old_read_start in + let nc = resized_contents old_al new_al old_contents in + let aux (j:nat{j < new_al}) + : Lemma (coherent_at new_al np nc 0 j) + = if j >= old_al then () + else begin + ML.small_mod j new_al; + assert (coherent_at old_al old_phys old_contents old_read_start j) + end + in + FStar.Classical.forall_intro aux + +/// --- Drain Spec --- + +/// Drained contents: shift left by n, fill tail with None +let drained_contents + (al: pos) + (contents: Seq.seq (option byte){Seq.length contents == al}) + (n: nat{n <= al}) + : Tot (s:Seq.seq (option byte){Seq.length s == al}) + = Seq.init al (fun k -> + if k + n < al then Seq.index contents (k + n) + else None) + +/// Drain preserves coherence with updated read_start +let drain_preserves_coherence + (al: pos) + (phys: Seq.seq byte{Seq.length phys == al}) + (contents: Seq.seq (option byte){Seq.length contents == al}) + (read_start: nat{read_start < al}) + (n: nat{n <= al}) + : Lemma + (requires phys_log_coherent al phys contents read_start) + (ensures + phys_log_coherent al phys + (drained_contents al contents n) + (phys_index read_start n al)) + = let new_rs = phys_index read_start n al in + let nc = drained_contents al contents n in + let aux (j:nat{j < al}) + : Lemma (coherent_at al phys nc new_rs j) + = if j + n >= al then () + else begin + Mod.advance_read_start read_start n j al; + assert (coherent_at al phys contents read_start (j + n)) + end + in + FStar.Classical.forall_intro aux + +/// --- Drain Prefix Lemma --- + +/// After draining n from the front (where n <= cpl), the prefix shrinks by exactly n. +let drain_prefix_length + (al: pos) + (contents: Seq.seq (option byte){Seq.length contents == al}) + (n: nat{n <= al}) + : Lemma + (requires n <= GT.contiguous_prefix_length contents) + (ensures GT.contiguous_prefix_length (drained_contents al contents n) + == GT.contiguous_prefix_length contents - n) + = let cpl = GT.contiguous_prefix_length contents in + let p = cpl - n in + let ds = drained_contents al contents n in + GT.prefix_length_bounded contents; + // All positions [0, p) of ds are Some + let aux1 (k:nat{k < p}) : Lemma (Some? (Seq.index ds k)) + = assert (k + n < al); + assert (Seq.index ds k == Seq.index contents (k + n)); + GT.prefix_elements_are_some contents (k + n) + in + FStar.Classical.forall_intro aux1; + // ds[p] is None (or p == al) + if p < al then begin + if cpl < al then begin + assert (Seq.index ds p == Seq.index contents (p + n)); + assert (p + n == cpl); + GT.prefix_boundary_is_none contents + end else begin + // cpl == al, so p = al - n, and p + n = al >= al, so ds[p] = None + assert (p + n >= al) + end + end else (); + GT.cpl_characterization ds p + +/// --- Resize Prefix Lemma --- + +/// After resize (pad with None), prefix is unchanged. +let resize_prefix_length + (old_al: pos) (new_al: pos) + (contents: Seq.seq (option byte){Seq.length contents == old_al}) + : Lemma + (requires new_al >= old_al) + (ensures GT.contiguous_prefix_length (resized_contents old_al new_al contents) + == GT.contiguous_prefix_length contents) + = let cpl = GT.contiguous_prefix_length contents in + let rc = resized_contents old_al new_al contents in + GT.prefix_length_bounded contents; + // All positions [0, cpl) of rc are Some (same as original) + let aux1 (k:nat{k < cpl}) : Lemma (Some? (Seq.index rc k)) + = assert (k < old_al); + assert (Seq.index rc k == Seq.index contents k); + GT.prefix_elements_are_some contents k + in + FStar.Classical.forall_intro aux1; + // rc[cpl] is None (or cpl == new_al) + if cpl < new_al then begin + if cpl < old_al then begin + assert (Seq.index rc cpl == Seq.index contents cpl); + GT.prefix_boundary_is_none contents + end else begin + // cpl == old_al, so rc[cpl] = None (padding) + assert (Seq.index rc cpl == None) + end + end else (); + GT.cpl_characterization rc cpl + +/// --- No-Overcommit Theorem --- + +/// For any in-bounds write, there exists a sufficient power-of-2 buffer size +/// that accommodates the write and is at most virtual_length. +/// This is the top-level safety property from recv_buffer.c: +/// "We must always be willing/able to allocate the buffer length advertised to the peer." +let no_overcommit (st: cb_state) (write_end: nat) + : Lemma + (requires + cb_wf st /\ + write_end > st.base_offset /\ + write_end <= st.base_offset + st.virtual_length) + (ensures + exists (new_al: pos). + Pow2.is_pow2 new_al /\ + new_al >= st.alloc_length /\ + new_al <= st.virtual_length /\ + write_end <= st.base_offset + new_al) + = if write_end <= st.base_offset + st.alloc_length then () + else begin + let needed : pos = write_end - st.base_offset in + Pow2.doubling_reaches_in_range st.alloc_length st.virtual_length needed + end + +/// --- Total helpers for Pulse interface (no preconditions) --- + +/// State after writing a byte (total: no-op if out of bounds) +let write_byte_result (st: cb_state) (i: nat) (b: byte) : cb_state = + if i < Seq.length st.contents then + { st with contents = Seq.upd st.contents i (Some b) } + else st + +/// State after draining n bytes (total: no-op if out of bounds) +let drain_result (st: cb_state) (n: nat) : cb_state = + if n <= st.alloc_length + && Seq.length st.contents = st.alloc_length + && st.read_start < st.alloc_length then + { st with + base_offset = st.base_offset + n; + read_start = phys_index st.read_start n st.alloc_length; + contents = drained_contents st.alloc_length st.contents n } + else st + +/// State after resize (total: no-op if invalid) +let resize_result (st: cb_state) (new_al: pos) : cb_state = + if new_al >= st.alloc_length && Seq.length st.contents = st.alloc_length then + { st with + read_start = 0; + alloc_length = new_al; + contents = resized_contents st.alloc_length new_al st.contents } + else st + +/// Transfer coherence across Seq.equal contents +let phys_log_coherent_seq_equal + (al: pos) (phys: Seq.seq byte{Seq.length phys == al}) + (c1 c2: Seq.seq (option byte)) + (rs: nat{rs < al}) + : Lemma + (requires Seq.length c1 == al /\ Seq.length c2 == al /\ + phys_log_coherent al phys c1 rs /\ c1 `Seq.equal` c2) + (ensures phys_log_coherent al phys c2 rs) + = () + +/// Combined step: write a byte and maintain coherence with write_range_contents +let write_step_coherence + (al: pos) + (phys: Seq.seq byte{Seq.length phys == al}) + (contents: Seq.seq (option byte){Seq.length contents == al}) + (rs: nat{rs < al}) + (offset: nat) + (data: Seq.seq byte) + (i: nat) + : Lemma + (requires offset + Seq.length data <= al /\ + i < Seq.length data /\ + offset + i < al /\ + phys_log_coherent al phys + (GT.write_range_contents contents offset (Seq.slice data 0 i)) rs) + (ensures phys_log_coherent al + (Seq.upd phys (phys_index rs (offset + i) al) (Seq.index data i)) + (GT.write_range_contents contents offset (Seq.slice data 0 (i + 1))) rs) + = let old_c = GT.write_range_contents contents offset (Seq.slice data 0 i) in + let b = Seq.index data i in + write_preserves_coherence al phys old_c rs (offset + i) b; + GT.write_range_snoc contents offset data i; + phys_log_coherent_seq_equal al + (Seq.upd phys (phys_index rs (offset + i) al) b) + (Seq.upd old_c (offset + i) (Some b)) + (GT.write_range_contents contents offset (Seq.slice data 0 (i + 1))) + rs + +/// --- Read step helper --- +/// Extends the read_buffer loop invariant from k + segs.off1 == rs /\ + segs.len1 + segs.len2 == n /\ + segs.off1 + segs.len1 <= al /\ + segs.off2 + segs.len2 <= al /\ + (segs.len2 > 0 ==> segs.off2 == 0) /\ + (segs.len2 == 0 ==> segs.off1 + segs.len1 == rs + n)) + = if rs + n <= al then + { off1 = rs; len1 = n; off2 = 0; len2 = 0 } + else + { off1 = rs; len1 = al - rs; off2 = 0; len2 = n - (al - rs) } + +/// The physical bytes for segment 1 match the logical contents. +/// phys[off1..off1+len1) corresponds to contents[0..len1) via coherence. +let read_segments_seg1_correct + (al: pos) (phys: Seq.seq byte{Seq.length phys == al}) + (contents: Seq.seq (option byte){Seq.length contents == al}) + (rs: nat{rs < al}) (n: nat{n <= al}) + (cpl: nat{cpl >= n}) + : Lemma + (requires + phys_log_coherent al phys contents rs /\ + cpl <= GT.contiguous_prefix_length contents) + (ensures ( + let segs = compute_read_segments rs n al in + forall (i:nat{i < segs.len1}). + Some? (Seq.index contents i) /\ + Seq.index phys (segs.off1 + i) == Some?.v (Seq.index contents i))) + = let segs = compute_read_segments rs n al in + let aux (i:nat{i < segs.len1}) + : Lemma (Some? (Seq.index contents i) /\ + Seq.index phys (segs.off1 + i) == Some?.v (Seq.index contents i)) + = GT.prefix_elements_are_some contents i; + assert (coherent_at al phys contents rs i); + Mod.circular_index_in_bounds rs i al; + // phys_index rs i al == (rs + i) % al + // Since i < len1 and off1 = rs, off1 + i = rs + i + // No wrap case: rs + i < al, so (rs + i) % al = rs + i = off1 + i + // Wrap case: i < al - rs, so rs + i < al, so (rs + i) % al = rs + i = off1 + i + ML.small_mod (rs + i) al + in + FStar.Classical.forall_intro aux + +/// The physical bytes for segment 2 match the logical contents. +/// phys[0..len2) corresponds to contents[len1..len1+len2) via coherence. +let read_segments_seg2_correct + (al: pos) (phys: Seq.seq byte{Seq.length phys == al}) + (contents: Seq.seq (option byte){Seq.length contents == al}) + (rs: nat{rs < al}) (n: nat{n <= al}) + (cpl: nat{cpl >= n}) + : Lemma + (requires + phys_log_coherent al phys contents rs /\ + cpl <= GT.contiguous_prefix_length contents) + (ensures ( + let segs = compute_read_segments rs n al in + forall (i:nat{i < segs.len2}). + Some? (Seq.index contents (segs.len1 + i)) /\ + Seq.index phys (segs.off2 + i) == Some?.v (Seq.index contents (segs.len1 + i)))) + = let segs = compute_read_segments rs n al in + if segs.len2 = 0 then () + else + let aux (i:nat{i < segs.len2}) + : Lemma (Some? (Seq.index contents (segs.len1 + i)) /\ + Seq.index phys (segs.off2 + i) == Some?.v (Seq.index contents (segs.len1 + i))) + = let li = segs.len1 + i in + GT.prefix_elements_are_some contents li; + assert (coherent_at al phys contents rs li); + Mod.circular_index_in_bounds rs li al; + // phys_index rs li al == (rs + li) % al + // li = (al - rs) + i, so rs + li = al + i + // (al + i) % al = i = off2 + i (since off2 = 0) + ML.lemma_mod_plus i 1 al; + assert ((rs + li) % al == i) + in + FStar.Classical.forall_intro aux + +/// Combined: both segments are correct +let read_segments_correct + (al: pos) (phys: Seq.seq byte{Seq.length phys == al}) + (contents: Seq.seq (option byte){Seq.length contents == al}) + (rs: nat{rs < al}) (n: nat{n <= al}) + (cpl: nat{cpl >= n}) + : Lemma + (requires + phys_log_coherent al phys contents rs /\ + cpl <= GT.contiguous_prefix_length contents) + (ensures ( + let segs = compute_read_segments rs n al in + // Segment 1 data matches + (forall (i:nat{i < segs.len1}). + Some? (Seq.index contents i) /\ + Seq.index phys (segs.off1 + i) == Some?.v (Seq.index contents i)) /\ + // Segment 2 data matches + (forall (i:nat{i < segs.len2}). + Some? (Seq.index contents (segs.len1 + i)) /\ + Seq.index phys (segs.off2 + i) == Some?.v (Seq.index contents (segs.len1 + i))))) + = read_segments_seg1_correct al phys contents rs n cpl; + read_segments_seg2_correct al phys contents rs n cpl + +/// Slice equality: phys[off1..off1+len1) == the logical bytes for [0..len1) +let read_segments_slice_eq + (al: pos) (phys: Seq.seq byte{Seq.length phys == al}) + (contents: Seq.seq (option byte){Seq.length contents == al}) + (rs: nat{rs < al}) (n: nat{n <= al}) + (cpl: nat{cpl >= n}) + : Lemma + (requires + phys_log_coherent al phys contents rs /\ + cpl <= GT.contiguous_prefix_length contents) + (ensures ( + let segs = compute_read_segments rs n al in + let s1 = Seq.slice phys segs.off1 (segs.off1 + segs.len1) in + let s2 = Seq.slice phys segs.off2 (segs.off2 + segs.len2) in + // Each byte in s1 matches the logical contents + (forall (i:nat{i < segs.len1}). + Some? (Seq.index contents i) /\ + Seq.index s1 i == Some?.v (Seq.index contents i)) /\ + // Each byte in s2 matches the logical contents + (forall (i:nat{i < segs.len2}). + Some? (Seq.index contents (segs.len1 + i)) /\ + Seq.index s2 i == Some?.v (Seq.index contents (segs.len1 + i))))) + = read_segments_correct al phys contents rs n cpl; + let segs = compute_read_segments rs n al in + let aux1 (i:nat{i < segs.len1}) + : Lemma (Seq.index (Seq.slice phys segs.off1 (segs.off1 + segs.len1)) i + == Seq.index phys (segs.off1 + i)) + = Seq.lemma_index_slice phys segs.off1 (segs.off1 + segs.len1) i + in + FStar.Classical.forall_intro aux1; + let aux2 (i:nat{i < segs.len2}) + : Lemma (Seq.index (Seq.slice phys segs.off2 (segs.off2 + segs.len2)) i + == Seq.index phys (segs.off2 + i)) + = Seq.lemma_index_slice phys segs.off2 (segs.off2 + segs.len2) i + in + FStar.Classical.forall_intro aux2 + +/// --- Out-of-order write helpers --- + +/// cb_wf is preserved by write_range_contents (contents length unchanged) +let write_range_preserves_wf + (st: cb_state) (offset: nat) (data: Seq.seq byte) + : Lemma + (requires cb_wf st /\ offset + Seq.length data <= st.alloc_length) + (ensures cb_wf { st with contents = GT.write_range_contents st.contents offset data }) + = () + +/// Transfer coherence from Seq.slice to full data for OOO write (no-resize case) +let write_ooo_coherence_transfer + (al: pos) (phys: Seq.seq byte{Seq.length phys == al}) + (contents: Seq.seq (option byte){Seq.length contents == al}) + (rs: nat{rs < al}) + (offset: nat) (data: Seq.seq byte) (n: nat) (write_len: nat) + : Lemma + (requires + n <= write_len /\ + write_len == Seq.length data /\ + false == (n < write_len) /\ + offset + write_len <= al /\ + phys_log_coherent al phys + (GT.write_range_contents contents offset (Seq.slice data 0 n)) + rs) + (ensures + phys_log_coherent al phys + (GT.write_range_contents contents offset data) + rs) + = assert (n == Seq.length data); + Seq.lemma_eq_intro (Seq.slice data 0 n) data; + Seq.lemma_eq_elim (Seq.slice data 0 n) data + +/// --- RangeMap ↔ Contents Bridge --- + +module RMSpec = Pulse.Lib.RangeMap.Spec + +/// Bridge: RangeMap intervals (absolute offsets) match the option-byte contents. +/// Intervals use absolute stream positions; contents is indexed relative to base_offset. +/// For every position i, mem repr (base_offset + i) <==> Some? contents[i]. +/// All interval endpoints are bounded by base_offset + Seq.length contents. +let ranges_match_contents + (repr: Seq.seq RMSpec.interval) + (contents: Seq.seq (option byte)) + (base_offset: nat) : prop = + Seq.length contents > 0 /\ + (forall (i:nat{i < Seq.length contents}). + RMSpec.mem repr (base_offset + i) <==> Some? (Seq.index contents i)) /\ + RMSpec.range_map_bounded repr (base_offset + Seq.length contents) + +/// base_offset is within the first interval (or repr is empty). +/// Invariant of the CircularBuffer: first interval starts at 0 and base_offset +/// only advances by drain (contiguous_from), so it stays within the first interval. +let base_aligned (repr: Seq.seq RMSpec.interval) (base_offset: nat) : prop = + Seq.length repr = 0 \/ + (let first = Seq.index repr 0 in first.low <= base_offset /\ base_offset <= RMSpec.high first) + +/// Empty repr matches all-None contents (base_offset = 0), and is trivially base_aligned. +let ranges_match_empty (al: pos) + : Lemma (ranges_match_contents Seq.empty (Seq.create al (None #byte)) 0 /\ + base_aligned Seq.empty 0) + = let contents : Seq.seq (option byte) = Seq.create al None in + let aux (i:nat{i < Seq.length contents}) + : Lemma (RMSpec.mem Seq.empty (0 + i) <==> Some? (Seq.index contents i)) + = () + in + FStar.Classical.forall_intro aux + +/// Writing data preserves the bridge. +/// add_range uses absolute offset (base_offset + rel_offset). +let ranges_match_write + (repr: Seq.seq RMSpec.interval) + (contents: Seq.seq (option byte)) + (base_offset: nat) (rel_offset: nat) (data: Seq.seq byte) + : Lemma + (requires + ranges_match_contents repr contents base_offset /\ + Seq.length data > 0 /\ + rel_offset + Seq.length data <= Seq.length contents) + (ensures + ranges_match_contents + (RMSpec.add_range repr (base_offset + rel_offset) (Seq.length data)) + (GT.write_range_contents contents rel_offset data) + base_offset) + = let len = Seq.length data in + let abs_offset = base_offset + rel_offset in + let new_repr = RMSpec.add_range repr abs_offset len in + let new_contents = GT.write_range_contents contents rel_offset data in + let aux (i:nat{i < Seq.length new_contents}) + : Lemma (RMSpec.mem new_repr (base_offset + i) <==> Some? (Seq.index new_contents i)) + = GT.write_range_index contents rel_offset data i; + let abs_i = base_offset + i in + if rel_offset <= i && i < rel_offset + len then ( + assert (abs_offset <= abs_i && abs_i < abs_offset + len); + RMSpec.add_range_mem_new repr abs_offset len abs_i + ) else ( + assert (Seq.index new_contents i == Seq.index contents i); + assert (not (abs_offset <= abs_i && abs_i < abs_offset + len)); + if Some? (Seq.index contents i) then ( + assert (RMSpec.mem repr abs_i); + RMSpec.add_range_mem_old repr abs_offset len abs_i + ) else (); + if RMSpec.mem new_repr abs_i then ( + RMSpec.add_range_mem_inv repr abs_offset len abs_i; + assert (RMSpec.mem repr abs_i) + ) else () + ) + in + FStar.Classical.forall_intro aux; + RMSpec.add_range_bounded repr abs_offset len (base_offset + Seq.length contents) + +/// Resize preserves the bridge: extending contents with Nones doesn't break the correspondence. +let ranges_match_resize + (repr: Seq.seq RMSpec.interval) + (contents: Seq.seq (option byte)) + (base_offset: nat) + (old_al: pos) (new_al: pos) + : Lemma + (requires + ranges_match_contents repr contents base_offset /\ + Seq.length contents == old_al /\ + new_al >= old_al) + (ensures + ranges_match_contents repr (resized_contents old_al new_al contents) base_offset) + = let new_c = resized_contents old_al new_al contents in + let aux (i:nat{i < Seq.length new_c}) + : Lemma (RMSpec.mem repr (base_offset + i) <==> Some? (Seq.index new_c i)) + = if i < old_al then () + else RMSpec.mem_bounded repr (base_offset + old_al) (base_offset + i) + in + FStar.Classical.forall_intro aux; + RMSpec.range_map_bounded_monotone repr (base_offset + old_al) (base_offset + new_al) + +/// Lower bound: contiguous_from is always <= contiguous_prefix_length. +/// Does NOT require base_aligned. When base_aligned holds, use ranges_match_prefix for equality. +let ranges_match_prefix_lower + (repr: Seq.seq RMSpec.interval) + (contents: Seq.seq (option byte)) + (base_offset: nat) + : Lemma + (requires ranges_match_contents repr contents base_offset /\ + RMSpec.range_map_wf repr) + (ensures RMSpec.contiguous_from repr base_offset <= GT.contiguous_prefix_length contents) + = RMSpec.cf_bounded repr base_offset (base_offset + Seq.length contents); + let cf = RMSpec.contiguous_from repr base_offset in + GT.prefix_length_bounded contents; + if cf > 0 then ( + let first = Seq.index repr 0 in + assert (first.low <= base_offset /\ base_offset < RMSpec.high first); + assert (cf == RMSpec.high first - base_offset); + assert (cf <= Seq.length contents); + let aux (i:nat{i < cf}) + : Lemma (Some? (Seq.index contents i)) + = assert (i < Seq.length contents); + assert (base_offset + i < RMSpec.high first); + assert (first.low <= base_offset + i); + assert (RMSpec.in_interval first (base_offset + i)); + assert (RMSpec.mem repr (base_offset + i)) + in + FStar.Classical.forall_intro aux; + GT.all_some_prefix_ge contents cf + ) else () + +/// Prefix equivalence: under the bridge + base_aligned, +/// contiguous_from(repr, base_offset) == contiguous_prefix_length(contents). +#push-options "--z3rlimit_factor 4 --fuel 2 --ifuel 1" +let ranges_match_prefix + (repr: Seq.seq RMSpec.interval) + (contents: Seq.seq (option byte)) + (base_offset: nat) + : Lemma + (requires ranges_match_contents repr contents base_offset /\ + RMSpec.range_map_wf repr /\ + base_aligned repr base_offset) + (ensures RMSpec.contiguous_from repr base_offset == GT.contiguous_prefix_length contents) + = RMSpec.cf_bounded repr base_offset (base_offset + Seq.length contents); + let cf = RMSpec.contiguous_from repr base_offset in + let cpl = GT.contiguous_prefix_length contents in + GT.prefix_length_bounded contents; + // Direction 1: cf <= cpl + if cf > 0 then ( + let first = Seq.index repr 0 in + assert (first.low <= base_offset /\ base_offset < RMSpec.high first); + assert (cf == RMSpec.high first - base_offset); + assert (cf <= Seq.length contents); + let aux (i:nat{i < cf}) + : Lemma (Some? (Seq.index contents i)) + = assert (i < Seq.length contents); + assert (base_offset + i < RMSpec.high first); + assert (first.low <= base_offset + i); + assert (RMSpec.in_interval first (base_offset + i)); + assert (RMSpec.mem repr (base_offset + i)) + in + FStar.Classical.forall_intro aux; + GT.all_some_prefix_ge contents cf + ) else (); + // Direction 2: cpl <= cf (by contradiction: assume cpl > cf, derive false) + if cpl > cf then ( + GT.prefix_elements_are_some contents cf; + assert (Some? (Seq.index contents cf)); + assert (RMSpec.mem repr (base_offset + cf)); + if Seq.length repr = 0 then () + else ( + let first = Seq.index repr 0 in + // From base_aligned: first.low <= base_offset <= high first + if first.low <= base_offset && base_offset < RMSpec.high first then ( + // cf = high first - base_offset, so base_offset + cf = high first + // high first is NOT in the first interval (boundary), so must be in tail + assert (not (RMSpec.in_interval first (base_offset + cf))); + RMSpec.mem_tail repr (base_offset + cf); + if Seq.length (Seq.tail repr) > 0 then + // tail membership implies position > high first, but position = high first + RMSpec.mem_wf_tail_gt repr (base_offset + cf) + else () + ) else ( + // first.low <= base_offset (from base_aligned) AND NOT (base_offset < high first) + // AND base_offset <= high first (from base_aligned) + // Therefore base_offset = high first, and cf = 0 + assert (base_offset == RMSpec.high first); + assert (not (RMSpec.in_interval first base_offset)); + RMSpec.mem_tail repr base_offset; + if Seq.length (Seq.tail repr) > 0 then + // tail membership implies position > high first = base_offset, contradiction + RMSpec.mem_wf_tail_gt repr base_offset + else () + ) + ) + ) else () +#pop-options + +/// Drain preservation: the bridge is automatically preserved by index substitution. +let ranges_match_drain + (repr: Seq.seq RMSpec.interval) + (contents: Seq.seq (option byte)) + (base_offset: nat) (n: nat) + : Lemma + (requires ranges_match_contents repr contents base_offset /\ + n <= Seq.length contents /\ + Seq.length contents - n > 0) + (ensures ranges_match_contents repr (Seq.slice contents n (Seq.length contents)) (base_offset + n)) + = let new_contents = Seq.slice contents n (Seq.length contents) in + let new_base = base_offset + n in + let aux (i:nat{i < Seq.length new_contents}) + : Lemma (RMSpec.mem repr (new_base + i) <==> Some? (Seq.index new_contents i)) + = assert (new_base + i == base_offset + (n + i)); + assert (Seq.index new_contents i == Seq.index contents (n + i)) + in + FStar.Classical.forall_intro aux; + assert (base_offset + Seq.length contents == new_base + Seq.length new_contents) + +/// Drain with padding: bridge preserved for drained_contents (slice + None padding). +/// This is what the actual CircularBuffer drain uses (keeps length = alloc_length). +#push-options "--z3rlimit_factor 2" +let ranges_match_drain_padded + (repr: Seq.seq RMSpec.interval) + (contents: Seq.seq (option byte)) + (base_offset: nat) (n: nat) + : Lemma + (requires ranges_match_contents repr contents base_offset /\ + n <= Seq.length contents) + (ensures ranges_match_contents repr + (drained_contents (Seq.length contents) contents n) + (base_offset + n)) + = let al = Seq.length contents in + let new_contents = drained_contents al contents n in + let new_base = base_offset + n in + assert (Seq.length new_contents == al); + assert (al > 0); + let aux (i:nat{i < al}) + : Lemma (RMSpec.mem repr (new_base + i) <==> Some? (Seq.index new_contents i)) + = if i < al - n then ( + // Position in sliced region: new_contents[i] = contents[n + i] + assert (Seq.index new_contents i == Seq.index contents (n + i)); + assert (new_base + i == base_offset + (n + i)); + assert (n + i < al) + ) else ( + // Position in padding region: new_contents[i] = None + assert (Seq.index new_contents i == None #byte); + // base_offset + (n + i) >= base_offset + al, beyond all intervals + assert (new_base + i == base_offset + (n + i)); + assert (n + i >= al); + RMSpec.mem_bounded repr (base_offset + al) (base_offset + (n + i)) + ) + in + FStar.Classical.forall_intro aux; + // Bounded: old bound (base_offset + al) <= new bound (base_offset + n + al) + RMSpec.range_map_bounded_monotone repr (base_offset + al) (base_offset + n + al) +#pop-options + +/// Drain repr bridge: drain_repr preserves ranges_match_contents with new base +let ranges_match_drain_repr + (repr: Seq.seq RMSpec.interval) + (contents: Seq.seq (option byte)) + (base_offset: nat) (n: nat) + : Lemma + (requires ranges_match_contents repr contents base_offset /\ + RMSpec.range_map_wf repr /\ + base_aligned repr base_offset /\ + n <= RMSpec.contiguous_from repr base_offset /\ + n <= Seq.length contents) + (ensures ranges_match_contents + (RMSpec.drain_repr repr (base_offset + n)) + (drained_contents (Seq.length contents) contents n) + (base_offset + n)) + = let al = Seq.length contents in + let new_bo = base_offset + n in + let new_contents = drained_contents al contents n in + let new_repr = RMSpec.drain_repr repr new_bo in + // First: original bridge after drain (using unchanged repr) + ranges_match_drain_padded repr contents base_offset n; + // So: ranges_match_contents repr new_contents new_bo + // Now show: drain_repr preserves mem for positions >= new_bo + if Seq.length repr = 0 then () + else begin + let first = Seq.index repr 0 in + assert (first.low <= base_offset); + assert (base_offset <= RMSpec.high first); + assert (new_bo <= RMSpec.high first); + // drain_repr_mem_above: for x >= new_bo, mem new_repr x == mem repr x + let aux (i:nat{i < al}) + : Lemma (RMSpec.mem new_repr (new_bo + i) <==> Some? (Seq.index new_contents i)) + = RMSpec.drain_repr_mem_above repr new_bo (new_bo + i) + in + FStar.Classical.forall_intro aux; + // Bounded + RMSpec.drain_repr_bounded repr new_bo (base_offset + al); + RMSpec.range_map_bounded_monotone new_repr (base_offset + al) (new_bo + al) + end + +/// Drain preserves base_aligned when draining at most contiguous_from bytes. +let drain_preserves_base_aligned + (repr: Seq.seq RMSpec.interval) + (base_offset: nat) (n: nat) + : Lemma + (requires base_aligned repr base_offset /\ + n <= RMSpec.contiguous_from repr base_offset) + (ensures base_aligned repr (base_offset + n)) + = if Seq.length repr = 0 then () + else ( + let first = Seq.index repr 0 in + if first.low <= base_offset && base_offset < RMSpec.high first then ( + assert (RMSpec.contiguous_from repr base_offset == RMSpec.high first - base_offset); + assert (base_offset + n <= RMSpec.high first) + ) else ( + assert (base_offset + n == base_offset) + ) + ) + + +/// 3-way invariant: empty, gap (first starts after base), or base_aligned. +/// Excludes the unreachable case where first starts at/before base but base is past the end. +let repr_well_positioned (repr: Seq.seq RMSpec.interval) (base_offset: nat) : prop = + Seq.length repr = 0 \/ + (Seq.index repr 0).low > base_offset \/ + ((Seq.index repr 0).low <= base_offset /\ base_offset <= RMSpec.high (Seq.index repr 0)) + +/// repr_well_positioned subsumes base_aligned +let base_aligned_implies_rwp (repr: Seq.seq RMSpec.interval) (base_offset: nat) + : Lemma (requires base_aligned repr base_offset) + (ensures repr_well_positioned repr base_offset) = () + +/// Empty repr matches create_nones contents (base_offset = 0), with all invariants. +let ranges_match_create_nones (al: pos) + : Lemma (ranges_match_contents Seq.empty (GT.create_nones al) 0 /\ + RMSpec.range_map_wf Seq.empty /\ + repr_well_positioned Seq.empty 0) + = let contents = GT.create_nones al in + let aux (i:nat{i < Seq.length contents}) + : Lemma (RMSpec.mem Seq.empty (0 + i) <==> Some? (Seq.index contents i)) + = GT.create_nones_all_none al i + in + FStar.Classical.forall_intro aux + +/// repr_well_positioned implies cf == cpl (the key property for drain_rm postconditions) +let rwp_cf_eq_cpl + (repr: Seq.seq RMSpec.interval) + (contents: Seq.seq (option byte)) + (base_offset: nat) + : Lemma + (requires ranges_match_contents repr contents base_offset /\ + RMSpec.range_map_wf repr /\ + repr_well_positioned repr base_offset) + (ensures RMSpec.contiguous_from repr base_offset == GT.contiguous_prefix_length contents) + = if Seq.length repr = 0 then ( + // Empty repr: no members, all None + ranges_match_prefix repr contents base_offset + ) else if (Seq.index repr 0).low > base_offset then ( + // Gap state: first starts after base + // cf = 0 (first doesn't cover base) + assert (RMSpec.contiguous_from repr base_offset == 0); + // position base_offset is not a member (below first.low, which is the lowest) + // So contents[0] = None, hence cpl = 0 + let first = Seq.index repr 0 in + assert (not (RMSpec.in_interval first base_offset)); + RMSpec.mem_not_below_first repr base_offset; + assert (not (RMSpec.mem repr base_offset)); + // contents[0] = None since mem repr (base_offset + 0) = false + assert (Seq.length contents > 0); + // cpl = 0 since contents[0] = None (from ranges_match_contents + non-membership) + assert (not (Some? (Seq.index contents 0))); + assert (GT.contiguous_prefix_length contents == 0) + ) else ( + // base_aligned: first.low <= base_offset <= high first + ranges_match_prefix repr contents base_offset + ) + +/// Write preserves repr_well_positioned +let write_preserves_rwp + (repr: Seq.seq RMSpec.interval) (base_offset: nat) (rel_offset: nat) (len: pos) + : Lemma + (requires RMSpec.range_map_wf repr /\ + repr_well_positioned repr base_offset) + (ensures repr_well_positioned (RMSpec.add_range repr (base_offset + rel_offset) len) base_offset) + = let offset = base_offset + rel_offset in + let r = RMSpec.add_range repr offset len in + if Seq.length repr = 0 then ( + // Empty repr: add creates [{offset, len}] + if rel_offset = 0 then + // Write at base: new first at base, base_aligned + RMSpec.add_range_at_base_establishes_aligned repr base_offset len + else + // Gap write: new first at offset > base + RMSpec.add_range_preserves_gap repr base_offset offset len + ) else if (Seq.index repr 0).low > base_offset then ( + // Gap state + if rel_offset = 0 then + // Write at base into gap state: establishes base_aligned + RMSpec.add_range_at_base_establishes_aligned repr base_offset len + else + // Gap write: offset > base_offset, preserves gap + RMSpec.add_range_preserves_gap repr base_offset offset len + ) else ( + // base_aligned: first.low <= base_offset <= high first + // offset = base_offset + rel_offset >= base_offset >= first.low + RMSpec.add_range_base_aligned repr base_offset offset len + ) + +/// Drain preserves repr_well_positioned +let drain_preserves_rwp + (repr: Seq.seq RMSpec.interval) (base_offset: nat) (n: nat) + : Lemma + (requires repr_well_positioned repr base_offset /\ + n <= RMSpec.contiguous_from repr base_offset) + (ensures repr_well_positioned repr (base_offset + n)) + = if Seq.length repr = 0 then () + else if (Seq.index repr 0).low > base_offset then ( + // Gap state: cf = 0, n = 0 + assert (RMSpec.contiguous_from repr base_offset == 0); + assert (n == 0) + ) else ( + // base_aligned: drain preserves it + drain_preserves_base_aligned repr base_offset n + ) + +/// drain_repr preserves repr_well_positioned with new base +let drain_repr_preserves_rwp + (repr: Seq.seq RMSpec.interval) (base_offset: nat) (n: nat) + : Lemma + (requires repr_well_positioned repr base_offset /\ + RMSpec.range_map_wf repr /\ + base_aligned repr base_offset /\ + n <= RMSpec.contiguous_from repr base_offset) + (ensures repr_well_positioned (RMSpec.drain_repr repr (base_offset + n)) (base_offset + n)) + = if Seq.length repr = 0 then () + else + let first = Seq.index repr 0 in + let new_bo = base_offset + n in + let result = RMSpec.drain_repr repr new_bo in + if RMSpec.high first <= new_bo then begin + let tl = Seq.tail repr in + if Seq.length tl = 0 then () + else begin + let next = Seq.index tl 0 in + assert (Seq.index repr 1 == next); + assert (RMSpec.separated first next); + assert (next.low > RMSpec.high first); + assert (next.low > new_bo); + assert (Seq.index result 0 == next) + end + end else if first.low < new_bo then begin + let trimmed = { RMSpec.low = new_bo; RMSpec.count = RMSpec.high first - new_bo } in + assert (Seq.index result 0 == trimmed); + assert (trimmed.low == new_bo); + assert (new_bo <= RMSpec.high trimmed) + end else + assert (new_bo == base_offset) + +/// Master lemma: write preserves cf == cpl under the 3-way invariant +#push-options "--z3rlimit_factor 2" +let write_preserves_cf_eq_cpl + (repr: Seq.seq RMSpec.interval) + (contents: Seq.seq (option byte)) + (base_offset: nat) + (rel_offset: nat) + (data: Seq.seq byte) + : Lemma + (requires + ranges_match_contents repr contents base_offset /\ + RMSpec.range_map_wf repr /\ + repr_well_positioned repr base_offset /\ + RMSpec.contiguous_from repr base_offset == GT.contiguous_prefix_length contents /\ + Seq.length data > 0 /\ + rel_offset + Seq.length data <= Seq.length contents) + (ensures ( + let new_repr = RMSpec.add_range repr (base_offset + rel_offset) (Seq.length data) in + let new_contents = GT.write_range_contents contents rel_offset data in + RMSpec.contiguous_from new_repr base_offset == + GT.contiguous_prefix_length new_contents)) + = let len = Seq.length data in + let new_repr = RMSpec.add_range repr (base_offset + rel_offset) len in + let new_contents = GT.write_range_contents contents rel_offset data in + // Prove preservation of ranges_match_contents and wf for new state + ranges_match_write repr contents base_offset rel_offset data; + RMSpec.add_range_wf repr (base_offset + rel_offset) len; + // Prove repr_well_positioned for new state + write_preserves_rwp repr base_offset rel_offset len; + // Use rwp_cf_eq_cpl on new state + rwp_cf_eq_cpl new_repr new_contents base_offset +#pop-options + +/// --- Trim Write (for absolute-offset API) --- + +/// Trim a write to remove bytes before base_offset (already consumed). +/// Returns (relative_offset, trimmed_data_length, skip_count). +/// skip_count is the number of leading bytes to skip from src. +let trim_write_params (abs_offset: nat) (write_len: nat) (base_offset: nat) + : (nat & nat & nat) // (rel_offset, trimmed_len, skip) + = let abs_end = abs_offset + write_len in + if abs_end <= base_offset then (0, 0, 0) // fully stale — no bytes to write + else if abs_offset >= base_offset then + (abs_offset - base_offset, write_len, 0) // no overlap — all bytes valid + else + let skip = base_offset - abs_offset in // partial overlap — skip consumed prefix + (0, write_len - skip, skip) + +/// Stale check: true if the entire write is before base_offset +let is_stale_write (abs_offset: nat) (write_len: nat) (base_offset: nat) : bool = + abs_offset + write_len <= base_offset + +/// After trimming, the relative offset + trimmed length fits in alloc_length +/// if the original absolute write fits in base_offset + alloc_length. +let trim_write_in_bounds + (abs_offset: nat) (write_len: nat) (base_offset: nat) (alloc_length: nat) + : Lemma + (requires + write_len > 0 /\ + abs_offset + write_len <= base_offset + alloc_length /\ + not (is_stale_write abs_offset write_len base_offset)) + (ensures ( + let (rel_off, tlen, skip) = trim_write_params abs_offset write_len base_offset in + rel_off + tlen <= alloc_length /\ + skip + tlen == write_len /\ + tlen > 0)) + = () + +/// The trimmed write produces the same logical result as writing only the +/// non-stale portion: write_range_contents at rel_offset with data[skip..]. +let trim_write_equiv + (abs_offset: nat) (write_len: nat) (base_offset: nat) + (contents: Seq.seq (option byte)) (data: Seq.seq byte) + : Lemma + (requires + Seq.length data == write_len /\ + not (is_stale_write abs_offset write_len base_offset) /\ + (let (rel_off, tlen, skip) = trim_write_params abs_offset write_len base_offset in + rel_off + tlen <= Seq.length contents)) + (ensures ( + let (rel_off, tlen, skip) = trim_write_params abs_offset write_len base_offset in + GT.write_range_contents_t contents rel_off (Seq.slice data skip (skip + tlen)) == + GT.write_range_contents contents rel_off (Seq.slice data skip (skip + tlen)))) + = let (rel_off, tlen, skip) = trim_write_params abs_offset write_len base_offset in + GT.write_range_contents_t_eq contents rel_off (Seq.slice data skip (skip + tlen)) + +/// Needed resize size: smallest pow2 >= abs_end - base_offset +let needed_alloc_for_write (abs_offset: nat) (write_len: nat) (base_offset: nat) : nat = + if abs_offset + write_len <= base_offset then 0 + else abs_offset + write_len - base_offset + +/// Count bound: when repr starts at/after base_offset, the count is bounded. +/// In gap case (first.low > bo): 2n <= al +/// In base_aligned case (first.low <= bo): 2n <= bo + al - first.low + 1 +/// For the write fold site, we only use the gap case + the empty case. +let repr_count_bound_gap + (repr: Seq.seq RMSpec.interval) (base_offset: nat) (al: pos) + : Lemma + (requires RMSpec.range_map_wf repr /\ + RMSpec.range_map_bounded repr (base_offset + al) /\ + Seq.length repr > 0 /\ + (Seq.index repr 0).low >= base_offset) + (ensures Seq.length repr + Seq.length repr <= al + 1) + = RMSpec.wf_count_bound repr base_offset (base_offset + al) diff --git a/lib/pulse/lib/Pulse.Lib.CircularBuffer.fst b/lib/pulse/lib/Pulse.Lib.CircularBuffer.fst new file mode 100644 index 000000000..4b0e50e2f --- /dev/null +++ b/lib/pulse/lib/Pulse.Lib.CircularBuffer.fst @@ -0,0 +1,1202 @@ +module Pulse.Lib.CircularBuffer + +#lang-pulse +open Pulse.Lib.Pervasives +open Pulse.Lib.Vec +open FStar.SizeT +module Seq = FStar.Seq +module SZ = FStar.SizeT +module U64 = FStar.UInt64 +module B = Pulse.Lib.Box +open Pulse.Lib.Box { box, (:=), (!) } +module R = Pulse.Lib.Reference +module Spec = Pulse.Lib.CircularBuffer.Spec +open Pulse.Lib.CircularBuffer.Spec +module Pow2 = Pulse.Lib.CircularBuffer.Pow2 +module GT = Pulse.Lib.CircularBuffer.GapTrack +module Mod = Pulse.Lib.CircularBuffer.Modular +module A = Pulse.Lib.Array +module RM = Pulse.Lib.RangeVec +module RMSpec = Pulse.Lib.RangeMap.Spec +module PTR = Pulse.Lib.Array.PtsToRange +open Pulse.Lib.Trade + +type byte = Spec.byte + +/// Prove pow2_63 equals Prims.pow2 63 (checked by normalization, not Z3) +private let _pow2_63_eq : squash (pow2_63 = Prims.pow2 63) = + assert_norm (pow2_63 = Prims.pow2 63) + +/// Pre-compute pow2 64 so Z3 doesn't evaluate it for SZ.fits_u64_implies_fits +private let _pow2_64_val : squash (Prims.pow2 64 = 0x10000000000000000) = + assert_norm (Prims.pow2 64 = 0x10000000000000000) + +let lemma_nones_coherent (al:pos) (phys:Seq.seq byte{Seq.length phys == al}) (rs:nat{rs < al}) + : Lemma (Spec.phys_log_coherent al phys (GT.create_nones al) rs) + = let aux (i:nat{i < al}) + : Lemma (Spec.coherent_at al phys (GT.create_nones al) rs i) + = GT.create_nones_all_none al i + in + FStar.Classical.forall_intro aux + +/// Platform assumption: SZ.t is at least 64 bits (true on all MsQuic targets). +assume val platform_fits_u64 : squash SZ.fits_u64 + +/// cb_max_length fits in SZ.t (follows from cb_max_length <= pow2_63 and fits_u64) +let cb_max_length_sz : SZ.t = + SZ.fits_u64_implies_fits Spec.cb_max_length; + SZ.uint_to_t Spec.cb_max_length + +let lemma_idx_sum_fits (al: SZ.t) (a b: SZ.t) + : Lemma (requires SZ.v a < SZ.v al /\ SZ.v b <= SZ.v al /\ + SZ.v al > 0 /\ SZ.v al <= pow2_63) + (ensures SZ.fits (SZ.v a + SZ.v b)) + = SZ.fits_u64_implies_fits (SZ.v a + SZ.v b) + +let lemma_inc_fits (x: SZ.t) (bound: SZ.t) + : Lemma (requires SZ.v x < SZ.v bound) + (ensures SZ.fits (SZ.v x + 1)) + = SZ.fits_lte (SZ.v x + 1) (SZ.v bound) + +/// Bridge: SZ.mod_spec equals Prims.op_Modulus for non-negative values +let lemma_mod_spec_eq (a: nat{SZ.fits a}) (b: pos{SZ.fits b}) + : Lemma (SZ.mod_spec a b == a % b) + = FStar.Math.Lemmas.euclidean_division_definition a b + +/// Prove that the copy loop produces exactly linearized_phys +let lemma_loop_is_linearized + (old_al: pos) (new_al: pos) + (old_phys: Seq.seq byte{Seq.length old_phys == old_al}) + (old_rs: nat{old_rs < old_al}) + (new_data: Seq.seq byte{Seq.length new_data == new_al}) + : Lemma + (requires + new_al >= old_al /\ + (forall (j:nat). j < old_al ==> + Seq.index new_data j == Seq.index old_phys ((old_rs + j) % old_al)) /\ + (forall (j:nat). (old_al <= j /\ j < new_al) ==> + Seq.index new_data j == 0uy)) + (ensures new_data == Spec.linearized_phys old_al new_al old_phys old_rs) + = let expected = Spec.linearized_phys old_al new_al old_phys old_rs in + let aux (j:nat{j < new_al}) + : Lemma (Seq.index new_data j == Seq.index expected j) + = Mod.circular_index_in_bounds old_rs j old_al + in + FStar.Classical.forall_intro aux; + Seq.lemma_eq_intro new_data expected + +/// Helper lemma: prove that Seq.upd maintains the resize loop invariant +let lemma_resize_invariant_step + (old_al: pos) (new_al: pos) + (old_phys: Seq.seq byte{Seq.length old_phys == old_al}) + (old_rs: nat{old_rs < old_al}) + (new_seq: Seq.seq byte{Seq.length new_seq == new_al}) + (vi: nat{vi < old_al /\ vi < new_al}) + (byte_val: byte) + : Lemma + (requires + (forall (j:nat). j < vi ==> + Seq.index new_seq j == Seq.index old_phys ((old_rs + j) % old_al)) /\ + (forall (j:nat). (vi <= j /\ j < new_al) ==> + Seq.index new_seq j == 0uy) /\ + byte_val == Seq.index old_phys ((old_rs + vi) % old_al)) + (ensures ( + let new_seq' = Seq.upd new_seq vi byte_val in + (forall (j:nat). j < vi + 1 ==> + Seq.index new_seq' j == Seq.index old_phys ((old_rs + j) % old_al)) /\ + (forall (j:nat). (vi + 1 <= j /\ j < new_al) ==> + Seq.index new_seq' j == 0uy))) + = let new_seq' = Seq.upd new_seq vi byte_val in + let aux (j:nat{j < new_al}) + : Lemma ( + (j < vi + 1 ==> Seq.index new_seq' j == Seq.index old_phys ((old_rs + j) % old_al)) /\ + (vi + 1 <= j ==> Seq.index new_seq' j == 0uy)) + = if j = vi then + Seq.lemma_index_upd1 new_seq vi byte_val + else + Seq.lemma_index_upd2 new_seq vi byte_val j + in + FStar.Classical.forall_intro aux + +noeq +type cb_internal = { + buf: vec byte; // Physical array (mutable via box for resize) + rs: SZ.t; // read_start (mutable) + al: SZ.t; // alloc_length (mutable, changes on resize) + pl: SZ.t; // prefix_length (mutable, tracks contiguous readable data) + vl: SZ.t; // virtual_length (constant) + bo: SZ.t; // base_offset (absolute stream position of read_start) +} + +type circular_buffer = box cb_internal + +let is_circular_buffer + ([@@@mkey]cb: circular_buffer) + (rm: RM.range_vec_t) + (st: Spec.cb_state) : slprop = + exists* (cbi: cb_internal) (buf_data: Seq.seq byte) (repr: Seq.seq RMSpec.interval). + B.pts_to cb cbi ** + Vec.pts_to cbi.buf buf_data ** + RM.is_range_vec rm repr ** + pure ( + SZ.v cbi.al > 0 /\ + SZ.v cbi.al == st.alloc_length /\ + SZ.v cbi.vl == st.virtual_length /\ + SZ.v cbi.rs == st.read_start /\ + SZ.v cbi.bo == st.base_offset /\ + SZ.v cbi.pl == RMSpec.contiguous_from repr (SZ.v cbi.bo) /\ + SZ.v cbi.pl == GT.contiguous_prefix_length st.contents /\ + Seq.length buf_data == SZ.v cbi.al /\ + is_full_vec cbi.buf /\ + Spec.cb_wf st /\ + SZ.v cbi.al <= pow2_63 /\ + st.virtual_length <= pow2_63 /\ + Spec.phys_log_coherent st.alloc_length buf_data st.contents st.read_start /\ + Spec.ranges_match_contents repr st.contents (SZ.v cbi.bo) /\ + RMSpec.range_map_wf repr /\ + Spec.repr_well_positioned repr (SZ.v cbi.bo) /\ + (Seq.length repr = 0 \/ (Seq.index repr 0).low >= SZ.v cbi.bo) /\ + Seq.length repr < RM.max_range_vec_entries + ) + +/// Get the length of contiguous readable data +fn read_length + (cb: circular_buffer) (rm: RM.range_vec_t) + (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st + returns n : SZ.t + ensures is_circular_buffer cb rm st ** + pure (SZ.v n == GT.contiguous_prefix_length st.contents) +{ + unfold (is_circular_buffer cb rm st); + with cbi buf_data repr. _; + let cb_val = !cb; + rewrite (Vec.pts_to cbi.buf buf_data) as (Vec.pts_to cb_val.buf buf_data); + let n = cb_val.pl; + rewrite (Vec.pts_to cb_val.buf buf_data) as (Vec.pts_to cbi.buf buf_data); + fold (is_circular_buffer cb rm st); + n +} + +fn get_total_length + (cb: circular_buffer) (rm: RM.range_vec_t) + (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st + returns n: SZ.t + ensures is_circular_buffer cb rm st ** + pure (SZ.v n <= st.base_offset + st.alloc_length) +{ + unfold (is_circular_buffer cb rm st); + with cbi buf_data repr. _; + let cb_val = !cb; + rewrite (Vec.pts_to cbi.buf buf_data) as (Vec.pts_to cb_val.buf buf_data); + let n = RM.range_vec_max_endpoint rm; + rewrite (Vec.pts_to cb_val.buf buf_data) as (Vec.pts_to cbi.buf buf_data); + RMSpec.range_map_max_endpoint_bounded repr (SZ.v cbi.bo + SZ.v cbi.al); + fold (is_circular_buffer cb rm st); + n +} + +fn set_virtual_length + (cb: circular_buffer) (rm: RM.range_vec_t) (new_vl: SZ.t{SZ.v new_vl > 0}) + (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st ** + pure (Spec.cb_wf st /\ + Pow2.is_pow2 (SZ.v new_vl) /\ + SZ.v new_vl >= st.virtual_length /\ + SZ.v new_vl <= pow2_63) + ensures is_circular_buffer cb rm ({ st with virtual_length = SZ.v new_vl }) +{ + unfold (is_circular_buffer cb rm st); + with cbi buf_data repr. _; + let cb_val = !cb; + rewrite (Vec.pts_to cbi.buf buf_data) as (Vec.pts_to cb_val.buf buf_data); + let new_cbi = Mkcb_internal cb_val.buf cb_val.rs cb_val.al cb_val.pl new_vl cb_val.bo; + ( := ) cb new_cbi; + rewrite (Vec.pts_to cb_val.buf buf_data) as (Vec.pts_to new_cbi.buf buf_data); + fold (is_circular_buffer cb rm ({ st with virtual_length = SZ.v new_vl })); + () +} + +#push-options "--fuel 1 --ifuel 1 --z3rlimit_factor 4" +fn create + (alloc_len: SZ.t{SZ.v alloc_len > 0}) + (virt_len: SZ.t{SZ.v virt_len > 0}) + requires pure ( + Pow2.is_pow2 (SZ.v alloc_len) /\ + Pow2.is_pow2 (SZ.v virt_len) /\ + SZ.v alloc_len <= SZ.v virt_len /\ + SZ.v alloc_len <= Spec.cb_max_length /\ + SZ.v virt_len <= pow2_63) + returns res : (circular_buffer & RM.range_vec_t) + ensures exists* st. + is_circular_buffer (fst res) (snd res) st ** + pure (Spec.cb_wf st /\ + st.base_offset == 0 /\ + st.alloc_length == SZ.v alloc_len /\ + st.virtual_length == SZ.v virt_len /\ + GT.contiguous_prefix_length st.contents == 0) +{ + let buf_vec : vec byte = alloc #byte 0uy alloc_len; + let al_v : SZ.t = alloc_len; + let vl_v : SZ.t = virt_len; + + let vi = Mkcb_internal buf_vec 0sz al_v 0sz vl_v 0sz; + let cb = B.alloc vi; + let rm = RM.range_vec_create (); + + with buf_data. assert (Vec.pts_to buf_vec buf_data); + lemma_nones_coherent (SZ.v alloc_len) buf_data 0; + GT.prefix_of_nones (SZ.v alloc_len); + Spec.ranges_match_create_nones (SZ.v alloc_len); + + rewrite (Vec.pts_to buf_vec buf_data) as (Vec.pts_to vi.buf buf_data); + + fold (is_circular_buffer cb rm ({ + base_offset = 0; read_start = 0; + alloc_length = SZ.v alloc_len; virtual_length = SZ.v virt_len; + contents = GT.create_nones (SZ.v alloc_len) + })); + (cb, rm) +} +#pop-options + +/// Resize: grow buffer while preserving range map bridge. +fn resize + (cb: circular_buffer) (rm: RM.range_vec_t) (new_al: SZ.t{SZ.v new_al > 0}) + (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st ** + pure (Spec.cb_wf st /\ Pow2.is_pow2 (SZ.v new_al) /\ + SZ.v new_al >= st.alloc_length /\ + SZ.v new_al <= st.virtual_length /\ + SZ.v new_al <= Spec.cb_max_length /\ + SZ.v new_al <= pow2_63) + ensures is_circular_buffer cb rm (Spec.resize_result st (SZ.v new_al)) +{ + unfold (is_circular_buffer cb rm st); + with cbi buf_data repr. _; + + let cb_val = !cb; + rewrite (Vec.pts_to cbi.buf buf_data) as (Vec.pts_to cb_val.buf buf_data); + + let new_vec : vec byte = alloc #byte 0uy new_al; + let mut i : SZ.t = 0sz; + while (let vi = R.read i; SZ.lt vi cb_val.al) + invariant exists* vi new_v. + B.pts_to cb cbi ** Vec.pts_to cb_val.buf buf_data ** + RM.is_range_vec rm repr ** + R.pts_to i vi ** Vec.pts_to new_vec new_v ** + pure (SZ.v vi <= SZ.v cb_val.al /\ + Seq.length new_v == SZ.v new_al /\ + Seq.length buf_data == SZ.v cb_val.al /\ + is_full_vec cb_val.buf /\ + SZ.v cb_val.al <= pow2_63 /\ + SZ.v cb_val.al > 0 /\ + SZ.v cb_val.rs == st.read_start /\ + SZ.v cb_val.al == st.alloc_length /\ + (forall (j:nat). j < SZ.v vi ==> + Seq.index new_v j == Seq.index buf_data ((st.read_start + j) % st.alloc_length)) /\ + (forall (j:nat). (SZ.v vi <= j /\ j < SZ.v new_al) ==> + Seq.index new_v j == 0uy)) + { + let vi = R.read i; + with new_v. assert (Vec.pts_to new_vec new_v); + lemma_idx_sum_fits cb_val.al cb_val.rs vi; + let temp = SZ.add cb_val.rs vi; + let src_idx = SZ.rem temp cb_val.al; + lemma_mod_spec_eq (SZ.v temp) (SZ.v cb_val.al); + + assert (pure (SZ.v src_idx < Seq.length buf_data)); + let byte_val = cb_val.buf.(src_idx); + assert (pure (byte_val == Seq.index buf_data ((st.read_start + SZ.v vi) % st.alloc_length))); + new_vec.(vi) <- byte_val; + with new_v'. assert (Vec.pts_to new_vec new_v'); + lemma_resize_invariant_step st.alloc_length (SZ.v new_al) buf_data st.read_start new_v (SZ.v vi) byte_val; + lemma_inc_fits vi cb_val.al; + R.write i (SZ.add vi 1sz); + }; + with _vi _new_v. _; + lemma_loop_is_linearized st.alloc_length (SZ.v new_al) buf_data st.read_start _new_v; + assert (pure (_new_v == Spec.linearized_phys st.alloc_length (SZ.v new_al) buf_data st.read_start)); + Vec.free cb_val.buf; + + let new_cbi = Mkcb_internal new_vec 0sz new_al cb_val.pl cb_val.vl cb_val.bo; + ( := ) cb new_cbi; + + with new_buf_data. assert (Vec.pts_to new_vec new_buf_data); + assert (pure (new_buf_data == _new_v)); + rewrite (Vec.pts_to new_vec new_buf_data) as (Vec.pts_to new_cbi.buf new_buf_data); + + Spec.linearize_preserves_coherence st.alloc_length (SZ.v new_al) buf_data st.contents st.read_start; + Spec.resize_prefix_length st.alloc_length (SZ.v new_al) st.contents; + Spec.ranges_match_resize repr st.contents (SZ.v cb_val.bo) st.alloc_length (SZ.v new_al); + fold (is_circular_buffer cb rm (Spec.resize_result st (SZ.v new_al))); +} + +fn free + (cb: circular_buffer) (rm: RM.range_vec_t) (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st + ensures emp +{ + unfold (is_circular_buffer cb rm st); + with cbi buf_data repr. _; + let cb_val = !cb; + rewrite (Vec.pts_to cbi.buf buf_data) as (Vec.pts_to cb_val.buf buf_data); + Vec.free cb_val.buf; + RM.range_vec_free rm; + B.free cb; +} + +fn get_alloc_length + (cb: circular_buffer) (rm: RM.range_vec_t) (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st ** pure (Spec.cb_wf st) + returns n : SZ.t + ensures is_circular_buffer cb rm st ** pure (SZ.v n == st.alloc_length) +{ + unfold (is_circular_buffer cb rm st); + with cbi buf_data repr. _; + let cb_val = !cb; + rewrite (Vec.pts_to cbi.buf buf_data) as (Vec.pts_to cb_val.buf buf_data); + let n = cb_val.al; + rewrite (Vec.pts_to cb_val.buf buf_data) as (Vec.pts_to cbi.buf buf_data); + fold (is_circular_buffer cb rm st); + n +} + +#push-options "--z3rlimit_factor 4" +fn read_buffer + (cb: circular_buffer) + (rm: RM.range_vec_t) + (dst: A.array byte) + (read_len: SZ.t) + (#dst_data: erased (Seq.seq byte)) + (#st: erased Spec.cb_state) + requires + is_circular_buffer cb rm st ** + A.pts_to dst dst_data ** + pure (Spec.cb_wf st /\ + SZ.v read_len <= GT.contiguous_prefix_length st.contents /\ + SZ.v read_len <= st.alloc_length /\ + SZ.v read_len <= A.length dst /\ + A.is_full_array dst) + ensures exists* (dst_data': Seq.seq byte). + is_circular_buffer cb rm st ** + A.pts_to dst dst_data' ** + pure (Spec.cb_wf st /\ + SZ.v read_len <= Seq.length st.contents /\ + SZ.v read_len <= Seq.length dst_data' /\ + Seq.length dst_data' == Seq.length dst_data /\ + (forall (i:nat{i < SZ.v read_len}). + Some? (Seq.index st.contents i) /\ + Seq.index dst_data' i == Some?.v (Seq.index st.contents i))) +{ + unfold (is_circular_buffer cb rm st); + with cbi buf_data repr. _; + let cb_val = !cb; + rewrite (Vec.pts_to cbi.buf buf_data) as (Vec.pts_to cb_val.buf buf_data); + A.pts_to_len dst; + + let mut ri : SZ.t = 0sz; + while (let vi = R.read ri; SZ.lt vi read_len) + invariant exists* (vi: SZ.t) (cur_dst: Seq.seq byte). + B.pts_to cb cbi ** Vec.pts_to cb_val.buf buf_data ** + RM.is_range_vec rm repr ** + A.pts_to dst cur_dst ** + R.pts_to ri vi ** + pure ( + SZ.v vi <= SZ.v read_len /\ + SZ.v cb_val.al > 0 /\ + SZ.v cb_val.al <= pow2_63 /\ + SZ.v cb_val.al == st.alloc_length /\ + SZ.v cb_val.rs == st.read_start /\ + Seq.length buf_data == SZ.v cb_val.al /\ + Seq.length cur_dst == Seq.length dst_data /\ + is_full_vec cb_val.buf /\ + A.is_full_array dst /\ + SZ.v read_len <= SZ.v cb_val.al /\ + SZ.v read_len <= A.length dst /\ + SZ.v read_len <= Seq.length cur_dst /\ + SZ.v read_len <= GT.contiguous_prefix_length st.contents /\ + Spec.phys_log_coherent st.alloc_length buf_data st.contents st.read_start /\ + (forall (k:nat{k < SZ.v vi}). + Some? (Seq.index st.contents k) /\ + Seq.index cur_dst k == Some?.v (Seq.index st.contents k))) + { + let vi = R.read ri; + with cur_dst. assert (A.pts_to dst cur_dst); + lemma_idx_sum_fits cb_val.al cb_val.rs vi; + let pidx = SZ.rem (SZ.add cb_val.rs vi) cb_val.al; + lemma_mod_spec_eq (SZ.v (SZ.add cb_val.rs vi)) (SZ.v cb_val.al); + GT.prefix_elements_are_some st.contents (SZ.v vi); + let byte_val = cb_val.buf.(pidx); + A.op_Array_Assignment dst vi byte_val; + with cur_dst'. assert (A.pts_to dst cur_dst'); + Spec.read_step_invariant (SZ.v cb_val.al) buf_data st.contents st.read_start cur_dst (SZ.v vi) byte_val; + lemma_inc_fits vi read_len; + R.write ri (SZ.add vi 1sz); + }; + + with _vi _cur_dst. _; + rewrite (Vec.pts_to cb_val.buf buf_data) as (Vec.pts_to cbi.buf buf_data); + fold (is_circular_buffer cb rm st); +} +#pop-options + +/// Internal helper: out-of-order write at a relative offset, +/// updates both the physical buffer and the range map, and computes exact new prefix. +#push-options "--z3rlimit_factor 32 --fuel 2 --ifuel 1" +fn write_buffer_core + (cb: circular_buffer) + (rm: RM.range_vec_t) + (offset: SZ.t) + (src: A.array byte) + (write_len: SZ.t) + (#p: perm) + (#src_data: erased (Seq.seq byte)) + (#st: erased Spec.cb_state) + requires + is_circular_buffer cb rm st ** + A.pts_to src #p src_data ** + pure (Spec.cb_wf st /\ + SZ.v write_len == Seq.length src_data /\ + SZ.v write_len == A.length src /\ + SZ.v write_len > 0 /\ + SZ.v offset + SZ.v write_len <= st.alloc_length /\ + SZ.fits (st.base_offset + SZ.v offset + SZ.v write_len)) + ensures exists* st'. + is_circular_buffer cb rm st' ** + A.pts_to src #p src_data ** + pure (Spec.cb_wf st' /\ + st'.base_offset == st.base_offset /\ + st'.virtual_length == st.virtual_length /\ + st'.alloc_length == st.alloc_length /\ + st'.read_start == st.read_start /\ + st'.contents == GT.write_range_contents_t st.contents (SZ.v offset) (reveal src_data) /\ + GT.contiguous_prefix_length st'.contents >= + GT.contiguous_prefix_length st.contents) +{ + unfold (is_circular_buffer cb rm st); + with cbi buf_data repr. _; + let cb_val = !cb; + rewrite (Vec.pts_to cbi.buf buf_data) as (Vec.pts_to cb_val.buf buf_data); + A.pts_to_len src; + + // Write loop: copy src into physical array at (rs + offset + i) % al + let mut wi : SZ.t = 0sz; + while (let vi = R.read wi; SZ.lt vi write_len) + invariant exists* (vi: SZ.t) (cur_phys: Seq.seq byte). + B.pts_to cb cbi ** Vec.pts_to cb_val.buf cur_phys ** + A.pts_to src #p src_data ** + RM.is_range_vec rm repr ** + R.pts_to wi vi ** + pure ( + SZ.v vi <= SZ.v write_len /\ + Seq.length cur_phys == SZ.v cb_val.al /\ + is_full_vec cb_val.buf /\ + SZ.v cb_val.al > 0 /\ + SZ.v cb_val.al <= pow2_63 /\ + SZ.v cb_val.rs == st.read_start /\ + SZ.v cb_val.al == st.alloc_length /\ + SZ.v write_len == Seq.length src_data /\ + SZ.v write_len == A.length src /\ + SZ.v offset + SZ.v write_len <= SZ.v cb_val.al /\ + st.read_start < st.alloc_length /\ + Spec.phys_log_coherent st.alloc_length cur_phys + (GT.write_range_contents st.contents (SZ.v offset) + (Seq.slice (reveal src_data) 0 (SZ.v vi))) + st.read_start) + { + let vi = R.read wi; + with cur_phys. assert (Vec.pts_to cb_val.buf cur_phys); + A.pts_to_len src; + let byte_val = A.op_Array_Access src vi; + let off = SZ.add offset vi; + lemma_idx_sum_fits cb_val.al cb_val.rs off; + let pidx = SZ.rem (SZ.add cb_val.rs off) cb_val.al; + lemma_mod_spec_eq (SZ.v (SZ.add cb_val.rs off)) (SZ.v cb_val.al); + cb_val.buf.(pidx) <- byte_val; + Spec.write_step_coherence (SZ.v cb_val.al) cur_phys st.contents + st.read_start (SZ.v offset) (reveal src_data) (SZ.v vi); + lemma_inc_fits vi write_len; + R.write wi (SZ.add vi 1sz); + }; + + with _vi _cur_phys. _; + // Bridge: Seq.slice data 0 write_len == data + Seq.lemma_eq_intro (Seq.slice (reveal src_data) 0 (SZ.v write_len)) (reveal src_data); + + // Coherence transfer + Spec.write_ooo_coherence_transfer (SZ.v cb_val.al) _cur_phys st.contents + st.read_start (SZ.v offset) (reveal src_data) (SZ.v _vi) (SZ.v write_len); + + // Bridge: total version equals partial version (precondition holds) + GT.write_range_contents_t_eq st.contents (SZ.v offset) (reveal src_data); + + // Prefix monotonicity + GT.write_range_monotone st.contents (SZ.v offset) (reveal src_data); + + // cb_wf preserved + Spec.write_range_preserves_wf st (SZ.v offset) (reveal src_data); + + // Update range map with absolute offset (bo + offset) + let abs_offset = SZ.add cb_val.bo offset; + RM.range_vec_add rm abs_offset write_len; + + // Bridge preservation (using absolute offsets) + RMSpec.add_range_wf repr (SZ.v abs_offset) (SZ.v write_len); + Spec.ranges_match_write repr st.contents (SZ.v cb_val.bo) (SZ.v offset) (reveal src_data); + + // Compute new prefix length from range map using base_offset + let new_pl = RM.range_vec_contiguous_from rm cb_val.bo; + + // Update cb with new pl + let new_cbi = Mkcb_internal cb_val.buf cb_val.rs cb_val.al new_pl cb_val.vl cb_val.bo; + ( := ) cb new_cbi; + rewrite (Vec.pts_to cb_val.buf _cur_phys) as (Vec.pts_to new_cbi.buf _cur_phys); + + // repr_well_positioned preservation + Spec.write_preserves_rwp repr (SZ.v cb_val.bo) (SZ.v offset) (SZ.v write_len); + + // cf == cpl after write + Spec.write_preserves_cf_eq_cpl repr st.contents (SZ.v cb_val.bo) (SZ.v offset) (reveal src_data); + + // Bounded: add_range preserves boundedness + RMSpec.add_range_bounded repr (SZ.v abs_offset) (SZ.v write_len) (SZ.v cb_val.bo + SZ.v cb_val.al); + + // Count bound: first.low >= bo preserved by add_range + RMSpec.add_range_first_low repr (SZ.v abs_offset) (SZ.v write_len) (SZ.v cb_val.bo); + // now: |add_range repr ...| > 0 /\ first'.low >= bo + // so repr_count_bound_gap applies + Spec.repr_count_bound_gap (RMSpec.add_range repr (SZ.v abs_offset) (SZ.v write_len)) + (SZ.v cb_val.bo) (SZ.v cb_val.al); + // gives: 2 * |repr'| <= al + 1 <= pow2_63 + 1, so |repr'| <= pow2_62 < max + + fold (is_circular_buffer cb rm + { st with contents = + GT.write_range_contents_t st.contents (SZ.v offset) (reveal src_data) }); +} +#pop-options + +/// Absolute-offset write with trim, auto-resize, and failure handling. +/// Handles stale writes (no-op), partial overlap trim, auto-resize up to cb_max_length. +/// Returns write_result with wrote/new_data_ready/resize_failed flags. +#push-options "--z3rlimit_factor 32 --fuel 2 --ifuel 1" +fn write_buffer + (cb: circular_buffer) (rm: RM.range_vec_t) + (abs_offset: SZ.t) (src: A.array byte) (write_len: SZ.t) + (#p: perm) + (#src_data: erased (Seq.seq byte)) + (#st: erased Spec.cb_state) + requires + is_circular_buffer cb rm st ** + A.pts_to src #p src_data ** + pure (Spec.cb_wf st /\ + SZ.v write_len == Seq.length src_data /\ + SZ.v write_len == A.length src /\ + SZ.v write_len > 0 /\ + SZ.v abs_offset + SZ.v write_len <= st.base_offset + st.virtual_length /\ + SZ.fits (SZ.v abs_offset + SZ.v write_len)) + returns wr: write_result + ensures exists* st'. + is_circular_buffer cb rm st' ** + A.pts_to src #p src_data ** + pure (Spec.cb_wf st' /\ + st'.base_offset == st.base_offset /\ + st'.virtual_length == st.virtual_length /\ + (not wr.wrote ==> st'.alloc_length == st.alloc_length /\ + st'.read_start == st.read_start /\ + st'.contents == st.contents) /\ + (wr.wrote ==> st'.alloc_length >= st.alloc_length /\ + GT.contiguous_prefix_length st'.contents >= + GT.contiguous_prefix_length st.contents)) +{ + // Step 1: Read base_offset and alloc_length + unfold (is_circular_buffer cb rm st); + with cbi buf_data repr. _; + let cb_val = !cb; + rewrite (Vec.pts_to cbi.buf buf_data) as (Vec.pts_to cb_val.buf buf_data); + let bo = cb_val.bo; + let al = cb_val.al; + let old_pl = cb_val.pl; + + // Step 2: Check stale (abs_end <= base_offset) + let abs_end = SZ.add abs_offset write_len; + if SZ.lte abs_end bo + { + // Fully stale — no-op + rewrite (Vec.pts_to cb_val.buf buf_data) as (Vec.pts_to cbi.buf buf_data); + fold (is_circular_buffer cb rm st); + { wrote = false; new_data_ready = false; resize_failed = false } + } + else + { + // Step 3: Compute trim params + // rel_offset: how far into buffer to start writing + // skip: how many src bytes to skip (already consumed) + let rel_offset : SZ.t = + (if SZ.gte abs_offset bo then SZ.sub abs_offset bo + else 0sz); + let skip : SZ.t = + (if SZ.lt abs_offset bo then SZ.sub bo abs_offset + else 0sz); + let trimmed_len = SZ.sub write_len skip; + + // The furthest point from base_offset the write reaches + let needed = SZ.add rel_offset trimmed_len; + + // Step 4: Check if resize needed + if SZ.gt needed al + { + // Need to resize — check if it fits within cb_max_length + // Compute the needed new_al by doubling + Pow2.next_pow2_ge_le_bound (SZ.v al) (SZ.v needed) st.virtual_length; + // Check if doubling can reach needed within cb_max_length + if SZ.gt needed cb_max_length_sz + { + // Resize would exceed max — return failure + rewrite (Vec.pts_to cb_val.buf buf_data) as (Vec.pts_to cbi.buf buf_data); + fold (is_circular_buffer cb rm st); + { wrote = false; new_data_ready = false; resize_failed = true } + } + else + { + // Compute new_al by doubling loop + let mut nal_ref : SZ.t = al; + while ( + let cur = R.read nal_ref; + SZ.lt cur needed + ) + invariant exists* (nal_v: SZ.t). + B.pts_to cb cbi ** Vec.pts_to cb_val.buf buf_data ** + A.pts_to src #p src_data ** + RM.is_range_vec rm repr ** + R.pts_to nal_ref nal_v ** + pure ( + SZ.v nal_v >= SZ.v al /\ + Pow2.is_pow2 (SZ.v nal_v) /\ + SZ.v nal_v <= st.virtual_length /\ + SZ.v nal_v <= Spec.cb_max_length /\ + SZ.v al > 0 /\ + SZ.v al == st.alloc_length /\ + SZ.v cb_val.rs == st.read_start /\ + Seq.length buf_data == SZ.v al /\ + is_full_vec cb_val.buf /\ + SZ.v al <= pow2_63 /\ + Pow2.is_pow2 st.virtual_length /\ + SZ.v needed <= st.virtual_length /\ + SZ.v needed <= Spec.cb_max_length) + { + let cur = R.read nal_ref; + Pow2.pow2_double_le (SZ.v cur) st.virtual_length; + Pow2.pow2_double_le (SZ.v cur) Spec.cb_max_length; + SZ.fits_lte (SZ.v cur + SZ.v cur) st.virtual_length; + Pow2.doubling_stays_pow2 (SZ.v cur); + R.write nal_ref (SZ.add cur cur); + }; + let new_al = R.read nal_ref; + + // Fold back to call resize + rewrite (Vec.pts_to cb_val.buf buf_data) as (Vec.pts_to cbi.buf buf_data); + fold (is_circular_buffer cb rm st); + resize cb rm new_al; + + // After resize, unfold to write inline + A.pts_to_len src; + Spec.trim_write_in_bounds (SZ.v abs_offset) (SZ.v write_len) st.base_offset (SZ.v new_al); + + unfold (is_circular_buffer cb rm (Spec.resize_result st (SZ.v new_al))); + with cbi2 buf_data2 repr2. _; + let cb_val2 = !cb; + rewrite (Vec.pts_to cbi2.buf buf_data2) as (Vec.pts_to cb_val2.buf buf_data2); + A.pts_to_len src; + + // Write loop: copy src[skip..] into physical array + let mut wi : SZ.t = 0sz; + while (let vi = R.read wi; SZ.lt vi trimmed_len) + invariant exists* (vi: SZ.t) (cur_phys: Seq.seq byte). + B.pts_to cb cbi2 ** Vec.pts_to cb_val2.buf cur_phys ** + A.pts_to src #p src_data ** + RM.is_range_vec rm repr2 ** + R.pts_to wi vi ** + pure ( + SZ.v vi <= SZ.v trimmed_len /\ + Seq.length cur_phys == SZ.v cb_val2.al /\ + is_full_vec cb_val2.buf /\ + SZ.v cb_val2.al > 0 /\ + SZ.v cb_val2.al <= pow2_63 /\ + SZ.v cb_val2.rs < SZ.v cb_val2.al /\ + SZ.v trimmed_len + SZ.v skip == SZ.v write_len /\ + SZ.v write_len == Seq.length src_data /\ + SZ.v write_len == A.length src /\ + SZ.v rel_offset + SZ.v trimmed_len <= SZ.v cb_val2.al /\ + Spec.phys_log_coherent (SZ.v cb_val2.al) cur_phys + (GT.write_range_contents (Spec.resize_result st (SZ.v new_al)).contents + (SZ.v rel_offset) + (Seq.slice (reveal src_data) (SZ.v skip) (SZ.v skip + SZ.v vi))) + (SZ.v cb_val2.rs)) + { + let vi = R.read wi; + with cur_phys. assert (Vec.pts_to cb_val2.buf cur_phys); + A.pts_to_len src; + let src_idx = SZ.add skip vi; + let byte_val = A.op_Array_Access src src_idx; + let off = SZ.add rel_offset vi; + lemma_idx_sum_fits cb_val2.al cb_val2.rs off; + let pidx = SZ.rem (SZ.add cb_val2.rs off) cb_val2.al; + lemma_mod_spec_eq (SZ.v (SZ.add cb_val2.rs off)) (SZ.v cb_val2.al); + cb_val2.buf.(pidx) <- byte_val; + Spec.write_step_coherence (SZ.v cb_val2.al) cur_phys + (Spec.resize_result st (SZ.v new_al)).contents + (SZ.v cb_val2.rs) (SZ.v rel_offset) + (Seq.slice (reveal src_data) (SZ.v skip) (SZ.v write_len)) (SZ.v vi); + lemma_inc_fits vi trimmed_len; + R.write wi (SZ.add vi 1sz); + }; + + with _vi _cur_phys. _; + let trimmed_data = Ghost.hide (Seq.slice (reveal src_data) (SZ.v skip) (SZ.v write_len)); + Seq.lemma_eq_intro + (Seq.slice (reveal src_data) (SZ.v skip) (SZ.v skip + SZ.v trimmed_len)) + (reveal trimmed_data); + Seq.lemma_eq_intro + (Seq.slice (reveal trimmed_data) 0 (SZ.v trimmed_len)) + (reveal trimmed_data); + + let rs_contents = Ghost.hide (Spec.resize_result st (SZ.v new_al)).contents; + Spec.write_ooo_coherence_transfer (SZ.v cb_val2.al) _cur_phys + rs_contents (SZ.v cb_val2.rs) (SZ.v rel_offset) + (reveal trimmed_data) (SZ.v _vi) (SZ.v trimmed_len); + GT.write_range_contents_t_eq rs_contents (SZ.v rel_offset) (reveal trimmed_data); + GT.write_range_monotone rs_contents (SZ.v rel_offset) (reveal trimmed_data); + Spec.resize_prefix_length st.alloc_length (SZ.v new_al) st.contents; + + let new_st_contents = Ghost.hide ( + GT.write_range_contents_t rs_contents (SZ.v rel_offset) (reveal trimmed_data)); + Spec.write_range_preserves_wf (Spec.resize_result st (SZ.v new_al)) + (SZ.v rel_offset) (reveal trimmed_data); + + // Update range map with absolute offset + let rm_abs = SZ.add cb_val2.bo rel_offset; + RM.range_vec_add rm rm_abs trimmed_len; + RMSpec.add_range_wf repr2 (SZ.v rm_abs) (SZ.v trimmed_len); + Spec.ranges_match_write repr2 rs_contents (SZ.v cb_val2.bo) (SZ.v rel_offset) (reveal trimmed_data); + + let new_pl = RM.range_vec_contiguous_from rm cb_val2.bo; + let ndr = SZ.gt new_pl 0sz && SZ.eq old_pl 0sz; + + let new_cbi = Mkcb_internal cb_val2.buf cb_val2.rs cb_val2.al new_pl cb_val2.vl cb_val2.bo; + ( := ) cb new_cbi; + rewrite (Vec.pts_to cb_val2.buf _cur_phys) as (Vec.pts_to new_cbi.buf _cur_phys); + + let rs_st = Ghost.hide (Spec.resize_result st (SZ.v new_al)); + Spec.write_preserves_rwp repr2 (SZ.v cb_val2.bo) (SZ.v rel_offset) (SZ.v trimmed_len); + Spec.write_preserves_cf_eq_cpl repr2 (reveal rs_st).contents (SZ.v cb_val2.bo) (SZ.v rel_offset) (reveal trimmed_data); + + // Bounded: add_range preserves boundedness + RMSpec.add_range_bounded repr2 (SZ.v rm_abs) (SZ.v trimmed_len) (SZ.v cb_val2.bo + SZ.v cb_val2.al); + + // Count bound: first.low >= bo preserved, then derive count < max + RMSpec.add_range_first_low repr2 (SZ.v rm_abs) (SZ.v trimmed_len) (SZ.v cb_val2.bo); + Spec.repr_count_bound_gap (RMSpec.add_range repr2 (SZ.v rm_abs) (SZ.v trimmed_len)) + (SZ.v cb_val2.bo) (SZ.v cb_val2.al); + + fold (is_circular_buffer cb rm + { Spec.resize_result st (SZ.v new_al) with contents = reveal new_st_contents }); + { wrote = true; new_data_ready = ndr; resize_failed = false } + } + } + else + { + // No resize needed — write directly + A.pts_to_len src; + Spec.trim_write_in_bounds (SZ.v abs_offset) (SZ.v write_len) st.base_offset st.alloc_length; + + // Write loop: copy src[skip..] into buffer at (rs + rel_offset + i) % al + let mut wi : SZ.t = 0sz; + while (let vi = R.read wi; SZ.lt vi trimmed_len) + invariant exists* (vi: SZ.t) (cur_phys: Seq.seq byte). + B.pts_to cb cbi ** Vec.pts_to cb_val.buf cur_phys ** + A.pts_to src #p src_data ** + RM.is_range_vec rm repr ** + R.pts_to wi vi ** + pure ( + SZ.v vi <= SZ.v trimmed_len /\ + Seq.length cur_phys == SZ.v cb_val.al /\ + is_full_vec cb_val.buf /\ + SZ.v cb_val.al > 0 /\ + SZ.v cb_val.al <= pow2_63 /\ + SZ.v cb_val.rs == st.read_start /\ + SZ.v cb_val.al == st.alloc_length /\ + SZ.v trimmed_len + SZ.v skip == SZ.v write_len /\ + SZ.v write_len == Seq.length src_data /\ + SZ.v write_len == A.length src /\ + SZ.v rel_offset + SZ.v trimmed_len <= SZ.v cb_val.al /\ + st.read_start < st.alloc_length /\ + Spec.phys_log_coherent st.alloc_length cur_phys + (GT.write_range_contents st.contents (SZ.v rel_offset) + (Seq.slice (reveal src_data) (SZ.v skip) (SZ.v skip + SZ.v vi))) + st.read_start) + { + let vi = R.read wi; + with cur_phys. assert (Vec.pts_to cb_val.buf cur_phys); + A.pts_to_len src; + let src_idx = SZ.add skip vi; + let byte_val = A.op_Array_Access src src_idx; + let off = SZ.add rel_offset vi; + lemma_idx_sum_fits cb_val.al cb_val.rs off; + let pidx = SZ.rem (SZ.add cb_val.rs off) cb_val.al; + lemma_mod_spec_eq (SZ.v (SZ.add cb_val.rs off)) (SZ.v cb_val.al); + cb_val.buf.(pidx) <- byte_val; + Spec.write_step_coherence (SZ.v cb_val.al) cur_phys st.contents + st.read_start (SZ.v rel_offset) (Seq.slice (reveal src_data) (SZ.v skip) (SZ.v write_len)) (SZ.v vi); + lemma_inc_fits vi trimmed_len; + R.write wi (SZ.add vi 1sz); + }; + + with _vi _cur_phys. _; + let trimmed_data = Ghost.hide (Seq.slice (reveal src_data) (SZ.v skip) (SZ.v write_len)); + Seq.lemma_eq_intro + (Seq.slice (reveal src_data) (SZ.v skip) (SZ.v skip + SZ.v trimmed_len)) + (reveal trimmed_data); + Seq.lemma_eq_intro + (Seq.slice (reveal trimmed_data) 0 (SZ.v trimmed_len)) + (reveal trimmed_data); + + Spec.write_ooo_coherence_transfer (SZ.v cb_val.al) _cur_phys st.contents + st.read_start (SZ.v rel_offset) (reveal trimmed_data) (SZ.v _vi) (SZ.v trimmed_len); + GT.write_range_contents_t_eq st.contents (SZ.v rel_offset) (reveal trimmed_data); + GT.write_range_monotone st.contents (SZ.v rel_offset) (reveal trimmed_data); + Spec.write_range_preserves_wf st (SZ.v rel_offset) (reveal trimmed_data); + + let rm_abs = SZ.add cb_val.bo rel_offset; + RM.range_vec_add rm rm_abs trimmed_len; + RMSpec.add_range_wf repr (SZ.v rm_abs) (SZ.v trimmed_len); + Spec.ranges_match_write repr st.contents (SZ.v cb_val.bo) (SZ.v rel_offset) (reveal trimmed_data); + + let new_pl = RM.range_vec_contiguous_from rm cb_val.bo; + let ndr = SZ.gt new_pl 0sz && SZ.eq old_pl 0sz; + + let new_cbi = Mkcb_internal cb_val.buf cb_val.rs cb_val.al new_pl cb_val.vl cb_val.bo; + ( := ) cb new_cbi; + rewrite (Vec.pts_to cb_val.buf _cur_phys) as (Vec.pts_to new_cbi.buf _cur_phys); + + Spec.write_preserves_rwp repr (SZ.v cb_val.bo) (SZ.v rel_offset) (SZ.v trimmed_len); + Spec.write_preserves_cf_eq_cpl repr st.contents (SZ.v cb_val.bo) (SZ.v rel_offset) (reveal trimmed_data); + + // Bounded: add_range preserves boundedness + RMSpec.add_range_bounded repr (SZ.v rm_abs) (SZ.v trimmed_len) (SZ.v cb_val.bo + SZ.v cb_val.al); + + // Count bound: first.low >= bo preserved, then derive count < max + RMSpec.add_range_first_low repr (SZ.v rm_abs) (SZ.v trimmed_len) (SZ.v cb_val.bo); + Spec.repr_count_bound_gap (RMSpec.add_range repr (SZ.v rm_abs) (SZ.v trimmed_len)) + (SZ.v cb_val.bo) (SZ.v cb_val.al); + + fold (is_circular_buffer cb rm + { st with contents = + GT.write_range_contents_t st.contents (SZ.v rel_offset) (reveal trimmed_data) }); + { wrote = true; new_data_ready = ndr; resize_failed = false } + } + } +} +#pop-options + +/// Drain: advance read_start and base_offset, slice contents. +/// The range map is UNCHANGED — this is the key advantage of absolute offsets. +#push-options "--z3rlimit_factor 8 --fuel 2 --ifuel 1" +fn drain + (cb: circular_buffer) (rm: RM.range_vec_t) (n: SZ.t) + (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st ** + pure (Spec.cb_wf st /\ SZ.v n <= st.alloc_length /\ + SZ.v n <= GT.contiguous_prefix_length st.contents /\ + SZ.fits (st.base_offset + SZ.v n)) + returns no_more_data: bool + ensures is_circular_buffer cb rm (Spec.drain_result st (SZ.v n)) ** + pure (no_more_data == (GT.contiguous_prefix_length st.contents = SZ.v n)) +{ + unfold (is_circular_buffer cb rm st); + with cbi buf_data repr. _; + let cb_val = !cb; + rewrite (Vec.pts_to cbi.buf buf_data) as (Vec.pts_to cb_val.buf buf_data); + + // Advance read_start and base_offset + lemma_idx_sum_fits cb_val.al cb_val.rs n; + let temp = SZ.add cb_val.rs n; + let new_rs = SZ.rem temp cb_val.al; + let new_bo = SZ.add cb_val.bo n; + + if (SZ.gt n 0sz) { + // n > 0: drain range vec + fold with drain_repr + RM.range_vec_drain rm new_bo; + + let new_pl = RM.range_vec_contiguous_from rm new_bo; + let new_cbi = Mkcb_internal cb_val.buf new_rs cb_val.al new_pl cb_val.vl new_bo; + ( := ) cb new_cbi; + rewrite (Vec.pts_to cb_val.buf buf_data) as (Vec.pts_to new_cbi.buf buf_data); + + Spec.drain_preserves_coherence st.alloc_length buf_data st.contents st.read_start (SZ.v n); + Spec.ranges_match_drain_repr repr st.contents (SZ.v cb_val.bo) (SZ.v n); + RMSpec.drain_repr_wf repr (SZ.v new_bo); + Spec.drain_repr_preserves_rwp repr (SZ.v cb_val.bo) (SZ.v n); + Spec.rwp_cf_eq_cpl (RMSpec.drain_repr repr (SZ.v new_bo)) + (Spec.drained_contents st.alloc_length st.contents (SZ.v n)) + (SZ.v new_bo); + Spec.drain_prefix_length st.alloc_length st.contents (SZ.v n); + RMSpec.drain_repr_length repr (SZ.v new_bo); + + fold (is_circular_buffer cb rm (Spec.drain_result st (SZ.v n))); + SZ.eq new_pl 0sz + } else { + // n = 0: no drain, fold with original repr + let new_pl = RM.range_vec_contiguous_from rm new_bo; + let new_cbi = Mkcb_internal cb_val.buf new_rs cb_val.al new_pl cb_val.vl new_bo; + ( := ) cb new_cbi; + rewrite (Vec.pts_to cb_val.buf buf_data) as (Vec.pts_to new_cbi.buf buf_data); + + Spec.drain_preserves_coherence st.alloc_length buf_data st.contents st.read_start (SZ.v n); + Spec.ranges_match_drain_padded repr st.contents (SZ.v cb_val.bo) (SZ.v n); + Spec.drain_preserves_rwp repr (SZ.v cb_val.bo) (SZ.v n); + Spec.rwp_cf_eq_cpl repr + (Spec.drained_contents st.alloc_length st.contents (SZ.v n)) + (SZ.v new_bo); + Spec.drain_prefix_length st.alloc_length st.contents (SZ.v n); + + fold (is_circular_buffer cb rm (Spec.drain_result st (SZ.v n))); + SZ.eq new_pl 0sz + } +} +#pop-options + +/// --- Zero-copy Read --- + +/// Core: split the buffer array into read segments, return trade back to raw resources. +/// Shared by all mode wrappers (non-RM, RM, OOO, ...). +#push-options "--z3rlimit_factor 32 --fuel 1 --ifuel 1" +fn read_zerocopy_core + (cb: circular_buffer) + (read_len: SZ.t) + (cbi: cb_internal) + (#buf_data: erased (Seq.seq byte)) + requires + B.pts_to cb cbi ** Vec.pts_to cbi.buf buf_data ** + pure (SZ.v cbi.al > 0 /\ SZ.v cbi.rs < SZ.v cbi.al /\ + SZ.v read_len <= SZ.v cbi.al /\ SZ.v read_len > 0 /\ + SZ.v cbi.al <= pow2_63 /\ is_full_vec cbi.buf /\ + Seq.length buf_data == SZ.v cbi.al /\ + SZ.fits (SZ.v cbi.rs + SZ.v read_len)) + returns rv: read_view + ensures exists* (s1 s2: Seq.seq byte). + zc_segs rv s1 s2 ** + (zc_segs rv s1 s2 @==> (B.pts_to cb cbi ** Vec.pts_to cbi.buf buf_data)) ** + pure ( + SZ.v rv.len1 + SZ.v rv.len2 == SZ.v read_len /\ + SZ.v rv.off1 + SZ.v rv.len1 <= SZ.v cbi.al /\ + SZ.v rv.off2 + SZ.v rv.len2 <= SZ.v cbi.al) +{ + // Convert Vec -> Array -> pts_to_range + to_array_pts_to cbi.buf; + A.pts_to_len (vec_to_array cbi.buf); + PTR.pts_to_range_intro (vec_to_array cbi.buf) 1.0R buf_data; + + // Compute segment boundaries + let rs = cbi.rs; + let al = cbi.al; + let wraps = SZ.gt (SZ.add rs read_len) al; + + if wraps { + // Wrap case: seg1 = [rs, al), seg2 = [0, read_len - (al - rs)) + let len1 = SZ.sub al rs; + let len2 = SZ.sub read_len len1; + + // Split: [0, rs) + [rs, al) + PTR.pts_to_range_split (vec_to_array cbi.buf) 0 (SZ.v rs) (A.length (vec_to_array cbi.buf)); + with s_pre s_post. _; + + // Split [0, rs) into [0, len2) + [len2, rs) + PTR.pts_to_range_split (vec_to_array cbi.buf) 0 (SZ.v len2) (SZ.v rs); + with s2 s_mid. _; + + let rv = Mkread_view (vec_to_array cbi.buf) rs len1 0sz len2; + + // Package trade: segments → raw resources + intro (trade (zc_segs rv s_post s2) + (B.pts_to cb cbi ** Vec.pts_to cbi.buf buf_data)) + #(A.pts_to_range (vec_to_array cbi.buf) (SZ.v len2) (SZ.v rs) s_mid ** + B.pts_to cb cbi) fn _ + { + // Rewrite hyp from rv.* to concrete values + unfold zc_segs; + rewrite + (A.pts_to_range rv.arr (SZ.v rv.off1) (SZ.v rv.off1 + SZ.v rv.len1) s_post) + as (A.pts_to_range (vec_to_array cbi.buf) (SZ.v rs) (SZ.v rs + SZ.v len1) s_post); + rewrite + (A.pts_to_range rv.arr (SZ.v rv.off2) (SZ.v rv.off2 + SZ.v rv.len2) s2) + as (A.pts_to_range (vec_to_array cbi.buf) 0 (SZ.v len2) s2); + // Rejoin [0, len2) + [len2, rs) + PTR.pts_to_range_join (vec_to_array cbi.buf) 0 (SZ.v len2) (SZ.v rs); + // Rejoin [0, rs) + [rs, al) + PTR.pts_to_range_join (vec_to_array cbi.buf) 0 (SZ.v rs) (A.length (vec_to_array cbi.buf)); + // pts_to_range -> pts_to -> Vec + PTR.pts_to_range_elim (vec_to_array cbi.buf) 1.0R (Seq.append (Seq.append s2 s_mid) s_post); + to_vec_pts_to cbi.buf; + with s'. assert (Vec.pts_to cbi.buf s'); + assert (pure (s' `Seq.equal` buf_data)); + rewrite (Vec.pts_to cbi.buf s') as (Vec.pts_to cbi.buf buf_data); + }; + // Rewrite from concrete array to rv.arr for postcondition + rewrite + (A.pts_to_range (vec_to_array cbi.buf) (SZ.v rs) (A.length (vec_to_array cbi.buf)) s_post) + as (A.pts_to_range rv.arr (SZ.v rv.off1) (SZ.v rv.off1 + SZ.v rv.len1) s_post); + rewrite + (A.pts_to_range (vec_to_array cbi.buf) 0 (SZ.v len2) s2) + as (A.pts_to_range rv.arr (SZ.v rv.off2) (SZ.v rv.off2 + SZ.v rv.len2) s2); + fold (zc_segs rv s_post s2); + rv + } else { + // No-wrap case: seg1 = [rs, rs+read_len), seg2 = empty + // Split: [0, rs) + [rs, al) + PTR.pts_to_range_split (vec_to_array cbi.buf) 0 (SZ.v rs) (A.length (vec_to_array cbi.buf)); + with s_pre s_post. _; + + // Split [rs, al) into [rs, rs+read_len) + [rs+read_len, al) + PTR.pts_to_range_split (vec_to_array cbi.buf) (SZ.v rs) (SZ.v rs + SZ.v read_len) (A.length (vec_to_array cbi.buf)); + with s1 s_tail. _; + + let rv = Mkread_view (vec_to_array cbi.buf) rs read_len 0sz 0sz; + + // Create empty pts_to_range for segment 2 + PTR.pts_to_range_split (vec_to_array cbi.buf) 0 0 (SZ.v rs); + with s_empty s_pre'. _; + + // Package trade: segments → raw resources + intro (trade (zc_segs rv s1 s_empty) + (B.pts_to cb cbi ** Vec.pts_to cbi.buf buf_data)) + #(A.pts_to_range (vec_to_array cbi.buf) 0 (SZ.v rs) s_pre' ** + A.pts_to_range (vec_to_array cbi.buf) (SZ.v rs + SZ.v read_len) (A.length (vec_to_array cbi.buf)) s_tail ** + B.pts_to cb cbi) fn _ + { + unfold zc_segs; + rewrite + (A.pts_to_range rv.arr (SZ.v rv.off1) (SZ.v rv.off1 + SZ.v rv.len1) s1) + as (A.pts_to_range (vec_to_array cbi.buf) (SZ.v rs) (SZ.v rs + SZ.v read_len) s1); + rewrite + (A.pts_to_range rv.arr (SZ.v rv.off2) (SZ.v rv.off2 + SZ.v rv.len2) s_empty) + as (A.pts_to_range (vec_to_array cbi.buf) 0 0 s_empty); + // Rejoin [0,0) + [0,rs) + PTR.pts_to_range_join (vec_to_array cbi.buf) 0 0 (SZ.v rs); + // Rejoin [rs, rs+rl) + [rs+rl, al) + PTR.pts_to_range_join (vec_to_array cbi.buf) (SZ.v rs) (SZ.v rs + SZ.v read_len) (A.length (vec_to_array cbi.buf)); + // Rejoin [0, rs) + [rs, al) + PTR.pts_to_range_join (vec_to_array cbi.buf) 0 (SZ.v rs) (A.length (vec_to_array cbi.buf)); + // pts_to_range -> pts_to -> Vec + PTR.pts_to_range_elim (vec_to_array cbi.buf) 1.0R + (Seq.append (Seq.append s_empty s_pre') (Seq.append s1 s_tail)); + to_vec_pts_to cbi.buf; + with s'. assert (Vec.pts_to cbi.buf s'); + assert (pure (s' `Seq.equal` buf_data)); + rewrite (Vec.pts_to cbi.buf s') as (Vec.pts_to cbi.buf buf_data); + }; + // Rewrite from concrete array to rv.arr for postcondition + rewrite + (A.pts_to_range (vec_to_array cbi.buf) (SZ.v rs) (SZ.v rs + SZ.v read_len) s1) + as (A.pts_to_range rv.arr (SZ.v rv.off1) (SZ.v rv.off1 + SZ.v rv.len1) s1); + rewrite + (A.pts_to_range (vec_to_array cbi.buf) 0 0 s_empty) + as (A.pts_to_range rv.arr (SZ.v rv.off2) (SZ.v rv.off2 + SZ.v rv.len2) s_empty); + fold (zc_segs rv s1 s_empty); + rv + } +} +#pop-options + +/// Zero-copy read: unfold is_circular_buffer, call core, compose trade +fn read_zerocopy + (cb: circular_buffer) + (rm: RM.range_vec_t) + (read_len: SZ.t) + (#st: erased Spec.cb_state) + requires + is_circular_buffer cb rm st ** + pure (Spec.cb_wf st /\ + SZ.v read_len <= GT.contiguous_prefix_length st.contents /\ + SZ.v read_len <= st.alloc_length /\ + SZ.v read_len > 0) + returns rv: read_view + ensures exists* (s1 s2: Seq.seq byte). + zc_segs rv s1 s2 ** + (zc_segs rv s1 s2 @==> is_circular_buffer cb rm st) ** + pure ( + SZ.v rv.len1 + SZ.v rv.len2 == SZ.v read_len /\ + SZ.v rv.off1 + SZ.v rv.len1 <= st.alloc_length /\ + SZ.v rv.off2 + SZ.v rv.len2 <= st.alloc_length) +{ + unfold (is_circular_buffer cb rm st); + with cbi buf_data repr. _; + let cb_val = !cb; + rewrite (Vec.pts_to cbi.buf buf_data) as (Vec.pts_to cb_val.buf buf_data); + rewrite (B.pts_to cb cbi) as (B.pts_to cb cb_val); + Spec.ranges_match_prefix_lower repr st.contents (SZ.v cbi.bo); + lemma_idx_sum_fits cb_val.al cb_val.rs read_len; + + let rv = read_zerocopy_core cb read_len cb_val; + with s1 s2. _; + + // Fold trade: raw resources → is_circular_buffer (captures RM as extra) + intro (trade (B.pts_to cb cb_val ** Vec.pts_to cb_val.buf buf_data) + (is_circular_buffer cb rm st)) + #(RM.is_range_vec rm repr) fn _ { + rewrite (B.pts_to cb cb_val) as (B.pts_to cb cbi); + rewrite (Vec.pts_to cb_val.buf buf_data) as (Vec.pts_to cbi.buf buf_data); + fold (is_circular_buffer cb rm st); + }; + + // Compose: (segs @==> raw) ** (raw @==> is_circular_buffer) → (segs @==> is_circular_buffer) + trade_compose + (zc_segs rv s1 s2) + (B.pts_to cb cb_val ** Vec.pts_to cb_val.buf buf_data) + (is_circular_buffer cb rm st); + + rv +} + +/// Release zero-copy read without draining (cancel) +fn release_read + (cb: circular_buffer) + (rm: RM.range_vec_t) + (rv: read_view) + (#st: erased Spec.cb_state) + (#s1 #s2: erased (Seq.seq byte)) + requires + zc_segs rv s1 s2 ** + (zc_segs rv s1 s2 @==> is_circular_buffer cb rm st) + ensures + is_circular_buffer cb rm st +{ + elim_trade (zc_segs rv s1 s2) (is_circular_buffer cb rm st); +} + +/// Release zero-copy read AND drain +#push-options "--z3rlimit_factor 8 --fuel 1 --ifuel 1" +fn drain_after_read + (cb: circular_buffer) + (rm: RM.range_vec_t) + (rv: read_view) + (drain_len: SZ.t) + (#st: erased Spec.cb_state) + (#s1 #s2: erased (Seq.seq byte)) + requires + zc_segs rv s1 s2 ** + (zc_segs rv s1 s2 @==> is_circular_buffer cb rm st) ** + pure (Spec.cb_wf st /\ + SZ.v drain_len <= st.alloc_length /\ + SZ.v drain_len <= GT.contiguous_prefix_length st.contents /\ + SZ.fits (st.base_offset + SZ.v drain_len)) + returns no_more_data: bool + ensures + is_circular_buffer cb rm (Spec.drain_result st (SZ.v drain_len)) ** + pure (no_more_data == (GT.contiguous_prefix_length st.contents = SZ.v drain_len)) +{ + release_read cb rm rv; + drain cb rm drain_len +} +#pop-options diff --git a/lib/pulse/lib/Pulse.Lib.CircularBuffer.fsti b/lib/pulse/lib/Pulse.Lib.CircularBuffer.fsti new file mode 100644 index 000000000..f9ccece07 --- /dev/null +++ b/lib/pulse/lib/Pulse.Lib.CircularBuffer.fsti @@ -0,0 +1,261 @@ +module Pulse.Lib.CircularBuffer + +#lang-pulse +open Pulse.Lib.Pervasives +open FStar.SizeT +open Pulse.Lib.CircularBuffer.Spec +module Seq = FStar.Seq +module SZ = FStar.SizeT +module Spec = Pulse.Lib.CircularBuffer.Spec +module Pow2 = Pulse.Lib.CircularBuffer.Pow2 +module GT = Pulse.Lib.CircularBuffer.GapTrack +module A = Pulse.Lib.Array +module RM = Pulse.Lib.RangeVec +module RMSpec = Pulse.Lib.RangeMap.Spec +open Pulse.Lib.Trade + +/// Pre-computed pow2 63 so Z3 never evaluates Prims.pow2 recursively +let pow2_63 : nat = 0x8000000000000000 + +/// Result of a write operation +noeq type write_result = { + wrote: bool; // true if any bytes were actually written + new_data_ready: bool; // true if new contiguous data became available from position 0 + resize_failed: bool; // true if auto-resize was needed but would exceed cb_max_length +} + +/// Abstract circular buffer type +val circular_buffer : Type0 + +/// Predicate connecting physical buffer to ghost spec state. +/// Always RM-tracked: exact prefix via RangeMap bridge. +val is_circular_buffer + ([@@@mkey]cb: circular_buffer) + (rm: RM.range_vec_t) + (st: Spec.cb_state) : slprop + +/// Create a new empty circular buffer with an empty range map. +fn create + (alloc_len: SZ.t{SZ.v alloc_len > 0}) + (virt_len: SZ.t{SZ.v virt_len > 0}) + requires pure ( + Pow2.is_pow2 (SZ.v alloc_len) /\ + Pow2.is_pow2 (SZ.v virt_len) /\ + SZ.v alloc_len <= SZ.v virt_len /\ + SZ.v alloc_len <= Spec.cb_max_length /\ + SZ.v virt_len <= pow2_63) + returns res : (circular_buffer & RM.range_vec_t) + ensures exists* st. + is_circular_buffer (fst res) (snd res) st ** + pure (Spec.cb_wf st /\ + st.base_offset == 0 /\ + st.alloc_length == SZ.v alloc_len /\ + st.virtual_length == SZ.v virt_len /\ + GT.contiguous_prefix_length st.contents == 0) + +/// Get the length of contiguous readable data +fn read_length + (cb: circular_buffer) (rm: RM.range_vec_t) + (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st + returns n : SZ.t + ensures is_circular_buffer cb rm st ** + pure (SZ.v n == GT.contiguous_prefix_length st.contents) + +/// Get total length: max absolute offset with data +fn get_total_length + (cb: circular_buffer) (rm: RM.range_vec_t) + (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st + returns n: SZ.t + ensures is_circular_buffer cb rm st ** + pure (SZ.v n <= st.base_offset + st.alloc_length) + +/// Drain n bytes from the front (n must not exceed prefix length). +/// The range map is UNCHANGED — this is the key advantage of absolute offsets. +fn drain + (cb: circular_buffer) + (rm: RM.range_vec_t) + (n: SZ.t) + (#st: erased Spec.cb_state) + requires + is_circular_buffer cb rm st ** + pure (Spec.cb_wf st /\ SZ.v n <= st.alloc_length /\ + SZ.v n <= GT.contiguous_prefix_length st.contents /\ + SZ.fits (st.base_offset + SZ.v n)) + returns no_more_data: bool + ensures + is_circular_buffer cb rm (Spec.drain_result st (SZ.v n)) ** + pure (no_more_data == (GT.contiguous_prefix_length st.contents = SZ.v n)) + +/// Resize (grow) the buffer while preserving range map bridge. +fn resize + (cb: circular_buffer) (rm: RM.range_vec_t) (new_al: SZ.t{SZ.v new_al > 0}) + (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st ** + pure (Spec.cb_wf st /\ Pow2.is_pow2 (SZ.v new_al) /\ + SZ.v new_al >= st.alloc_length /\ + SZ.v new_al <= st.virtual_length /\ + SZ.v new_al <= Spec.cb_max_length /\ + SZ.v new_al <= pow2_63) + ensures is_circular_buffer cb rm (Spec.resize_result st (SZ.v new_al)) + +/// Free the circular buffer +fn free + (cb: circular_buffer) + (rm: RM.range_vec_t) + (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st + ensures emp + +/// Get the current alloc_length +fn get_alloc_length + (cb: circular_buffer) + (rm: RM.range_vec_t) + (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st ** pure (Spec.cb_wf st) + returns n : SZ.t + ensures is_circular_buffer cb rm st ** pure (SZ.v n == st.alloc_length) + +/// Increase virtual buffer length +fn set_virtual_length + (cb: circular_buffer) (rm: RM.range_vec_t) (new_vl: SZ.t{SZ.v new_vl > 0}) + (#st: erased Spec.cb_state) + requires is_circular_buffer cb rm st ** + pure (Spec.cb_wf st /\ + Pow2.is_pow2 (SZ.v new_vl) /\ + SZ.v new_vl >= st.virtual_length /\ + SZ.v new_vl <= pow2_63) + ensures is_circular_buffer cb rm ({ st with virtual_length = SZ.v new_vl }) + +/// Write data at an absolute stream offset with trim, auto-resize, and failure handling. +/// Handles stale writes (no-op), partial overlap trim, auto-resize up to cb_max_length. +fn write_buffer + (cb: circular_buffer) (rm: RM.range_vec_t) + (abs_offset: SZ.t) (src: A.array byte) (write_len: SZ.t) + (#p: perm) + (#src_data: erased (Seq.seq byte)) + (#st: erased Spec.cb_state) + requires + is_circular_buffer cb rm st ** + A.pts_to src #p src_data ** + pure (Spec.cb_wf st /\ + SZ.v write_len == Seq.length src_data /\ + SZ.v write_len == A.length src /\ + SZ.v write_len > 0 /\ + SZ.v abs_offset + SZ.v write_len <= st.base_offset + st.virtual_length /\ + SZ.fits (SZ.v abs_offset + SZ.v write_len)) + returns wr: write_result + ensures exists* st'. + is_circular_buffer cb rm st' ** + A.pts_to src #p src_data ** + pure (Spec.cb_wf st' /\ + st'.base_offset == st.base_offset /\ + st'.virtual_length == st.virtual_length /\ + (not wr.wrote ==> st'.alloc_length == st.alloc_length /\ + st'.read_start == st.read_start /\ + st'.contents == st.contents) /\ + (wr.wrote ==> st'.alloc_length >= st.alloc_length /\ + GT.contiguous_prefix_length st'.contents >= + GT.contiguous_prefix_length st.contents)) + +/// Read the contiguous prefix of the buffer into a destination array. +/// Copies read_len bytes from the circular buffer into dst. +/// The circular buffer state is unchanged. +fn read_buffer + (cb: circular_buffer) + (rm: RM.range_vec_t) + (dst: A.array byte) + (read_len: SZ.t) + (#dst_data: erased (Seq.seq byte)) + (#st: erased Spec.cb_state) + requires + is_circular_buffer cb rm st ** + A.pts_to dst dst_data ** + pure (Spec.cb_wf st /\ + SZ.v read_len <= GT.contiguous_prefix_length st.contents /\ + SZ.v read_len <= st.alloc_length /\ + SZ.v read_len <= A.length dst /\ + A.is_full_array dst) + ensures exists* (dst_data': Seq.seq byte). + is_circular_buffer cb rm st ** + A.pts_to dst dst_data' ** + pure (Spec.cb_wf st /\ + SZ.v read_len <= Seq.length st.contents /\ + SZ.v read_len <= Seq.length dst_data' /\ + Seq.length dst_data' == Seq.length dst_data /\ + (forall (i:nat{i < SZ.v read_len}). + Some? (Seq.index st.contents i) /\ + Seq.index dst_data' i == Some?.v (Seq.index st.contents i))) + +/// --- Zero-copy Read --- + +/// Return type for zero-copy read: references into the buffer's internal array. +noeq type read_view = { + arr: A.array byte; // The underlying buffer array + off1: SZ.t; // Start of segment 1 + len1: SZ.t; // Length of segment 1 + off2: SZ.t; // Start of segment 2 (0 if no wrap) + len2: SZ.t; // Length of segment 2 (0 if no wrap) +} + +/// Abbreviation for the two read-segment slprops +let zc_segs (rv: read_view) (s1 s2: Seq.seq byte) : slprop = + A.pts_to_range rv.arr (SZ.v rv.off1) (SZ.v rv.off1 + SZ.v rv.len1) s1 ** + A.pts_to_range rv.arr (SZ.v rv.off2) (SZ.v rv.off2 + SZ.v rv.len2) s2 + +/// Zero-copy read: returns segment pointers into the internal buffer, +/// plus a trade that restores the buffer when the segments are returned. +/// Up to 2 segments for wrap-around (like MsQuic's QuicRecvBufferRead). +fn read_zerocopy + (cb: circular_buffer) + (rm: RM.range_vec_t) + (read_len: SZ.t) + (#st: erased Spec.cb_state) + requires + is_circular_buffer cb rm st ** + pure (Spec.cb_wf st /\ + SZ.v read_len <= GT.contiguous_prefix_length st.contents /\ + SZ.v read_len <= st.alloc_length /\ + SZ.v read_len > 0) + returns rv: read_view + ensures exists* (s1 s2: Seq.seq byte). + zc_segs rv s1 s2 ** + (zc_segs rv s1 s2 @==> is_circular_buffer cb rm st) ** + pure ( + SZ.v rv.len1 + SZ.v rv.len2 == SZ.v read_len /\ + SZ.v rv.off1 + SZ.v rv.len1 <= st.alloc_length /\ + SZ.v rv.off2 + SZ.v rv.len2 <= st.alloc_length) + +/// Release zero-copy read without draining (cancel) +fn release_read + (cb: circular_buffer) + (rm: RM.range_vec_t) + (rv: read_view) + (#st: erased Spec.cb_state) + (#s1 #s2: erased (Seq.seq byte)) + requires + zc_segs rv s1 s2 ** + (zc_segs rv s1 s2 @==> is_circular_buffer cb rm st) + ensures + is_circular_buffer cb rm st + +/// Release zero-copy read AND drain +fn drain_after_read + (cb: circular_buffer) + (rm: RM.range_vec_t) + (rv: read_view) + (drain_len: SZ.t) + (#st: erased Spec.cb_state) + (#s1 #s2: erased (Seq.seq byte)) + requires + zc_segs rv s1 s2 ** + (zc_segs rv s1 s2 @==> is_circular_buffer cb rm st) ** + pure (Spec.cb_wf st /\ + SZ.v drain_len <= st.alloc_length /\ + SZ.v drain_len <= GT.contiguous_prefix_length st.contents /\ + SZ.fits (st.base_offset + SZ.v drain_len)) + returns no_more_data: bool + ensures + is_circular_buffer cb rm (Spec.drain_result st (SZ.v drain_len)) ** + pure (no_more_data == (GT.contiguous_prefix_length st.contents = SZ.v drain_len)) diff --git a/lib/pulse/lib/Pulse.Lib.RangeMap.Spec.fst b/lib/pulse/lib/Pulse.Lib.RangeMap.Spec.fst new file mode 100644 index 000000000..2dd5fb087 --- /dev/null +++ b/lib/pulse/lib/Pulse.Lib.RangeMap.Spec.fst @@ -0,0 +1,1671 @@ +module Pulse.Lib.RangeMap.Spec + +/// Spec for a range set — sorted non-overlapping, non-adjacent intervals. +/// Models MsQuic's QUIC_RANGE (WrittenRanges) for tracking received byte offsets. + +module Seq = FStar.Seq + +(*** Types ***) + +/// An interval [low, low+count) +noeq type interval = { low: nat; count: pos } + +/// Upper bound (exclusive) of an interval +let high (iv: interval) : nat = iv.low + iv.count + +(*** Well-formedness ***) + +/// Two intervals are non-overlapping and non-adjacent, and sorted +let separated (a b: interval) : prop = + high a < b.low // gap between a and b (not adjacent, not overlapping) + +/// A range set is a sorted sequence of separated intervals +let rec range_map_wf (s: Seq.seq interval) + : Tot prop (decreases Seq.length s) = + if Seq.length s <= 1 then True + else + separated (Seq.index s 0) (Seq.index s 1) /\ + range_map_wf (Seq.tail s) + +type range_map = s:Seq.seq interval{range_map_wf s} + +(*** Membership ***) + +/// An offset is covered by an interval +let in_interval (iv: interval) (offset: nat) : bool = + iv.low <= offset && offset < high iv + +/// An offset is covered by some interval in the range set +let rec mem (s: Seq.seq interval) (offset: nat) + : Tot bool (decreases Seq.length s) = + if Seq.length s = 0 then false + else in_interval (Seq.index s 0) offset || mem (Seq.tail s) offset + +(*** Core operations ***) + +/// Length of contiguous coverage starting from offset 0. +/// If the first interval starts at 0, returns its count; otherwise 0. +let contiguous_from_zero (s: Seq.seq interval) : nat = + if Seq.length s = 0 then 0 + else + let first = Seq.index s 0 in + if first.low = 0 then first.count + else 0 + +/// Length of contiguous coverage starting from a given base offset. +/// If the first interval covers base, returns high(first) - base; otherwise 0. +let contiguous_from (s: Seq.seq interval) (base: nat) : nat = + if Seq.length s = 0 then 0 + else + let first = Seq.index s 0 in + if first.low <= base && base < high first then high first - base + else 0 + +/// Check if interval overlaps or is adjacent to [offset, offset+len) +let mergeable (iv: interval) (offset: nat) (len: pos) : bool = + not (high iv < offset || offset + len < iv.low) + +/// Merge interval [offset, offset+len) into sorted range set +let rec add_range (s: Seq.seq interval) (offset: nat) (len: pos) + : Tot (Seq.seq interval) (decreases Seq.length s) = + if Seq.length s = 0 then + Seq.create 1 ({ low = offset; count = len }) + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + if offset + len < hd.low then + // New interval goes entirely before hd (no overlap/adjacency) + Seq.cons ({ low = offset; count = len }) s + else if high hd < offset then + // hd is entirely before new interval, keep hd, recurse on tail + Seq.cons hd (add_range tl offset len) + else + // Overlap or adjacency — merge + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + // Continue merging with tail (the merged interval might overlap more) + add_range tl merged_low (merged_high - merged_low) + +/// Drain n bytes: shift all intervals left by n, drop/trim those below 0 +let rec drain_ranges (s: Seq.seq interval) (n: nat) + : Tot (Seq.seq interval) (decreases Seq.length s) = + if Seq.length s = 0 then Seq.empty + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + if high hd <= n then + // Entire interval is drained + drain_ranges tl n + else if hd.low < n then + // Partially drained — trim the front + Seq.cons ({ low = 0; count = high hd - n }) (drain_ranges tl n) + else + // Shift left by n + Seq.cons ({ low = hd.low - n; count = hd.count }) (drain_ranges tl n) + +(*** Lemmas ***) + +/// Helper: range_map_wf is preserved by tail +let range_map_wf_tail (s: Seq.seq interval) + : Lemma (requires range_map_wf s /\ Seq.length s > 0) + (ensures range_map_wf (Seq.tail s)) = + () + +/// Helper: cons preserves range_map_wf if head is separated from new head +let range_map_wf_cons (hd: interval) (tl: Seq.seq interval) + : Lemma (requires (Seq.length tl = 0 \/ (range_map_wf tl /\ separated hd (Seq.index tl 0)))) + (ensures range_map_wf (Seq.cons hd tl)) = + let s = Seq.cons hd tl in + assert (Seq.length s = Seq.length tl + 1); + if Seq.length tl = 0 then + assert (Seq.length s = 1) + else ( + assert (Seq.length s > 1); + assert (Seq.length tl > 0); + // Need to show: separated (Seq.index s 0) (Seq.index s 1) /\ range_map_wf (Seq.tail s) + assert (Seq.index s 0 == hd); + assert (Seq.index s 1 == Seq.index tl 0); + assert (Seq.tail s `Seq.equal` tl) + ) + +/// drain_ranges preserves well-formedness +let rec drain_ranges_wf (s: Seq.seq interval) (n: nat) + : Lemma (requires range_map_wf s) + (ensures range_map_wf (drain_ranges s n)) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + range_map_wf_tail s; + drain_ranges_wf tl n; + if high hd <= n then () + else if hd.low < n then + let drained_tl = drain_ranges tl n in + if Seq.length drained_tl > 0 then ( + let new_hd = { low = 0; count = high hd - n } in + let tl_hd = Seq.index drained_tl 0 in + assert (separated hd (Seq.index tl 0)); + assert (high hd < (Seq.index tl 0).low); + assert (tl_hd.low = (Seq.index tl 0).low - n); + assert (high new_hd = high hd - n); + assert (high new_hd < (Seq.index tl 0).low - n); + assert (high new_hd < tl_hd.low); + range_map_wf_cons new_hd drained_tl + ) else () + else + let drained_tl = drain_ranges tl n in + if Seq.length drained_tl > 0 then ( + let new_hd = { low = hd.low - n; count = hd.count } in + let tl_hd = Seq.index drained_tl 0 in + assert (separated hd (Seq.index tl 0)); + assert (high new_hd = high hd - n); + range_map_wf_cons new_hd drained_tl + ) else () + +/// add_range preserves well-formedness +/// Helper: add_range preserves lower bounds +let rec add_range_first_lower_bound (s: Seq.seq interval) (offset: nat) (len: pos) + : Lemma (ensures (let r = add_range s offset len in + Seq.length r > 0 ==> + (Seq.length s = 0 ==> (Seq.index r 0).low = offset) /\ + (Seq.length s > 0 ==> (Seq.index r 0).low <= (Seq.index s 0).low /\ + (Seq.index r 0).low <= offset))) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + if offset + len < hd.low then () + else if high hd < offset then + add_range_first_lower_bound tl offset len + else + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + add_range_first_lower_bound tl merged_low (merged_high - merged_low) + +/// Helper: add_range respects lower bound +let rec add_range_respects_lower_bound (s: Seq.seq interval) (offset: nat) (len: pos) (iv: interval) + : Lemma (requires Seq.length s > 0 /\ + range_map_wf s /\ + high iv < (Seq.index s 0).low /\ + high iv < offset) + (ensures (let r = add_range s offset len in + Seq.length r > 0 ==> high iv < (Seq.index r 0).low)) + (decreases Seq.length s) = + add_range_first_lower_bound s offset len; + let hd = Seq.index s 0 in + let tl = Seq.tail s in + if offset + len < hd.low then ( + // Result is [new_interval; ...s], so first element has low = offset + () + ) else if high hd < offset then ( + // Result is [hd; ...add_range tl offset len] + // The first element is hd, and we know high iv < hd.low + () + ) else ( + // Merging case: the result comes from recursing on tl with merged interval + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + let result_tl = add_range tl merged_low (merged_high - merged_low) in + // We need to show that if result_tl is non-empty, high iv < (Seq.index result_tl 0).low + // We know: high iv < hd.low and high iv < offset + // Therefore: high iv < merged_low + if Seq.length tl > 0 then ( + // From separated: high hd < (Seq.index tl 0).low + // From transitivity: high iv < hd.low < hd.low + hd.count = high hd < (Seq.index tl 0).low + assert (separated hd (Seq.index tl 0)); + range_map_wf_tail s; + assert (high hd == hd.low + hd.count); + assert (high iv < hd.low); + assert (hd.low < high hd); + assert (high iv < high hd); + assert (high hd < (Seq.index tl 0).low); + assert (high iv < (Seq.index tl 0).low); + assert (high iv < merged_low); + add_range_respects_lower_bound tl merged_low (merged_high - merged_low) iv + ) else ( + // tl is empty, so result_tl is just the merged interval + assert ((Seq.index result_tl 0).low = merged_low); + assert (high iv < merged_low) + ) + ) + +let rec add_range_wf (s: Seq.seq interval) (offset: nat) (len: pos) + : Lemma (requires range_map_wf s) + (ensures range_map_wf (add_range s offset len)) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + range_map_wf_tail s; + if offset + len < hd.low then + let new_iv = { low = offset; count = len } in + assert (high new_iv = offset + len); + assert (high new_iv < hd.low); + assert (separated new_iv hd); + range_map_wf_cons new_iv s + else if high hd < offset then ( + add_range_wf tl offset len; + let result = add_range tl offset len in + if Seq.length result > 0 then ( + if Seq.length tl > 0 then ( + add_range_respects_lower_bound tl offset len hd; + assert (separated hd (Seq.index result 0)) + ) else (); + range_map_wf_cons hd result + ) else () + ) else + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + add_range_wf tl merged_low (merged_high - merged_low) + +/// Helper: membership in tail +let mem_tail (s: Seq.seq interval) (offset: nat) + : Lemma (requires Seq.length s > 0 /\ mem s offset /\ not (in_interval (Seq.index s 0) offset)) + (ensures mem (Seq.tail s) offset) = + () + +/// Helper: membership after cons +let mem_cons (hd: interval) (tl: Seq.seq interval) (offset: nat) + : Lemma (ensures mem (Seq.cons hd tl) offset <==> (in_interval hd offset || mem tl offset)) = + let s = Seq.cons hd tl in + assert (Seq.length s > 0); + assert (Seq.index s 0 == hd); + assert (Seq.tail s `Seq.equal` tl) + +/// add_range includes all offsets in the added range +let rec add_range_mem_new (s: Seq.seq interval) (offset: nat) (len: pos) (x: nat) + : Lemma (requires offset <= x /\ x < offset + len) + (ensures mem (add_range s offset len) x) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + if offset + len < hd.low then + mem_cons ({ low = offset; count = len }) s x + else if high hd < offset then ( + add_range_mem_new tl offset len x; + mem_cons hd (add_range tl offset len) x + ) else + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + assert (merged_low <= offset); + assert (merged_high >= offset + len); + assert (merged_low <= x); + assert (x < merged_high); + add_range_mem_new tl merged_low (merged_high - merged_low) x + +/// add_range preserves existing members +let rec add_range_mem_old (s: Seq.seq interval) (offset: nat) (len: pos) (x: nat) + : Lemma (requires mem s x) + (ensures mem (add_range s offset len) x) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + if in_interval hd x then ( + if offset + len < hd.low then + mem_cons ({ low = offset; count = len }) s x + else if high hd < offset then + mem_cons hd (add_range tl offset len) x + else + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + assert (merged_low <= hd.low); + assert (merged_high >= high hd); + assert (merged_low <= x); + assert (x < merged_high); + add_range_mem_new tl merged_low (merged_high - merged_low) x + ) else ( + mem_tail s x; + assert (mem tl x); + if offset + len < hd.low then + mem_cons ({ low = offset; count = len }) s x + else if high hd < offset then ( + add_range_mem_old tl offset len x; + mem_cons hd (add_range tl offset len) x + ) else + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + add_range_mem_old tl merged_low (merged_high - merged_low) x + ) + +/// add_range converse: if x is in the result, then either x was in [offset, offset+len) or in s +let rec add_range_mem_inv (s: Seq.seq interval) (offset: nat) (len: pos) (x: nat) + : Lemma (requires mem (add_range s offset len) x) + (ensures (offset <= x /\ x < offset + len) \/ mem s x) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + if offset + len < hd.low then ( + // Result is cons {offset,len} s + mem_cons ({ low = offset; count = len }) s x; + if in_interval ({ low = offset; count = len }) x then () + else () // mem s x + ) else if high hd < offset then ( + // Result is cons hd (add_range tl offset len) + mem_cons hd (add_range tl offset len) x; + if in_interval hd x then ( + mem_cons hd tl x + ) else ( + add_range_mem_inv tl offset len x; + if mem tl x then + mem_cons hd tl x + else () + ) + ) else ( + // Merge case + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + let new_len : pos = merged_high - merged_low in + add_range_mem_inv tl merged_low new_len x; + if merged_low <= x && x < merged_high then ( + // x is in merged range — either in [offset, offset+len) or in hd + if offset <= x && x < offset + len then () + else ( + // x must be in hd + assert (hd.low <= x /\ x < high hd); + mem_cons hd tl x + ) + ) else ( + // x was in tl + mem_cons hd tl x + ) + ) + +/// In a well-formed range map, any member of the tail is >= high of head. +/// (All tail intervals are separated from head, so their low > high head.) +let rec mem_wf_tail_ge (s: Seq.seq interval) (x: nat) + : Lemma (requires range_map_wf s /\ Seq.length s > 0 /\ + mem (Seq.tail s) x) + (ensures x >= high (Seq.index s 0)) + (decreases Seq.length s) = + let hd = Seq.index s 0 in + let tl = Seq.tail s in + if Seq.length tl = 0 then () + else + let tl_hd = Seq.index tl 0 in + assert (separated hd tl_hd); + assert (tl_hd.low > high hd); + if in_interval tl_hd x then + assert (x >= tl_hd.low) + else ( + mem_tail tl x; + range_map_wf_tail s; + mem_wf_tail_ge tl x; + assert (x >= high tl_hd); + assert (high tl_hd > high hd) + ) + +/// Strict version: member of tail is strictly greater than high of head +let mem_wf_tail_gt (s: Seq.seq interval) (x: nat) + : Lemma (requires range_map_wf s /\ Seq.length s > 0 /\ + Seq.length (Seq.tail s) > 0 /\ + mem (Seq.tail s) x) + (ensures x > high (Seq.index s 0)) + = let hd = Seq.index s 0 in + let tl = Seq.tail s in + let tl_hd = Seq.index tl 0 in + assert (separated hd tl_hd); + assert (tl_hd.low > high hd); + range_map_wf_tail s; + if in_interval tl_hd x then + assert (x >= tl_hd.low) + else ( + mem_tail tl x; + mem_wf_tail_ge tl x; + assert (x >= high tl_hd); + assert (high tl_hd > tl_hd.low); + assert (tl_hd.low > high hd) + ) + +/// Positions below the first interval's low are not members (wf ensures sorted order) +let rec mem_not_below_first (s: Seq.seq interval) (x: nat) + : Lemma (requires range_map_wf s /\ Seq.length s > 0 /\ x < (Seq.index s 0).low) + (ensures not (mem s x)) + (decreases Seq.length s) + = let hd = Seq.index s 0 in + assert (not (in_interval hd x)); + let tl = Seq.tail s in + if Seq.length tl = 0 then () + else ( + range_map_wf_tail s; + assert (separated hd (Seq.index tl 0)); + assert ((Seq.index tl 0).low > high hd); + assert ((Seq.index tl 0).low > x); + mem_not_below_first tl x + ) + +/// All interval endpoints are bounded by a given value +let range_map_bounded (s: Seq.seq interval) (bound: nat) : prop = + forall (i:nat{i < Seq.length s}). high (Seq.index s i) <= bound + +/// add_range preserves boundedness +let rec add_range_bounded (s: Seq.seq interval) (offset: nat) (len: pos) (bound: nat) + : Lemma (requires range_map_bounded s bound /\ offset + len <= bound) + (ensures range_map_bounded (add_range s offset len) bound) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + assert (high hd <= bound); + let bounded_tail () : Lemma (range_map_bounded tl bound) = + let aux (i:nat{i < Seq.length tl}) : Lemma (high (Seq.index tl i) <= bound) = + assert (Seq.index tl i == Seq.index s (i + 1)) + in + FStar.Classical.forall_intro aux + in + bounded_tail (); + if offset + len < hd.low then ( + // Inserted before: new first is {low=offset; count=len}, rest is s + let r = add_range s offset len in + let aux (i:nat{i < Seq.length r}) : Lemma (high (Seq.index r i) <= bound) = + if i = 0 then assert (high (Seq.index r 0) == offset + len) + else assert (Seq.index r i == Seq.index s (i - 1)) + in + FStar.Classical.forall_intro aux + ) else if high hd < offset then ( + // Keep hd, recurse on tail + add_range_bounded tl offset len bound; + let r = add_range s offset len in + let r_tl = add_range tl offset len in + let aux (i:nat{i < Seq.length r}) : Lemma (high (Seq.index r i) <= bound) = + if i = 0 then assert (Seq.index r 0 == hd) + else assert (Seq.index r i == Seq.index r_tl (i - 1)) + in + FStar.Classical.forall_intro aux + ) else ( + // Merge: merged_high = max(offset+len, high hd) <= bound + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + assert (merged_high <= bound); + add_range_bounded tl merged_low (merged_high - merged_low) bound + ) + +/// contiguous_from is bounded by range_map_bounded +let cf_bounded (s: Seq.seq interval) (base: nat) (bound: nat) + : Lemma (requires range_map_bounded s bound /\ base <= bound) + (ensures contiguous_from s base <= bound - base) + = if Seq.length s = 0 then () + else ( + let first = Seq.index s 0 in + assert (high first <= bound); + if first.low <= base && base < high first then + assert (contiguous_from s base == high first - base) + else () + ) + +/// Positions beyond bound are not members (when range_map_bounded holds) +let rec mem_bounded (s: Seq.seq interval) (bound: nat) (x: nat) + : Lemma (requires range_map_bounded s bound /\ x >= bound) + (ensures not (mem s x)) + (decreases Seq.length s) + = if Seq.length s = 0 then () + else ( + let first = Seq.index s 0 in + assert (high first <= bound); + assert (x >= bound); + assert (not (in_interval first x)); + let tl = Seq.tail s in + let bounded_tail () : Lemma (range_map_bounded tl bound) = + let aux (i:nat{i < Seq.length tl}) : Lemma (high (Seq.index tl i) <= bound) = + assert (Seq.index tl i == Seq.index s (i + 1)) + in + FStar.Classical.forall_intro aux + in + bounded_tail (); + mem_bounded tl bound x + ) + +/// range_map_bounded is monotone in the bound +let range_map_bounded_monotone (s: Seq.seq interval) (bound1: nat) (bound2: nat) + : Lemma (requires range_map_bounded s bound1 /\ bound1 <= bound2) + (ensures range_map_bounded s bound2) + = () + +/// add_range result is non-empty +let rec add_range_nonempty (s: Seq.seq interval) (offset: nat) (len: pos) + : Lemma (ensures Seq.length (add_range s offset len) > 0) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else + let tl = Seq.tail s in + let hd = Seq.index s 0 in + if offset + len < hd.low then () + else if high hd < offset then + add_range_nonempty tl offset len + else + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + add_range_nonempty tl merged_low (merged_high - merged_low) + +/// Maximum endpoint of any interval in the range map (0 if empty) +let range_map_max_endpoint (s: Seq.seq interval) : nat = + if Seq.length s = 0 then 0 + else high (Seq.index s (Seq.length s - 1)) + +/// range_map_max_endpoint is bounded by range_map_bounded +let range_map_max_endpoint_bounded (s: Seq.seq interval) (bound: nat) + : Lemma (requires range_map_bounded s bound) + (ensures range_map_max_endpoint s <= bound) = () + +/// contiguous_from > 0 implies base_aligned (first interval covers base) +let contiguous_from_implies_base_aligned (s: Seq.seq interval) (base: nat) + : Lemma (requires contiguous_from s base > 0) + (ensures Seq.length s > 0 /\ + (let first = Seq.index s 0 in first.low <= base /\ base <= high first)) = () + +/// contiguous_from decreases linearly when advancing the base within the first interval +let contiguous_from_after_advance (s: Seq.seq interval) (base: nat) (n: nat) + : Lemma (requires contiguous_from s base > 0 /\ n <= contiguous_from s base) + (ensures contiguous_from s (base + n) == contiguous_from s base - n) = () + +/// add_range preserves base_aligned when existing base_aligned holds and offset >= base +/// (i.e., the add can only merge/extend the first interval or append after it) + +/// Helper: after add_range with offset <= first element's low (or empty seq), +/// result's first has high >= offset + len +#push-options "--z3rlimit_factor 4" +let rec add_range_first_high_bound (s: Seq.seq interval) (offset: nat) (len: pos) + : Lemma (requires range_map_wf s /\ (Seq.length s = 0 \/ offset <= (Seq.index s 0).low)) + (ensures (let r = add_range s offset len in + Seq.length r > 0 /\ + high (Seq.index r 0) >= offset + len)) + (decreases Seq.length s) + = if Seq.length s = 0 then () + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + range_map_wf_tail s; + if offset + len < hd.low then () // insert before: result_first = {offset, len} + else ( + // offset <= hd.low and offset + len >= hd.low, so merge (not "keep first" since offset <= hd.low) + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + assert (merged_low == offset); // since offset <= hd.low + if Seq.length tl = 0 then () + else ( + assert (separated hd (Seq.index tl 0)); + assert ((Seq.index tl 0).low > high hd); + assert (merged_low <= (Seq.index tl 0).low); + add_range_first_high_bound tl merged_low (merged_high - merged_low) + ) + ) +#pop-options + +let rec add_range_base_aligned + (s: Seq.seq interval) (base offset: nat) (len: pos) + : Lemma + (requires range_map_wf s /\ + Seq.length s > 0 /\ + (let first = Seq.index s 0 in + first.low <= base /\ base <= high first) /\ + offset >= (Seq.index s 0).low) + (ensures ( + Seq.length (add_range s offset len) > 0 /\ + (Seq.index (add_range s offset len) 0).low <= base /\ + base <= high (Seq.index (add_range s offset len) 0))) + (decreases Seq.length s) + = let first = Seq.index s 0 in + let tl = Seq.tail s in + range_map_wf_tail s; + if offset + len < first.low then () // can't happen: offset >= first.low + else if high first < offset then ( + // Keep first, recurse on tail — base_aligned trivially preserved (first unchanged) + add_range_wf tl offset len; + () + ) else ( + // Merge: merged_low = min(offset, first.low) = first.low (since offset >= first.low) + let merged_low = if offset < first.low then offset else first.low in + let merged_high = if offset + len > high first then offset + len else high first in + // merged_low = first.low <= base, merged_high >= high first >= base + add_range_first_lower_bound tl merged_low (merged_high - merged_low); + add_range_first_high_bound tl merged_low (merged_high - merged_low) + ) + +/// Gap state: if first.low > base, contiguous_from is 0 +let contiguous_from_gap (s: Seq.seq interval) (base: nat) + : Lemma (requires Seq.length s > 0 /\ (Seq.index s 0).low > base) + (ensures contiguous_from s base == 0) = () + +/// add_range preserves gap state: if all intervals start above base and offset > base, +/// result's first element also starts above base +let rec add_range_preserves_gap + (s: Seq.seq interval) (base offset: nat) (len: pos) + : Lemma (requires range_map_wf s /\ offset > base /\ + (Seq.length s = 0 \/ (Seq.index s 0).low > base)) + (ensures (let r = add_range s offset len in + Seq.length r > 0 /\ + (Seq.index r 0).low > base)) + (decreases Seq.length s) + = if Seq.length s = 0 then () + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + range_map_wf_tail s; + if offset + len < hd.low then () // insert before: new first = {offset, len}, offset > base ✓ + else if high hd < offset then ( + // keep first (hd), hd.low > base ✓ + () + ) else ( + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + // merged_low = min(offset, hd.low). Both > base. So merged_low > base. + assert (merged_low > base); + if Seq.length tl = 0 then () + else ( + assert (separated hd (Seq.index tl 0)); + assert ((Seq.index tl 0).low > high hd); + add_range_preserves_gap tl base merged_low (merged_high - merged_low) + ) + ) + +/// add_range at exactly base establishes base_aligned when all existing intervals are above base +let add_range_at_base_establishes_aligned + (s: Seq.seq interval) (base: nat) (len: pos) + : Lemma (requires range_map_wf s /\ + (Seq.length s = 0 \/ (Seq.index s 0).low > base)) + (ensures (let r = add_range s base len in + Seq.length r > 0 /\ + (Seq.index r 0).low <= base /\ + base <= high (Seq.index r 0))) + = add_range_first_lower_bound s base len; + add_range_first_high_bound s base len + +(*** Lemmas bridging add_range to imperative implementation ***) + +/// When all intervals have high < offset, add_range appends at the end +#push-options "--fuel 2 --ifuel 1 --z3rlimit 40" +let rec add_range_all_before (s: Seq.seq interval) (offset: nat) (len: pos) + : Lemma (requires range_map_wf s /\ + (forall (i:nat). i < Seq.length s ==> high (Seq.index s i) < offset)) + (ensures add_range s offset len == Seq.snoc s ({ low = offset; count = len })) + (decreases Seq.length s) + = let iv = { low = offset; count = len } in + if Seq.length s = 0 then ( + // Base case: empty sequence + // add_range s offset len = Seq.create 1 iv + // Seq.snoc s iv = Seq.snoc Seq.empty iv = Seq.create 1 iv + assert (add_range s offset len `Seq.equal` Seq.create 1 iv); + assert (Seq.snoc s iv `Seq.equal` Seq.create 1 iv) + ) else ( + // Inductive case: s is non-empty + let hd = Seq.index s 0 in + let tl = Seq.tail s in + + // From precondition, high hd < offset (using i=0) + assert (high hd < offset); + + // By definition of add_range, since high hd < offset: + // add_range s offset len = Seq.cons hd (add_range tl offset len) + + // Establish precondition for tail: + // forall i < Seq.length tl. high (Seq.index tl i) < offset + let tail_pre () : Lemma (forall (i:nat). i < Seq.length tl ==> high (Seq.index tl i) < offset) = + let aux (i:nat{i < Seq.length tl}) : Lemma (high (Seq.index tl i) < offset) = + assert (Seq.index tl i == Seq.index s (i + 1)) + in + FStar.Classical.forall_intro aux + in + tail_pre (); + + // Apply IH to tail + range_map_wf_tail s; + add_range_all_before tl offset len; + + // Now we have: add_range tl offset len == Seq.snoc tl iv + // So: add_range s offset len = Seq.cons hd (Seq.snoc tl iv) + // Goal: Seq.cons hd (Seq.snoc tl iv) == Seq.snoc s iv + + // We need to show Seq.cons hd (Seq.snoc tl iv) == Seq.snoc (Seq.cons hd tl) iv + // and Seq.cons hd tl == s + + let result_lhs = Seq.cons hd (Seq.snoc tl iv) in + let result_rhs = Seq.snoc s iv in + + // Show sequences are equal by extensionality + let len_eq () : Lemma (Seq.length result_lhs == Seq.length result_rhs) = + assert (Seq.length result_lhs == Seq.length (Seq.snoc tl iv) + 1); + assert (Seq.length (Seq.snoc tl iv) == Seq.length tl + 1); + assert (Seq.length result_lhs == Seq.length tl + 2); + assert (Seq.length s == Seq.length tl + 1); + assert (Seq.length result_rhs == Seq.length s + 1); + assert (Seq.length result_rhs == Seq.length tl + 2) + in + len_eq (); + + let elem_eq (i:nat{i < Seq.length result_lhs}) + : Lemma (Seq.index result_lhs i == Seq.index result_rhs i) = + if i = 0 then ( + assert (Seq.index result_lhs 0 == hd); + assert (Seq.index result_rhs 0 == Seq.index s 0); + assert (Seq.index s 0 == hd) + ) else if i < Seq.length result_lhs - 1 then ( + assert (Seq.index result_lhs i == Seq.index (Seq.snoc tl iv) (i - 1)); + assert (Seq.index (Seq.snoc tl iv) (i - 1) == Seq.index tl (i - 1)); + assert (Seq.index tl (i - 1) == Seq.index s i); + assert (Seq.index result_rhs i == Seq.index s i) + ) else ( + assert (i == Seq.length result_lhs - 1); + assert (Seq.index result_lhs i == Seq.index (Seq.snoc tl iv) (i - 1)); + assert (i - 1 == Seq.length tl); + assert (Seq.index (Seq.snoc tl iv) (Seq.length tl) == iv); + assert (Seq.index result_rhs i == iv) + ) + in + FStar.Classical.forall_intro elem_eq; + Seq.lemma_eq_intro result_lhs result_rhs + ) +#pop-options + +/// When intervals [0..k) have high < offset, and offset+len < s[k].low, +/// add_range inserts at position k +#push-options "--fuel 2 --ifuel 1 --z3rlimit 40" +let rec add_range_insert_no_overlap (s: Seq.seq interval) (offset: nat) (len: pos) (k: nat) + : Lemma (requires range_map_wf s /\ k < Seq.length s /\ + (forall (i:nat). i < k ==> high (Seq.index s i) < offset) /\ + offset + len < (Seq.index s k).low) + (ensures add_range s offset len == + Seq.append (Seq.slice s 0 k) + (Seq.cons ({ low = offset; count = len }) (Seq.slice s k (Seq.length s)))) + (decreases k) + = let iv = { low = offset; count = len } in + if k = 0 then ( + // Base case: insert at position 0 + let hd = Seq.index s 0 in + + // From precondition: offset + len < hd.low + assert (offset + len < hd.low); + + // By definition of add_range, this branch returns Seq.cons iv s + assert (add_range s offset len `Seq.equal` Seq.cons iv s); + + // RHS = Seq.append (Seq.slice s 0 0) (Seq.cons iv (Seq.slice s 0 (Seq.length s))) + // = Seq.append Seq.empty (Seq.cons iv s) + // = Seq.cons iv s + assert (Seq.slice s 0 0 `Seq.equal` Seq.empty); + assert (Seq.slice s 0 (Seq.length s) `Seq.equal` s); + assert (Seq.cons iv s `Seq.equal` Seq.append Seq.empty (Seq.cons iv s)) + ) else ( + // Inductive case: k > 0 + let hd = Seq.index s 0 in + let tl = Seq.tail s in + + // From precondition with i=0: high hd < offset + assert (high hd < offset); + + // By definition of add_range, since high hd < offset: + // add_range s offset len = Seq.cons hd (add_range tl offset len) + + // Establish preconditions for tail with k-1: + // 1. range_map_wf tl + range_map_wf_tail s; + + // 2. k - 1 < Seq.length tl + assert (k < Seq.length s); + assert (Seq.length tl == Seq.length s - 1); + assert (k - 1 < Seq.length tl); + + // 3. forall i < k-1. high (Seq.index tl i) < offset + let tail_forall () : Lemma (forall (i:nat). i < k - 1 ==> high (Seq.index tl i) < offset) = + let aux (i:nat{i < k - 1}) : Lemma (high (Seq.index tl i) < offset) = + assert (Seq.index tl i == Seq.index s (i + 1)); + assert (i + 1 < k) + in + FStar.Classical.forall_intro aux + in + tail_forall (); + + // 4. offset + len < (Seq.index tl (k-1)).low + assert (Seq.index tl (k - 1) == Seq.index s k); + assert (offset + len < (Seq.index tl (k - 1)).low); + + // Apply IH to tail with k-1 + add_range_insert_no_overlap tl offset len (k - 1); + + // IH gives us: add_range tl offset len == + // Seq.append (Seq.slice tl 0 (k-1)) + // (Seq.cons iv (Seq.slice tl (k-1) (Seq.length tl))) + + // So: add_range s offset len = Seq.cons hd (add_range tl offset len) + // = Seq.cons hd (Seq.append (Seq.slice tl 0 (k-1)) + // (Seq.cons iv (Seq.slice tl (k-1) (Seq.length tl)))) + + // Goal: This equals Seq.append (Seq.slice s 0 k) (Seq.cons iv (Seq.slice s k (Seq.length s))) + + // Key observations: + // - Seq.slice s 0 k = Seq.cons hd (Seq.slice tl 0 (k-1)) + // - Seq.slice s k (Seq.length s) = Seq.slice tl (k-1) (Seq.length tl) + + let lhs = add_range s offset len in + let rhs = Seq.append (Seq.slice s 0 k) (Seq.cons iv (Seq.slice s k (Seq.length s))) in + + // Prove Seq.slice s 0 k == Seq.cons hd (Seq.slice tl 0 (k-1)) + let slice_s_eq () : Lemma (Seq.slice s 0 k `Seq.equal` Seq.cons hd (Seq.slice tl 0 (k - 1))) = + let s_prefix = Seq.slice s 0 k in + let tl_prefix = Seq.slice tl 0 (k - 1) in + let expected = Seq.cons hd tl_prefix in + + assert (Seq.length s_prefix == k); + assert (Seq.length expected == k); + + let aux (i:nat{i < k}) : Lemma (Seq.index s_prefix i == Seq.index expected i) = + if i = 0 then ( + assert (Seq.index s_prefix 0 == Seq.index s 0); + assert (Seq.index s 0 == hd); + assert (Seq.index expected 0 == hd) + ) else ( + assert (Seq.index s_prefix i == Seq.index s i); + assert (Seq.index s i == Seq.index tl (i - 1)); + assert (Seq.index tl (i - 1) == Seq.index tl_prefix (i - 1)); + assert (Seq.index expected i == Seq.index tl_prefix (i - 1)) + ) + in + FStar.Classical.forall_intro aux; + Seq.lemma_eq_intro s_prefix expected + in + slice_s_eq (); + + // Prove Seq.slice s k (Seq.length s) == Seq.slice tl (k-1) (Seq.length tl) + let slice_s_suffix_eq () : Lemma (Seq.slice s k (Seq.length s) `Seq.equal` Seq.slice tl (k - 1) (Seq.length tl)) = + let s_suffix = Seq.slice s k (Seq.length s) in + let tl_suffix = Seq.slice tl (k - 1) (Seq.length tl) in + + assert (Seq.length s_suffix == Seq.length s - k); + assert (Seq.length tl_suffix == Seq.length tl - (k - 1)); + assert (Seq.length tl == Seq.length s - 1); + assert (Seq.length tl_suffix == Seq.length s - k); + + let aux (i:nat{i < Seq.length s - k}) : Lemma (Seq.index s_suffix i == Seq.index tl_suffix i) = + assert (Seq.index s_suffix i == Seq.index s (k + i)); + assert (Seq.index s (k + i) == Seq.index tl (k + i - 1)); + assert (Seq.index tl_suffix i == Seq.index tl (k - 1 + i)) + in + FStar.Classical.forall_intro aux; + Seq.lemma_eq_intro s_suffix tl_suffix + in + slice_s_suffix_eq (); + + // Now prove the final equality using associativity of append and cons + // lhs = Seq.cons hd (Seq.append (Seq.slice tl 0 (k-1)) (Seq.cons iv (Seq.slice tl (k-1) (Seq.length tl)))) + // rhs = Seq.append (Seq.cons hd (Seq.slice tl 0 (k-1))) (Seq.cons iv (Seq.slice tl (k-1) (Seq.length tl))) + + // Use the property: Seq.cons x (Seq.append a b) == Seq.append (Seq.cons x a) b + let cons_append_assoc (x: interval) (a b: Seq.seq interval) + : Lemma (Seq.cons x (Seq.append a b) `Seq.equal` Seq.append (Seq.cons x a) b) = + let lhs = Seq.cons x (Seq.append a b) in + let rhs = Seq.append (Seq.cons x a) b in + assert (Seq.length lhs == 1 + Seq.length a + Seq.length b); + assert (Seq.length rhs == (1 + Seq.length a) + Seq.length b); + let aux (i:nat{i < Seq.length lhs}) : Lemma (Seq.index lhs i == Seq.index rhs i) = + if i = 0 then () + else if i <= Seq.length a then () + else () + in + FStar.Classical.forall_intro aux; + Seq.lemma_eq_intro lhs rhs + in + cons_append_assoc hd (Seq.slice tl 0 (k - 1)) (Seq.cons iv (Seq.slice tl (k - 1) (Seq.length tl))); + + Seq.lemma_eq_intro lhs rhs + ) +#pop-options + +#push-options "--fuel 2 --ifuel 1 --z3rlimit 40" +let rec add_range_skip_prefix (s: Seq.seq interval) (offset: nat) (len: pos) (k: nat) + : Lemma (requires range_map_wf s /\ k <= Seq.length s /\ + (forall (i:nat). i < k ==> high (Seq.index s i) < offset)) + (ensures add_range s offset len == + Seq.append (Seq.slice s 0 k) (add_range (Seq.slice s k (Seq.length s)) offset len)) + (decreases k) = + if k = 0 then ( + // Base case: Seq.slice s 0 0 is empty, Seq.slice s 0 n is s + assert (Seq.slice s 0 0 `Seq.equal` Seq.empty); + assert (Seq.slice s 0 (Seq.length s) `Seq.equal` s); + Seq.lemma_eq_intro (Seq.slice s 0 0) Seq.empty; + Seq.lemma_eq_intro (Seq.slice s 0 (Seq.length s)) s; + assert (Seq.append Seq.empty (add_range s offset len) `Seq.equal` add_range s offset len); + Seq.lemma_eq_intro (Seq.append Seq.empty (add_range s offset len)) (add_range s offset len) + ) else ( + // Inductive case: k > 0 + let hd = Seq.index s 0 in + let tl = Seq.tail s in + let n = Seq.length s in + + // hd has high hd < offset (from forall with i=0) + assert (high hd < offset); + + // So add_range s offset len takes the branch: Seq.cons hd (add_range tl offset len) + assert (add_range s offset len == Seq.cons hd (add_range tl offset len)); + + // Apply IH on tl with k-1 + // Need: range_map_wf tl + range_map_wf_tail s; + + // Need: forall for the tail + let forall_tail (i:nat{i < k - 1}) : Lemma (high (Seq.index tl i) < offset) = + assert (Seq.length tl == n - 1); + assert (k <= n); + assert (i < k - 1); + assert (i < n - 1); + assert (i < Seq.length tl); + assert (Seq.index tl i == Seq.index s (i + 1)); + assert (i + 1 < k); + assert (high (Seq.index s (i + 1)) < offset) + in + FStar.Classical.forall_intro forall_tail; + + add_range_skip_prefix tl offset len (k - 1); + + // From IH: add_range tl offset len == Seq.append (Seq.slice tl 0 (k-1)) (add_range (Seq.slice tl (k-1) (Seq.length tl)) offset len) + + // Prove Seq.slice s 0 k == Seq.cons hd (Seq.slice tl 0 (k-1)) + let slice_prefix_eq () : Lemma (Seq.slice s 0 k `Seq.equal` Seq.cons hd (Seq.slice tl 0 (k - 1))) = + let s_prefix = Seq.slice s 0 k in + let expected = Seq.cons hd (Seq.slice tl 0 (k - 1)) in + + assert (Seq.length s_prefix == k); + assert (Seq.length expected == 1 + (k - 1)); + assert (Seq.length expected == k); + + let aux (i:nat{i < k}) : Lemma (Seq.index s_prefix i == Seq.index expected i) = + if i = 0 then ( + assert (Seq.index s_prefix 0 == Seq.index s 0); + assert (Seq.index s 0 == hd); + assert (Seq.index expected 0 == hd) + ) else ( + assert (Seq.index s_prefix i == Seq.index s i); + assert (Seq.index s i == Seq.index tl (i - 1)); + assert (Seq.index (Seq.slice tl 0 (k - 1)) (i - 1) == Seq.index tl (i - 1)); + assert (Seq.index expected i == Seq.index (Seq.slice tl 0 (k - 1)) (i - 1)) + ) + in + FStar.Classical.forall_intro aux; + Seq.lemma_eq_intro s_prefix expected + in + slice_prefix_eq (); + + // Prove Seq.slice s k n == Seq.slice tl (k-1) (Seq.length tl) + let slice_suffix_eq () : Lemma (Seq.slice s k n `Seq.equal` Seq.slice tl (k - 1) (Seq.length tl)) = + let s_suffix = Seq.slice s k n in + let tl_suffix = Seq.slice tl (k - 1) (Seq.length tl) in + + assert (Seq.length s_suffix == n - k); + assert (Seq.length tl == n - 1); + assert (Seq.length tl_suffix == (n - 1) - (k - 1)); + assert (Seq.length tl_suffix == n - k); + + let aux (i:nat{i < n - k}) : Lemma (Seq.index s_suffix i == Seq.index tl_suffix i) = + assert (Seq.index s_suffix i == Seq.index s (k + i)); + assert (Seq.index s (k + i) == Seq.index tl (k + i - 1)); + assert (Seq.index tl_suffix i == Seq.index tl (k - 1 + i)) + in + FStar.Classical.forall_intro aux; + Seq.lemma_eq_intro s_suffix tl_suffix + in + slice_suffix_eq (); + + // Now prove: Seq.cons hd (Seq.append a b) == Seq.append (Seq.cons hd a) b + let cons_append_assoc (#a:Type) (x:a) (s1 s2: Seq.seq a) + : Lemma (Seq.cons x (Seq.append s1 s2) `Seq.equal` Seq.append (Seq.cons x s1) s2) = + let lhs = Seq.cons x (Seq.append s1 s2) in + let rhs = Seq.append (Seq.cons x s1) s2 in + + assert (Seq.length lhs == 1 + Seq.length s1 + Seq.length s2); + assert (Seq.length rhs == (1 + Seq.length s1) + Seq.length s2); + + let aux (i:nat{i < Seq.length lhs}) : Lemma (Seq.index lhs i == Seq.index rhs i) = + if i = 0 then () + else if i <= Seq.length s1 then () + else () + in + FStar.Classical.forall_intro aux; + Seq.lemma_eq_intro lhs rhs + in + + cons_append_assoc hd + (Seq.slice tl 0 (k - 1)) + (add_range (Seq.slice tl (k - 1) (Seq.length tl)) offset len); + + // Final equality follows + Seq.lemma_eq_intro + (add_range s offset len) + (Seq.append (Seq.slice s 0 k) (add_range (Seq.slice s k n) offset len)) + ) +#pop-options + +#push-options "--fuel 2 --ifuel 1 --z3rlimit 40" + +/// Helper: prove transitive sortedness from range_map_wf +let rec range_map_wf_sorted (s: Seq.seq interval) (i j: nat) + : Lemma (requires range_map_wf s /\ i < j /\ j < Seq.length s) + (ensures high (Seq.index s i) < (Seq.index s j).low) + (decreases %[Seq.length s; j - i]) = + if i + 1 = j then begin + // Adjacent case: directly from wf + if i = 0 then begin + assert (separated (Seq.index s 0) (Seq.index s 1)) + end else begin + range_map_wf_tail s; + range_map_wf_sorted (Seq.tail s) (i - 1) (j - 1) + end + end else begin + // Transitive case: i < i+1 < j + range_map_wf_sorted s i (i + 1); + range_map_wf_sorted s (i + 1) j; + // Now we have: high s[i] < s[i+1].low < high s[i+1] < s[j].low + assert (high (Seq.index s i) < (Seq.index s (i + 1)).low); + assert ((Seq.index s (i + 1)).low < high (Seq.index s (i + 1))); + assert (high (Seq.index s (i + 1)) < (Seq.index s j).low) + end + +/// Helper: compute final merged high value after absorbing k elements +let rec merge_absorbed_high (s: Seq.seq interval) (mh: nat) (k: nat{k <= Seq.length s}) + : Tot nat (decreases k) = + if k = 0 then mh + else let hd = Seq.index s 0 in + merge_absorbed_high (Seq.tail s) (if mh > high hd then mh else high hd) (k - 1) + +/// Monotonicity of merge_absorbed_high +let rec merge_absorbed_high_mono (s: Seq.seq interval) (mh: nat) (k: nat{k <= Seq.length s}) + : Lemma (ensures merge_absorbed_high s mh k >= mh) + (decreases k) = + if k = 0 then () + else merge_absorbed_high_mono (Seq.tail s) (if mh > high (Seq.index s 0) then mh else high (Seq.index s 0)) (k - 1) + +/// Unfold from the right: merge_absorbed_high(s, mh, k+1) == max(merge_absorbed_high(s, mh, k), high(s[k])) +/// This enables imperative loop invariant tracking where the last element absorbed is s[k] +let rec merge_absorbed_high_unfold_right (s: Seq.seq interval) (mh: nat) (k: nat{k < Seq.length s}) + : Lemma (ensures (let mh_k = merge_absorbed_high s mh k in + merge_absorbed_high s mh (k + 1) == + (if mh_k > high (Seq.index s k) then mh_k else high (Seq.index s k)))) + (decreases k) = + if k = 0 then begin + // Base case: merge_absorbed_high s mh 1 == max(mh, high(s[0])) + // LHS: merge_absorbed_high s mh 1 + // = merge_absorbed_high (Seq.tail s) (max(mh, high(s[0]))) 0 + // = max(mh, high(s[0])) + // RHS: merge_absorbed_high s mh 0 = mh, so max(mh, high(s[0])) + () + end else begin + // Inductive case: use IH on (Seq.tail s) with mh' = max(mh, high(s[0])) and k-1 + let mh' = if mh > high (Seq.index s 0) then mh else high (Seq.index s 0) in + // IH gives: merge_absorbed_high (Seq.tail s) mh' k == + // max(merge_absorbed_high (Seq.tail s) mh' (k-1), high((Seq.tail s)[k-1])) + merge_absorbed_high_unfold_right (Seq.tail s) mh' (k - 1); + // Note: Seq.index (Seq.tail s) (k - 1) == Seq.index s k + assert (Seq.index (Seq.tail s) (k - 1) == Seq.index s k); + // LHS: merge_absorbed_high s mh (k+1) + // = merge_absorbed_high (Seq.tail s) mh' k (by definition) + // = max(merge_absorbed_high (Seq.tail s) mh' (k-1), high(s[k])) (by IH) + // = max(merge_absorbed_high s mh k, high(s[k])) (by definition of mah(s, mh, k)) + () + end + +/// Step lemma: merge_absorbed_high(s, mh, k+1) == merge_absorbed_high(tail s, max(mh, high(s[0])), k) +let merge_absorbed_high_step (s: Seq.seq interval) (mh: nat) (k: nat{k < Seq.length s}) + : Lemma (ensures merge_absorbed_high s mh (k + 1) == + merge_absorbed_high (Seq.tail s) (if mh > high (Seq.index s 0) then mh else high (Seq.index s 0)) k) = () + +/// Shift: merge_absorbed_high on slice (i..n) relates to original seq indexing +let merge_absorbed_high_slice_step (s: Seq.seq interval) (base: nat) (mh: nat) (k: nat) + : Lemma (requires base + k + 1 <= Seq.length s /\ base + 1 <= Seq.length s) + (ensures (let suffix = Seq.slice s base (Seq.length s) in + let mh' = (if mh > high (Seq.index s base) then mh else high (Seq.index s base)) in + Seq.lemma_eq_intro (Seq.tail suffix) (Seq.slice s (base + 1) (Seq.length s)); + merge_absorbed_high suffix mh (k + 1) == + merge_absorbed_high (Seq.slice s (base + 1) (Seq.length s)) mh' k)) = + let suffix = Seq.slice s base (Seq.length s) in + Seq.lemma_eq_intro (Seq.tail suffix) (Seq.slice s (base + 1) (Seq.length s)); + merge_absorbed_high_step suffix mh k + +/// Lemma: With range_map_wf, high values are strictly increasing +let rec high_values_increasing (s: Seq.seq interval) (i j: nat) + : Lemma (requires range_map_wf s /\ i < j /\ j < Seq.length s) + (ensures high (Seq.index s i) < high (Seq.index s j)) + (decreases j - i) = + if i + 1 = j then begin + // Adjacent case: from wf, high(s[i]) < s[j].low <= s[j].low < high(s[j]) + range_map_wf_sorted s i j; + assert (high (Seq.index s i) < (Seq.index s j).low); + assert ((Seq.index s j).low < high (Seq.index s j)) + end else begin + // Transitive case: i < j-1 < j + high_values_increasing s i (j - 1); + high_values_increasing s (j - 1) j; + assert (high (Seq.index s i) < high (Seq.index s (j - 1))); + assert (high (Seq.index s (j - 1)) < high (Seq.index s j)) + end + +/// Lemma: For wf sequences with k > 0, merge_absorbed_high equals max(mh, high(s[k-1])) +/// because high values are strictly increasing, so high(s[k-1]) dominates all earlier highs +let rec merge_absorbed_high_eq_max_last (s: Seq.seq interval) (mh: nat) (k: nat) + : Lemma (requires range_map_wf s /\ 0 < k /\ k <= Seq.length s) + (ensures merge_absorbed_high s mh k == + (if mh > high (Seq.index s (k - 1)) then mh else high (Seq.index s (k - 1)))) + (decreases k) = + if k = 1 then begin + // Base case: merge_absorbed_high s mh 1 == max(mh, high(s[0])) + () + end else begin + // k > 1: by IH, merge_absorbed_high s mh (k-1) == max(mh, high(s[k-2])) + merge_absorbed_high_eq_max_last s mh (k - 1); + // By unfold_right: merge_absorbed_high s mh k == max(mah(s, mh, k-1), high(s[k-1])) + merge_absorbed_high_unfold_right s mh (k - 1); + // So: merge_absorbed_high s mh k == max(max(mh, high(s[k-2])), high(s[k-1])) + // By wf, high(s[k-2]) < high(s[k-1]) + high_values_increasing s (k - 2) (k - 1); + assert (high (Seq.index s (k - 2)) < high (Seq.index s (k - 1))); + // Therefore: max(max(mh, high(s[k-2])), high(s[k-1])) == max(mh, high(s[k-1])) + () + end + +/// Main lemma: From the running-max invariant plus wf, derive that mh0 covers absorbed elements +/// +/// If merge_absorbed_high(s, mh0, k) >= s[k].low for some k, and range_map_wf holds, +/// then mh0 >= s[k].low. +/// +/// Proof: By merge_absorbed_high_eq_max_last, mah(s, mh0, k) = max(mh0, high(s[k-1])). +/// By wf, high(s[k-1]) < s[k].low (from range_map_wf_sorted). +/// So max(mh0, high(s[k-1])) >= s[k].low and high(s[k-1]) < s[k].low +/// implies mh0 >= s[k].low. +let mh0_covers_absorbed (s: Seq.seq interval) (mh0: nat) (k: nat) + : Lemma (requires range_map_wf s /\ + 0 < k /\ k < Seq.length s /\ + merge_absorbed_high s mh0 k >= (Seq.index s k).low) + (ensures mh0 >= (Seq.index s k).low) = + // Step 1: Express merge_absorbed_high as max(mh0, high(s[k-1])) + merge_absorbed_high_eq_max_last s mh0 k; + assert (merge_absorbed_high s mh0 k == + (if mh0 > high (Seq.index s (k - 1)) then mh0 else high (Seq.index s (k - 1)))); + + // Step 2: Use wf to show high(s[k-1]) < s[k].low + range_map_wf_sorted s (k - 1) k; + assert (high (Seq.index s (k - 1)) < (Seq.index s k).low); + + // Step 3: From merge_absorbed_high s mh0 k >= s[k].low and high(s[k-1]) < s[k].low, + // deduce mh0 >= s[k].low + let mah_val = merge_absorbed_high s mh0 k in + let s_k_low = (Seq.index s k).low in + let high_prev = high (Seq.index s (k - 1)) in + + assert (mah_val >= s_k_low); + assert (high_prev < s_k_low); + assert (mah_val == (if mh0 > high_prev then mh0 else high_prev)); + + // Since high_prev < s_k_low and max(mh0, high_prev) >= s_k_low, + // we must have mh0 >= s_k_low + if mh0 > high_prev then + assert (mh0 >= s_k_low) + else begin + assert (mah_val == high_prev); + assert (high_prev >= s_k_low); + assert (False) // Contradiction: high_prev < s_k_low but also high_prev >= s_k_low + end + +/// Lemma 1: Trivial unfolding lemma for the merge branch +let add_range_merge_step (s: Seq.seq interval) (offset: nat) (len: pos) + : Lemma (requires Seq.length s > 0 /\ + (let hd = Seq.index s 0 in + ~(offset + len < hd.low) /\ ~(high hd < offset))) + (ensures (let hd = Seq.index s 0 in + let tl = Seq.tail s in + let ml = (if offset < hd.low then offset else hd.low) in + let mh = (if offset + len > high hd then offset + len else high hd) in + mh > ml /\ + add_range s offset len == add_range tl ml (mh - ml))) + = let hd = Seq.index s 0 in + let ml = (if offset < hd.low then offset else hd.low) in + let mh = (if offset + len > high hd then offset + len else high hd) in + // Show mh > ml + assert (offset + len > offset); + assert (~(offset + len < hd.low)); + assert (offset + len >= hd.low); + assert (~(high hd < offset)); + assert (high hd >= offset); + assert (mh >= offset); + assert (ml <= offset); + // Unfold add_range definition + assert (add_range s offset len == + (let hd' = Seq.index s 0 in + let tl' = Seq.tail s in + if offset + len < hd'.low then Seq.cons ({ low = offset; count = len }) s + else if high hd' < offset then Seq.cons hd' (add_range tl' offset len) + else + let merged_low = if offset < hd'.low then offset else hd'.low in + let merged_high = if offset + len > high hd' then offset + len else high hd' in + add_range tl' merged_low (merged_high - merged_low))) + +/// Lemma 2: Characterize recursive merge +let rec add_range_merge_scan (s: Seq.seq interval) (ml: nat) (mh: nat) (k: nat) + : Lemma (requires range_map_wf s /\ mh > ml /\ + k <= Seq.length s /\ + (k > 0 ==> ml <= (Seq.index s 0).low) /\ + (forall (i:nat). i < k ==> mh >= (Seq.index s i).low) /\ + (k = Seq.length s \/ mh < (Seq.index s k).low)) + (ensures (let fh = merge_absorbed_high s mh k in + fh > ml /\ + add_range s ml (mh - ml) == + Seq.append (Seq.create 1 ({ low = ml; count = fh - ml })) + (Seq.slice s k (Seq.length s)))) + (decreases k) = + if k = 0 then begin + // No overlaps to absorb + merge_absorbed_high_mono s mh 0; + assert (merge_absorbed_high s mh 0 = mh); + + if Seq.length s = 0 then begin + // Empty sequence case + let iv = { low = ml; count = mh - ml } in + assert (add_range s ml (mh - ml) == Seq.create 1 iv); + Seq.lemma_eq_intro (Seq.slice s 0 0) Seq.empty; + Seq.lemma_eq_intro (Seq.append (Seq.create 1 iv) Seq.empty) (Seq.create 1 iv) + end else begin + // mh < s[0].low, insert before first element + let iv = { low = ml; count = mh - ml } in + assert (ml + (mh - ml) = mh); + assert (mh < (Seq.index s 0).low); + assert (add_range s ml (mh - ml) == Seq.cons iv s); + Seq.lemma_eq_intro (Seq.slice s 0 (Seq.length s)) s; + Seq.lemma_eq_intro (Seq.cons iv s) (Seq.append (Seq.create 1 iv) s) + end + end else begin + // k > 0: first element overlaps + let hd = Seq.index s 0 in + let tl = Seq.tail s in + let n = Seq.length s in + + // Establish overlap conditions + assert (mh >= hd.low); // from forall with i=0 + assert (~(ml + (mh - ml) < hd.low)); // since ml + (mh - ml) = mh >= hd.low + + // Show ~(high hd < ml) + assert (ml <= hd.low); // from precondition k > 0 + assert (high hd = hd.low + hd.count); + assert (hd.count > 0); + assert (high hd > hd.low); + assert (high hd >= ml); // since high hd > hd.low >= ml + assert (~(high hd < ml)); + + // Merge branch is taken + let ml' = (if ml < hd.low then ml else hd.low) in + let mh' = (if mh > high hd then mh else high hd) in + + // ml' = ml since ml <= hd.low + assert (ml' = ml); + + // mh' = max(mh, high hd) + assert (mh' >= mh); + assert (mh' >= high hd); + assert (mh' > ml); + + // Establish IH preconditions for tl with k-1 + range_map_wf_tail s; + assert (range_map_wf tl); + + // Show: forall i < k-1. mh' >= (Seq.index tl i).low + // Seq.index tl i = Seq.index s (i+1) + assert (forall (i:nat). i < k - 1 ==> ( + let si1 = Seq.index s (i + 1) in + let ti = Seq.index tl i in + ti == si1 + )); + + assert (forall (i:nat). i < k - 1 ==> mh >= (Seq.index s (i + 1)).low); + assert (forall (i:nat). i < k - 1 ==> mh' >= (Seq.index tl i).low); + + // Show: k-1 = Seq.length tl \/ mh' < (Seq.index tl (k-1)).low + assert (Seq.length tl = n - 1); + if k = n then begin + assert (k - 1 = Seq.length tl) + end else begin + // mh < s[k].low, need to show mh' < s[k].low + assert (mh < (Seq.index s k).low); + + // By wf, high hd < s[1].low + if k >= 2 then begin + assert (separated hd (Seq.index s 1)); + assert (high hd < (Seq.index s 1).low); + range_map_wf_sorted s 1 k; + assert (high (Seq.index s 1) < (Seq.index s k).low); + assert ((Seq.index s 1).low < high (Seq.index s 1)); + assert ((Seq.index s 1).low < (Seq.index s k).low); + assert (high hd < (Seq.index s 1).low); + assert (high hd < (Seq.index s k).low) + end else begin + // k = 1, so we need mh' < s[1].low + assert (k = 1); + assert (separated hd (Seq.index s 1)); + assert (high hd < (Seq.index s 1).low) + end; + + // mh' = max(mh, high hd), both < s[k].low + assert (mh' < (Seq.index s k).low); + assert (Seq.index tl (k - 1) == Seq.index s k); + assert (mh' < (Seq.index tl (k - 1)).low) + end; + + // Show: k-1 > 0 ==> ml <= (Seq.index tl 0).low + if k - 1 > 0 then begin + assert (k >= 2); + assert (Seq.index tl 0 == Seq.index s 1); + assert (ml <= hd.low); + assert (separated hd (Seq.index s 1)); + assert (high hd < (Seq.index s 1).low); + assert (hd.low < (Seq.index s 1).low); + assert (ml <= (Seq.index tl 0).low) + end; + + // Apply IH + add_range_merge_scan tl ml mh' (k - 1); + + // Now we have: add_range tl ml (mh' - ml) = + // append (create 1 {ml, merge_absorbed_high tl mh' (k-1) - ml}) + // (slice tl (k-1) (length tl)) + + // Show: merge_absorbed_high s mh k = merge_absorbed_high tl mh' (k-1) + assert (merge_absorbed_high s mh k = + merge_absorbed_high tl (if mh > high hd then mh else high hd) (k - 1)); + assert (merge_absorbed_high s mh k = merge_absorbed_high tl mh' (k - 1)); + + let fh = merge_absorbed_high s mh k in + + // Show: add_range s ml (mh - ml) = add_range tl ml (mh' - ml) + assert (add_range s ml (mh - ml) == add_range tl ml (mh' - ml)); + + // From IH: add_range tl ml (mh' - ml) = append (create 1 {ml, fh - ml}) (slice tl (k-1) (n-1)) + + // Show: slice tl (k-1) (n-1) = slice s k n + assert (forall (i:nat). i < Seq.length (Seq.slice tl (k - 1) (n - 1)) ==> + Seq.index (Seq.slice tl (k - 1) (n - 1)) i == + Seq.index (Seq.slice s k n) i); + Seq.lemma_eq_intro (Seq.slice tl (k - 1) (n - 1)) (Seq.slice s k n); + + // Conclude + Seq.lemma_eq_intro + (add_range s ml (mh - ml)) + (Seq.append (Seq.create 1 ({ low = ml; count = fh - ml })) + (Seq.slice s k n)) + end + +#pop-options + +#push-options "--fuel 2 --ifuel 1 --z3rlimit 80" + +let rec range_map_wf_slice (s: Seq.seq interval) (i: nat) + : Lemma (requires range_map_wf s /\ i <= Seq.length s) + (ensures range_map_wf (Seq.slice s i (Seq.length s))) + (decreases i) = + if i = 0 then Seq.lemma_eq_intro (Seq.slice s 0 (Seq.length s)) s + else begin + range_map_wf_tail s; + range_map_wf_slice (Seq.tail s) (i - 1); + Seq.lemma_eq_intro (Seq.slice s i (Seq.length s)) (Seq.slice (Seq.tail s) (i - 1) (Seq.length (Seq.tail s))) + end + +#pop-options + +#push-options "--fuel 1 --ifuel 0 --z3rlimit 40 --split_queries always" + +/// Helper: the suffix part of the merge +/// add_range (slice s iv n) offset len == append (create 1 merged) (slice s j n) +let add_range_merge_suffix (s: Seq.seq interval) (offset: nat) (len: pos) (iv j: nat) + : Lemma (requires range_map_wf s /\ + iv < Seq.length s /\ j > iv /\ j <= Seq.length s /\ + ~(offset + len < (Seq.index s iv).low) /\ + ~(high (Seq.index s iv) < offset) /\ + (let ml = (if offset < (Seq.index s iv).low then offset else (Seq.index s iv).low) in + let mh0 = (if offset + len > high (Seq.index s iv) then offset + len else high (Seq.index s iv)) in + (forall (i:nat). i > iv /\ i < j ==> mh0 >= (Seq.index s i).low) /\ + (j = Seq.length s \/ mh0 < (Seq.index s j).low))) + (ensures (let ml = (if offset < (Seq.index s iv).low then offset else (Seq.index s iv).low) in + let mh0 = (if offset + len > high (Seq.index s iv) then offset + len else high (Seq.index s iv)) in + let suffix_tail = Seq.slice s (iv + 1) (Seq.length s) in + let fh = merge_absorbed_high suffix_tail mh0 (j - iv - 1) in + fh > ml /\ + add_range (Seq.slice s iv (Seq.length s)) offset len == + Seq.append (Seq.create 1 ({ low = ml; count = fh - ml })) + (Seq.slice s j (Seq.length s)))) = + let n = Seq.length s in + let ml = (if offset < (Seq.index s iv).low then offset else (Seq.index s iv).low) in + let mh0 = (if offset + len > high (Seq.index s iv) then offset + len else high (Seq.index s iv)) in + let k = j - iv - 1 in + let suffix = Seq.slice s iv n in + let stail = Seq.tail suffix in + + // merge step on first element of suffix + range_map_wf_slice s iv; + add_range_merge_step suffix offset len; + + // merge scan using stail + Seq.lemma_eq_intro stail (Seq.slice s (iv + 1) n); + range_map_wf_slice s (iv + 1); + if k > 0 then range_map_wf_sorted s iv (iv + 1); + add_range_merge_scan stail ml mh0 k; + merge_absorbed_high_mono stail mh0 k; + + // Relate slice stail k to slice s j n + Seq.lemma_eq_intro (Seq.slice stail k (Seq.length stail)) (Seq.slice s j n) + +/// Full merge: combines skip_prefix + merge_suffix +let add_range_merge_full (s: Seq.seq interval) (offset: nat) (len: pos) (iv j: nat) + : Lemma (requires range_map_wf s /\ + iv < Seq.length s /\ j > iv /\ j <= Seq.length s /\ + (forall (i:nat). i < iv ==> high (Seq.index s i) < offset) /\ + ~(offset + len < (Seq.index s iv).low) /\ + ~(high (Seq.index s iv) < offset) /\ + (let ml = (if offset < (Seq.index s iv).low then offset else (Seq.index s iv).low) in + let mh0 = (if offset + len > high (Seq.index s iv) then offset + len else high (Seq.index s iv)) in + (forall (i:nat). i > iv /\ i < j ==> mh0 >= (Seq.index s i).low) /\ + (j = Seq.length s \/ mh0 < (Seq.index s j).low))) + (ensures (let ml = (if offset < (Seq.index s iv).low then offset else (Seq.index s iv).low) in + let mh0 = (if offset + len > high (Seq.index s iv) then offset + len else high (Seq.index s iv)) in + let suffix_tail = Seq.slice s (iv + 1) (Seq.length s) in + let fh = merge_absorbed_high suffix_tail mh0 (j - iv - 1) in + fh > ml /\ + add_range s offset len == + Seq.append (Seq.slice s 0 iv) + (Seq.append (Seq.create 1 ({ low = ml; count = fh - ml })) + (Seq.slice s j (Seq.length s))))) = + add_range_skip_prefix s offset len iv; + add_range_merge_suffix s offset len iv j + +/// Explicit-mh0 version: takes mh0 as a parameter for easier SMT matching +let add_range_merge_full_explicit (s: Seq.seq interval) (offset: nat) (len: pos) (iv j: nat) (mh0: nat) + : Lemma (requires range_map_wf s /\ + iv < Seq.length s /\ j > iv /\ j <= Seq.length s /\ + (forall (i:nat). i < iv ==> high (Seq.index s i) < offset) /\ + ~(offset + len < (Seq.index s iv).low) /\ + ~(high (Seq.index s iv) < offset) /\ + mh0 == (if offset + len > high (Seq.index s iv) then offset + len else high (Seq.index s iv)) /\ + (forall (i:nat). i > iv /\ i < j ==> mh0 >= (Seq.index s i).low) /\ + (j = Seq.length s \/ mh0 < (Seq.index s j).low)) + (ensures (let ml = (if offset < (Seq.index s iv).low then offset else (Seq.index s iv).low) in + let suffix_tail = Seq.slice s (iv + 1) (Seq.length s) in + let fh = merge_absorbed_high suffix_tail mh0 (j - iv - 1) in + fh > ml /\ + add_range s offset len == + Seq.append (Seq.slice s 0 iv) + (Seq.append (Seq.create 1 ({ low = ml; count = fh - ml })) + (Seq.slice s j (Seq.length s))))) = + add_range_merge_full s offset len iv j + +#pop-options + +(*** Length bounds ***) + +/// drain_ranges never increases the number of intervals +let rec drain_ranges_length (s: Seq.seq interval) (n: nat) + : Lemma (ensures Seq.length (drain_ranges s n) <= Seq.length s) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + drain_ranges_length tl n + +/// Separated intervals span at least 2n-1 offsets within [lo, bound). +/// From range_map_wf (gaps >= 1) and count >= 1 per interval. +let rec wf_count_bound (s: Seq.seq interval) (lo bound: nat) + : Lemma (requires range_map_wf s /\ range_map_bounded s bound /\ + Seq.length s > 0 /\ (Seq.index s 0).low >= lo) + (ensures Seq.length s + Seq.length s <= bound - lo + 1) + (decreases Seq.length s) = + let hd = Seq.index s 0 in + if Seq.length s = 1 then + // Single interval: count >= 1, so high hd <= bound and hd.low >= lo + // high hd - lo >= hd.count >= 1 = 2*1 - 1 + () + else begin + let tl = Seq.tail s in + let hd2 = Seq.index s 1 in + // separated hd hd2: high hd < hd2.low (gap >= 1) + assert (Seq.index tl 0 == hd2); + range_map_wf_tail s; + // range_map_bounded for tail + let aux (i:nat{i < Seq.length tl}) : Lemma (high (Seq.index tl i) <= bound) = + assert (Seq.index tl i == Seq.index s (i + 1)) + in + Classical.forall_intro aux; + // Recurse: 2*(|tl|) - 1 <= bound - hd2.low + wf_count_bound tl hd2.low bound; + // hd spans [hd.low, high hd), count >= 1 + // gap: hd2.low > high hd, so hd2.low >= high hd + 1 + assert (high hd - hd.low >= 1); // count >= 1 + assert (hd2.low - high hd >= 1); // separated gap >= 1 + assert (hd.low >= lo) + end + +/// add_range increases length by at most 1 +let rec add_range_length (s: Seq.seq interval) (offset: nat) (len: pos) + : Lemma (ensures Seq.length (add_range s offset len) <= Seq.length s + 1) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + if offset + len < hd.low then () + else if high hd < offset then + add_range_length tl offset len + else + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + add_range_length tl merged_low (merged_high - merged_low) + +/// Drain the first interval (or trim it) up to new_bo. +/// Precondition: first interval contains new_bo (base_aligned + n <= cf). +let drain_repr (s: Seq.seq interval) (new_bo: nat) + : Seq.seq interval = + if Seq.length s = 0 then Seq.empty + else + let first = Seq.index s 0 in + if high first <= new_bo then Seq.tail s + else if first.low < new_bo then + Seq.cons ({ low = new_bo; count = high first - new_bo }) (Seq.tail s) + else s + +/// drain_repr preserves range_map_wf +let drain_repr_wf (s: Seq.seq interval) (new_bo: nat) + : Lemma (requires range_map_wf s /\ Seq.length s > 0 /\ + (Seq.index s 0).low <= new_bo /\ new_bo <= high (Seq.index s 0)) + (ensures range_map_wf (drain_repr s new_bo)) = + let first = Seq.index s 0 in + let tl = Seq.tail s in + if high first <= new_bo then () + else if first.low < new_bo then begin + let trimmed = { low = new_bo; count = high first - new_bo } in + if Seq.length tl = 0 then + range_map_wf_cons trimmed tl + else begin + let next = Seq.index tl 0 in + assert (Seq.index s 1 == next); + assert (separated first next); + assert (high first < next.low); + assert (high trimmed == high first); + assert (high trimmed < next.low); + assert (separated trimmed next); + range_map_wf_cons trimmed tl + end + end else () + +/// drain_repr preserves range_map_bounded +let drain_repr_bounded (s: Seq.seq interval) (new_bo: nat) (bound: nat) + : Lemma (requires range_map_bounded s bound /\ Seq.length s > 0 /\ + (Seq.index s 0).low <= new_bo /\ new_bo <= high (Seq.index s 0)) + (ensures range_map_bounded (drain_repr s new_bo) bound) = + let first = Seq.index s 0 in + let result = drain_repr s new_bo in + if high first <= new_bo then begin + let tl = Seq.tail s in + let aux (i:nat{i < Seq.length tl}) + : Lemma (high (Seq.index tl i) <= bound) + = assert (Seq.index tl i == Seq.index s (i + 1)) + in + Classical.forall_intro aux + end else if first.low < new_bo then begin + let trimmed = { low = new_bo; count = high first - new_bo } in + let tl = Seq.tail s in + assert (high trimmed == high first); + assert (high trimmed <= bound); + let aux (i:nat{i < Seq.length tl}) + : Lemma (high (Seq.index tl i) <= bound) + = assert (Seq.index tl i == Seq.index s (i + 1)) + in + Classical.forall_intro aux; + assert (Seq.index result 0 == trimmed); + let aux2 (i:nat{i < Seq.length result}) + : Lemma (high (Seq.index result i) <= bound) + = if i = 0 then () else assert (Seq.index result i == Seq.index tl (i - 1)) + in + Classical.forall_intro aux2 + end else () + +/// drain_repr decreases (or preserves) length +let drain_repr_length (s: Seq.seq interval) (new_bo: nat) + : Lemma (Seq.length (drain_repr s new_bo) <= Seq.length s) = () + +/// drain_repr mem: positions >= new_bo are preserved +let drain_repr_mem_above (s: Seq.seq interval) (new_bo: nat) (x: nat) + : Lemma (requires range_map_wf s /\ Seq.length s > 0 /\ + (Seq.index s 0).low <= new_bo /\ new_bo <= high (Seq.index s 0) /\ + x >= new_bo) + (ensures mem (drain_repr s new_bo) x == mem s x) = + let first = Seq.index s 0 in + let result = drain_repr s new_bo in + if high first <= new_bo then begin + // first removed, x >= new_bo >= high first, so x not in first + assert (not (in_interval first x)); + if mem s x then mem_tail s x else () + end else if first.low < new_bo then begin + let trimmed = { low = new_bo; count = high first - new_bo } in + let tl = Seq.tail s in + mem_cons trimmed tl x + end else () + +/// add_range preserves first.low >= lo when offset >= lo +let rec add_range_first_low (s: Seq.seq interval) (offset: nat) (len: pos) (lo: nat) + : Lemma (requires range_map_wf s /\ + (Seq.length s = 0 \/ (Seq.index s 0).low >= lo) /\ offset >= lo) + (ensures Seq.length (add_range s offset len) > 0 /\ + (Seq.index (add_range s offset len) 0).low >= lo) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else + let hd = Seq.index s 0 in + let tl = Seq.tail s in + if offset + len < hd.low then () + else if high hd < offset then + // add_range returns cons hd (add_range tl offset len) + () + else + let merged_low = if offset < hd.low then offset else hd.low in + let merged_high = if offset + len > high hd then offset + len else high hd in + range_map_wf_tail s; + add_range_first_low tl merged_low (merged_high - merged_low) lo diff --git a/lib/pulse/lib/Pulse.Lib.RangeMap.fst b/lib/pulse/lib/Pulse.Lib.RangeMap.fst new file mode 100644 index 000000000..4c49c007c --- /dev/null +++ b/lib/pulse/lib/Pulse.Lib.RangeMap.fst @@ -0,0 +1,835 @@ +module Pulse.Lib.RangeMap + +/// Range map backed by an AVL tree mapping Range intervals to unit (pure interval tracker). + +#lang-pulse + +open Pulse.Lib.Pervasives + +module SZ = FStar.SizeT +module Seq = FStar.Seq +module Spec = Pulse.Lib.RangeMap.Spec +module B = Pulse.Lib.Box +module T = Pulse.Lib.Spec.AVLTree +module AVL = Pulse.Lib.AVLTree +module G = FStar.Ghost +module R = Pulse.Lib.Reference +/// Concrete range type: [start, start+len) +type range = { start: SZ.t; len: SZ.t } + +/// An entry is a range paired with unit (no byte data) +let entry = range & unit + +(*** B1: Range comparison ***) + +let range_cmp_fn (a b: range) : int = + let av = SZ.v a.start in + let bv = SZ.v b.start in + if av < bv then (-1) + else if av = bv then 0 + else 1 + +let range_cmp : T.cmp range = range_cmp_fn + +(*** Helpers ***) + +let range_valid (r: range) : prop = + SZ.v r.len > 0 /\ + SZ.fits (SZ.v r.start + SZ.v r.len) + +let entry_valid (e: entry) : prop = + range_valid (fst e) + +let rec list_valid (l: list entry) : Tot prop (decreases l) = + match l with + | [] -> True + | hd :: tl -> entry_valid hd /\ list_valid tl + +let range_to_interval (r: range) + : Pure Spec.interval (requires range_valid r) (ensures fun _ -> True) = + { Spec.low = SZ.v r.start; Spec.count = SZ.v r.len } + +let mk_entry (s l: SZ.t) : entry = ({ start = s; len = l }, ()) + +let rec list_to_spec (l: list entry) + : Pure (Seq.seq Spec.interval) + (requires list_valid l) + (ensures fun r -> True) + (decreases l) = + match l with + | [] -> Seq.empty + | hd :: tl -> Seq.cons (range_to_interval (fst hd)) (list_to_spec tl) + +let rec seq_all_valid (s: Seq.seq entry) + : Tot prop (decreases Seq.length s) = + if Seq.length s = 0 then True + else entry_valid (Seq.head s) /\ seq_all_valid (Seq.tail s) + +let rec seq_to_spec (s: Seq.seq entry) + : Pure (Seq.seq Spec.interval) + (requires seq_all_valid s) + (ensures fun r -> Seq.length r == Seq.length s) + (decreases Seq.length s) = + if Seq.length s = 0 then Seq.empty + else Seq.cons (range_to_interval (fst (Seq.head s))) (seq_to_spec (Seq.tail s)) + +let seq_to_spec_head (s: Seq.seq entry) + : Lemma (requires seq_all_valid s /\ Seq.length s > 0) + (ensures Seq.head (seq_to_spec s) == range_to_interval (fst (Seq.head s))) = () + +let tree_wf (t: T.tree range unit) : prop = + seq_all_valid (T.inorder t) + +let tree_to_spec (t: T.tree range unit) + : Pure (Seq.seq Spec.interval) + (requires tree_wf t) + (ensures fun r -> Seq.length r == Seq.length (T.inorder t)) = + seq_to_spec (T.inorder t) + +(*** seq ↔ list conversion ***) + +let rec seq_to_list (s: Seq.seq entry) + : Tot (list entry) (decreases Seq.length s) = + if Seq.length s = 0 then [] + else Seq.head s :: seq_to_list (Seq.tail s) + +let rec seq_to_list_valid (s: Seq.seq entry) + : Lemma (requires seq_all_valid s) + (ensures list_valid (seq_to_list s)) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else seq_to_list_valid (Seq.tail s) + +let rec seq_to_list_spec (s: Seq.seq entry) + : Lemma (requires seq_all_valid s) + (ensures (seq_to_list_valid s; list_to_spec (seq_to_list s) == seq_to_spec s)) + (decreases Seq.length s) = + if Seq.length s = 0 then ( + seq_to_list_valid s; + assert (Seq.equal (list_to_spec (seq_to_list s)) (seq_to_spec s)) + ) else ( + seq_to_list_valid (Seq.tail s); + seq_to_list_spec (Seq.tail s); + seq_to_list_valid s + ) + +(*** tree_min ↔ inorder head ***) + +#push-options "--fuel 3 --z3rlimit 40" + +let rec tree_min_head_inorder (t: T.tree range unit{T.Node? t}) + : Lemma (ensures Seq.length (T.inorder t) > 0 /\ + T.tree_min t == Seq.head (T.inorder t)) + (decreases t) = + match t with + | T.Node dk dv T.Leaf r -> + Seq.append_empty_l (Seq.cons (dk, dv) (T.inorder r)); + assert (Seq.equal (T.inorder t) (Seq.cons (dk, dv) (T.inorder r))) + | T.Node dk dv l r -> + tree_min_head_inorder l; + Seq.lemma_head_append (T.inorder l) (Seq.cons (dk, dv) (T.inorder r)) + +#pop-options + +(*** seq_to_spec indexing ***) + +#push-options "--fuel 2 --ifuel 1" + +let rec seq_all_valid_index (s: Seq.seq entry) (i: nat) + : Lemma (requires seq_all_valid s /\ i < Seq.length s) + (ensures entry_valid (Seq.index s i)) + (decreases Seq.length s) = + if i = 0 then () + else seq_all_valid_index (Seq.tail s) (i - 1) + +let rec seq_to_spec_index (s: Seq.seq entry) (i: nat) + : Lemma (requires seq_all_valid s /\ i < Seq.length s) + (ensures entry_valid (Seq.index s i) /\ + Seq.index (seq_to_spec s) i == range_to_interval (fst (Seq.index s i))) + (decreases Seq.length s) = + seq_all_valid_index s i; + if i = 0 then () + else seq_to_spec_index (Seq.tail s) (i - 1) + +#pop-options + +(*** tree_max ↔ spec last ***) + +let tree_max_last_spec (t: T.tree range unit) + : Lemma (requires T.Node? t /\ tree_wf t) + (ensures (let s = T.inorder t in + Seq.length s > 0 /\ + entry_valid (Seq.index s (Seq.length s - 1)) /\ + Seq.index (tree_to_spec t) (Seq.length (tree_to_spec t) - 1) == + range_to_interval (fst (Seq.index s (Seq.length s - 1))) /\ + T.tree_max t == Seq.index s (Seq.length s - 1))) = + T.tree_max_last_inorder t; + let s = T.inorder t in + let n = Seq.length s in + seq_to_spec_index s (n - 1) + +(*** Pure implementations ***) + +#push-options "--z3rlimit_factor 4 --fuel 2 --ifuel 2" + +let rec add_range_impl (l: list entry) (off len: SZ.t) + : Pure (list entry) + (requires list_valid l /\ SZ.v len > 0 /\ SZ.fits (SZ.v off + SZ.v len)) + (ensures fun r -> list_valid r /\ + list_to_spec r == Spec.add_range (list_to_spec l) (SZ.v off) (SZ.v len)) + (decreases List.Tot.length l) = + match l with + | [] -> + let e = mk_entry off len in + assert (Seq.equal (list_to_spec [e]) + (Spec.add_range (list_to_spec []) (SZ.v off) (SZ.v len))); + [e] + | hd :: tl -> + let hd_low = (fst hd).start in + let hd_count = (fst hd).len in + let hd_high = SZ.add hd_low hd_count in + let off_plus_len = SZ.add off len in + let hd_spec = range_to_interval (fst hd) in + let tl_spec = list_to_spec tl in + assert (list_to_spec l == Seq.cons hd_spec tl_spec); + assert (Seq.length (list_to_spec l) > 0); + assert (Seq.index (list_to_spec l) 0 == hd_spec); + assert (Seq.tail (list_to_spec l) `Seq.equal` tl_spec); + assert (Spec.high hd_spec == SZ.v hd_high); + if SZ.lt off_plus_len hd_low then ( + let e = mk_entry off len in + assert (SZ.v off + SZ.v len < SZ.v hd_low); + assert (SZ.v off_plus_len < hd_spec.Spec.low); + assert (Seq.equal (list_to_spec (e :: l)) + (Seq.cons (range_to_interval (fst e)) (list_to_spec l))); + e :: l + ) + else if SZ.lt hd_high off then ( + assert (Spec.high hd_spec < SZ.v off); + let r = add_range_impl tl off len in + assert (list_to_spec r == Spec.add_range tl_spec (SZ.v off) (SZ.v len)); + assert (Seq.equal (list_to_spec (hd :: r)) + (Seq.cons hd_spec (list_to_spec r))); + hd :: r + ) + else ( + let merged_low = (if SZ.lt off hd_low then off else hd_low) in + let merged_high = (if SZ.gt off_plus_len hd_high then off_plus_len else hd_high) in + assert (SZ.v merged_high > SZ.v merged_low); + let new_len = SZ.sub merged_high merged_low in + assert (SZ.v new_len > 0); + assert (SZ.fits (SZ.v merged_low + SZ.v new_len)); + add_range_impl tl merged_low new_len + ) + +#pop-options + +(*** Rebuild: list → tree ***) + +let rec list_to_tree_fwd (l: list entry) (acc: T.tree range unit) + : Tot (T.tree range unit) (decreases l) = + match l with + | [] -> acc + | hd :: tl -> list_to_tree_fwd tl (T.insert_avl range_cmp acc (fst hd) (snd hd)) + +let rec list_to_seq (l: list entry) + : Tot (Seq.seq entry) (decreases l) = + match l with + | [] -> Seq.empty + | hd :: tl -> Seq.cons hd (list_to_seq tl) + +let rec list_sorted (l: list entry) : prop = + match l with + | [] -> True + | [_] -> True + | a :: b :: rest -> range_cmp_fn (fst a) (fst b) < 0 /\ list_sorted (b :: rest) + +let rec fold_sorted_insert (l: list entry) (s: Seq.seq entry) + : Tot (Seq.seq entry) (decreases l) = + match l with + | [] -> s + | hd :: tl -> fold_sorted_insert tl (T.sorted_insert range_cmp hd s) + +#push-options "--fuel 3 --z3rlimit 60" + +/// Helper: in a sorted sequence, if last < k then head < k +let rec sorted_head_lt (s: Seq.seq entry) (k: entry) + : Lemma (requires T.sorted range_cmp s /\ Seq.length s > 0 /\ + range_cmp_fn (fst (Seq.index s (Seq.length s - 1))) (fst k) < 0) + (ensures range_cmp_fn (fst (Seq.head s)) (fst k) < 0) + (decreases Seq.length s) = + if Seq.length s = 1 then () + else sorted_head_lt (Seq.tail s) k + +let rec sorted_insert_snoc (k: entry) (s: Seq.seq entry) + : Lemma (requires T.sorted range_cmp s /\ + (Seq.length s = 0 \/ range_cmp_fn (fst (Seq.index s (Seq.length s - 1))) (fst k) < 0)) + (ensures T.sorted_insert range_cmp k s == Seq.snoc s k) + (decreases Seq.length s) = + if Seq.length s = 0 then + assert (Seq.equal (T.sorted_insert range_cmp k s) (Seq.snoc s k)) + else ( + sorted_head_lt s k; + assert (range_cmp_fn (fst (Seq.head s)) (fst k) < 0); + if Seq.length s = 1 then + assert (Seq.equal (T.sorted_insert range_cmp k s) (Seq.snoc s k)) + else ( + sorted_insert_snoc k (Seq.tail s); + assert (Seq.equal (Seq.snoc s k) (Seq.cons (Seq.head s) (Seq.snoc (Seq.tail s) k))) + ) + ) + +let rec sorted_snoc (s: Seq.seq entry) (k: entry) + : Lemma (requires T.sorted range_cmp s /\ + (Seq.length s = 0 \/ range_cmp_fn (fst (Seq.index s (Seq.length s - 1))) (fst k) < 0)) + (ensures T.sorted range_cmp (Seq.snoc s k)) + (decreases Seq.length s) = + let s' = Seq.snoc s k in + if Seq.length s = 0 then () + else if Seq.length s = 1 then ( + sorted_head_lt s k + ) + else ( + sorted_snoc (Seq.tail s) k; + assert (Seq.index s' 0 == Seq.index s 0); + assert (Seq.index s' 1 == Seq.index s 1); + assert (Seq.equal (Seq.tail s') (Seq.snoc (Seq.tail s) k)) + ) + +let rec fold_sorted_insert_is_append (l: list entry) (s: Seq.seq entry) + : Lemma (requires list_sorted l /\ T.sorted range_cmp s /\ + (Seq.length s = 0 \/ List.Tot.length l = 0 \/ + range_cmp_fn (fst (Seq.index s (Seq.length s - 1))) (fst (List.Tot.hd l)) < 0)) + (ensures fold_sorted_insert l s == Seq.append s (list_to_seq l)) + (decreases l) = + match l with + | [] -> assert (Seq.equal (Seq.append s (list_to_seq [])) s) + | [hd] -> + sorted_insert_snoc hd s; + assert (Seq.equal (Seq.snoc s hd) (Seq.append s (Seq.create 1 hd))); + assert (Seq.equal (list_to_seq [hd]) (Seq.create 1 hd)) + | hd :: tl -> + sorted_insert_snoc hd s; + sorted_snoc s hd; + assert (Seq.index (Seq.snoc s hd) (Seq.length (Seq.snoc s hd) - 1) == hd); + fold_sorted_insert_is_append tl (Seq.snoc s hd); + Seq.append_assoc s (Seq.create 1 hd) (list_to_seq tl); + assert (Seq.equal (Seq.snoc s hd) (Seq.append s (Seq.create 1 hd))); + assert (Seq.equal (Seq.cons hd (list_to_seq tl)) (Seq.append (Seq.create 1 hd) (list_to_seq tl))); + assert (Seq.equal (list_to_seq (hd :: tl)) (Seq.cons hd (list_to_seq tl))) + +let rec list_to_tree_fwd_inorder (l: list entry) (acc: T.avl range unit range_cmp) + : Lemma (ensures T.inorder (list_to_tree_fwd l acc) == fold_sorted_insert l (T.inorder acc)) + (decreases l) = + match l with + | [] -> () + | hd :: tl -> + T.insert_avl_inorder range_cmp acc (fst hd) (snd hd); + T.insert_avl_proof range_cmp acc (fst hd) (snd hd); + list_to_tree_fwd_inorder tl (T.insert_avl range_cmp acc (fst hd) (snd hd)) + +let list_to_tree_fwd_correct (l: list entry) + : Lemma (requires list_sorted l) + (ensures T.inorder (list_to_tree_fwd l T.Leaf) == list_to_seq l) = + list_to_tree_fwd_inorder l T.Leaf; + fold_sorted_insert_is_append l Seq.empty; + assert (Seq.equal (Seq.append Seq.empty (list_to_seq l)) (list_to_seq l)) + +let rec list_to_tree_fwd_avl (l: list entry) (acc: T.avl range unit range_cmp) + : Lemma (ensures T.is_avl range_cmp (list_to_tree_fwd l acc)) + (decreases l) = + match l with + | [] -> () + | hd :: tl -> + T.insert_avl_proof range_cmp acc (fst hd) (snd hd); + list_to_tree_fwd_avl tl (T.insert_avl range_cmp acc (fst hd) (snd hd)) + +let rec list_to_tree_fwd_snoc_gen (l: list entry) (x: entry) (acc: T.tree range unit) + : Lemma (ensures list_to_tree_fwd (List.Tot.append l [x]) acc == + T.insert_avl range_cmp (list_to_tree_fwd l acc) (fst x) (snd x)) + (decreases l) = + match l with + | [] -> () + | hd :: tl -> list_to_tree_fwd_snoc_gen tl x (T.insert_avl range_cmp acc (fst hd) (snd hd)) + +let list_to_tree_fwd_snoc (l: list entry) (x: entry) + : Lemma (ensures list_to_tree_fwd (List.Tot.append l [x]) T.Leaf == + T.insert_avl range_cmp (list_to_tree_fwd l T.Leaf) (fst x) (snd x)) = + list_to_tree_fwd_snoc_gen l x T.Leaf + +#pop-options + +(*** list_to_seq ↔ seq_to_spec bridge ***) + +let rec list_to_seq_spec_eq (l: list entry) + : Lemma (requires list_valid l) + (ensures seq_all_valid (list_to_seq l) /\ + seq_to_spec (list_to_seq l) == list_to_spec l) + (decreases l) = + match l with + | [] -> () + | hd :: tl -> + list_to_seq_spec_eq tl; + let s = list_to_seq (hd :: tl) in + assert (Seq.head s == hd); + assert (Seq.equal (Seq.tail s) (list_to_seq tl)) + +(*** Extract-rebuild bridge lemmas ***) + +/// Lemma 1: sorted_remove of head element gives tail +#push-options "--fuel 2 --ifuel 1 --z3rlimit 30" + +let sorted_remove_head (#k #v: Type) (cmp: T.cmp k) (s: Seq.seq (k & v)) + : Lemma (requires Seq.length s > 0) + (ensures Seq.equal (T.sorted_remove cmp (fst (Seq.head s)) s) (Seq.tail s)) = + let hd = Seq.head s in + assert (cmp (fst hd) (fst hd) == 0); + () + +#pop-options + +/// Lemma 2: delete_min removes minimum (leftmost) element from BST +#push-options "--fuel 2 --ifuel 1 --z3rlimit 40" + +let delete_min_is_tail_inorder (t: T.tree range unit) + : Lemma (requires T.is_bst range_cmp t /\ T.no_dup_tree range_cmp t /\ T.Node? t) + (ensures Seq.equal + (T.inorder (T.delete_avl range_cmp t (fst (T.tree_min t)))) + (Seq.tail (T.inorder t))) = + tree_min_head_inorder t; + T.delete_avl_inorder range_cmp t (fst (T.tree_min t)); + sorted_remove_head range_cmp (T.inorder t); + () + +#pop-options + +/// Lemma 3: list_valid from seq_valid +#push-options "--fuel 2 --ifuel 1 --z3rlimit 30" + +let rec list_valid_from_seq_valid (l: list entry) + : Lemma (requires seq_all_valid (list_to_seq l)) + (ensures list_valid l) + (decreases l) = + match l with + | [] -> () + | hd :: tl -> + // seq_all_valid (list_to_seq (hd :: tl)) + // ==> seq_all_valid (Seq.cons hd (list_to_seq tl)) + // ==> entry_valid (Seq.head (Seq.cons hd ...)) /\ seq_all_valid (Seq.tail (Seq.cons hd ...)) + // ==> entry_valid hd /\ seq_all_valid (list_to_seq tl) + assert (Seq.head (list_to_seq (hd :: tl)) == hd); + assert (Seq.equal (Seq.tail (list_to_seq (hd :: tl))) (list_to_seq tl)); + list_valid_from_seq_valid tl + +#pop-options + +/// Lemma 4: range_map_wf implies list_sorted (length >= 2 case) +#push-options "--fuel 2 --ifuel 1 --z3rlimit 40" + +let rec range_map_wf_list_sorted (l: list entry) + : Lemma (requires list_valid l /\ + Spec.range_map_wf (list_to_spec l) /\ + List.Tot.length l >= 2) + (ensures list_sorted l) + (decreases l) = + match l with + | [] -> () // Impossible: length >= 2 + | [_] -> () // Impossible: length >= 2 + | a :: b :: rest -> + // We need to show: range_cmp_fn a b < 0 /\ list_sorted (b :: rest) + + // From range_map_wf (list_to_spec (a :: b :: rest)): + // Seq.index (list_to_spec l) 0 is range_to_interval a + // Seq.index (list_to_spec l) 1 is range_to_interval b + let spec_seq = list_to_spec l in + assert (Seq.length spec_seq >= 2); + + let a_spec = range_to_interval (fst a) in + let b_spec = range_to_interval (fst b) in + + assert (Seq.index spec_seq 0 == a_spec); + assert (Seq.index spec_seq 1 == b_spec); + + // range_map_wf says: separated (index 0) (index 1) + // separated means: high a_spec < b_spec.low + assert (Spec.high a_spec < b_spec.Spec.low); + assert (Spec.high a_spec == a_spec.Spec.low + a_spec.Spec.count); + assert (a_spec.Spec.low == SZ.v (fst a).start); + assert (a_spec.Spec.count == SZ.v (fst a).len); + assert (b_spec.Spec.low == SZ.v (fst b).start); + + assert (SZ.v (fst a).start + SZ.v (fst a).len < SZ.v (fst b).start); + assert (SZ.v (fst a).len > 0); + assert (SZ.v (fst a).start < SZ.v (fst b).start); + + assert (range_cmp_fn (fst a) (fst b) == -1); + assert (range_cmp_fn (fst a) (fst b) < 0); + + // Now prove list_sorted (b :: rest) + match rest with + | [] -> () // list_sorted [b] is True by definition + | c :: _ -> + // Need to prove list_sorted (b :: c :: ...) + // This requires: Spec.range_map_wf (list_to_spec (b :: c :: ...)) + // Which follows from Spec.range_map_wf (list_to_spec (a :: b :: c :: ...)) + + // range_map_wf (list_to_spec l) implies range_map_wf (Seq.tail (list_to_spec l)) + assert (Seq.equal (Seq.tail spec_seq) (list_to_spec (b :: rest))); + + // Use induction + range_map_wf_list_sorted (b :: rest) + +#pop-options + +/// Lemma 5: range_map_wf implies list_sorted (all cases) +#push-options "--fuel 2 --ifuel 1 --z3rlimit 40" + +let range_map_wf_list_sorted_full (l: list entry) + : Lemma (requires list_valid l /\ Spec.range_map_wf (list_to_spec l)) + (ensures list_sorted l) = + match l with + | [] -> () // list_sorted [] is True + | [_] -> () // list_sorted [_] is True + | _ :: _ :: _ -> + // Length >= 2, use the main lemma + range_map_wf_list_sorted l + +#pop-options + +#push-options "--fuel 2 --ifuel 1 --z3rlimit 40" + +/// Bridge: list_sorted implies T.sorted_strict on list_to_seq +let rec list_sorted_to_sorted_strict (l: list entry) + : Lemma (requires list_sorted l) + (ensures T.sorted_strict range_cmp (list_to_seq l)) + (decreases l) = + match l with + | [] -> () + | [_] -> () + | a :: b :: rest -> + list_sorted_to_sorted_strict (b :: rest); + assert (Seq.head (list_to_seq l) == a); + assert (Seq.index (list_to_seq l) 1 == b); + assert (Seq.equal (Seq.tail (list_to_seq l)) (list_to_seq (b :: rest))) + +#pop-options + +(*** Validity preservation lemmas ***) + +#push-options "--fuel 2 --ifuel 1 --z3rlimit 30" + +/// M2: Removing an element from a valid seq preserves validity +let rec seq_all_valid_sorted_remove (cmp: T.cmp range) (k: range) + (s: Seq.seq entry) + : Lemma (requires seq_all_valid s) + (ensures seq_all_valid (T.sorted_remove cmp k s)) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else if cmp (fst (Seq.head s)) k = 0 then () + else ( + seq_all_valid_sorted_remove cmp k (Seq.tail s); + let result = T.sorted_remove cmp k s in + assert (Seq.length result > 0); + assert (Seq.head result == Seq.head s); + assert (Seq.equal (Seq.tail result) (T.sorted_remove cmp k (Seq.tail s))) + ) + +/// M3: list_valid implies seq_all_valid on list_to_seq +let rec list_valid_to_seq_all_valid (l: list entry) + : Lemma (requires list_valid l) + (ensures seq_all_valid (list_to_seq l)) + (decreases l) = + match l with + | [] -> () + | hd :: tl -> + list_valid_to_seq_all_valid tl; + assert (Seq.head (list_to_seq (hd :: tl)) == hd); + assert (Seq.equal (Seq.tail (list_to_seq (hd :: tl))) (list_to_seq tl)) + +#pop-options + +(*** Extraction/rebuild helper lemmas ***) + +#push-options "--fuel 2 --ifuel 1 --z3rlimit 60" + +/// Helper: cons head (tail s) == s for non-empty sequences +let cons_head_tail (s: Seq.seq entry) + : Lemma (requires Seq.length s > 0) + (ensures Seq.equal (Seq.cons (Seq.head s) (Seq.tail s)) s) = () + +/// list_to_seq distributes over list append +let rec list_to_seq_append (l1 l2: list entry) + : Lemma (ensures Seq.equal (list_to_seq (List.Tot.append l1 l2)) + (Seq.append (list_to_seq l1) (list_to_seq l2))) + (decreases l1) = + match l1 with + | [] -> () + | hd :: tl -> list_to_seq_append tl l2 + +/// After extracting min and prepending, the append invariant is maintained +let extract_step_invariant + (acc_old: list entry) + (ft_cur: T.tree range unit) + (original_inorder: Seq.seq entry) + : Lemma (requires + T.is_bst range_cmp ft_cur /\ T.no_dup_tree range_cmp ft_cur /\ T.Node? ft_cur /\ + Seq.append (list_to_seq (List.Tot.rev acc_old)) (T.inorder ft_cur) == original_inorder) + (ensures ( + let min = T.tree_min ft_cur in + let ft_new = T.delete_avl range_cmp ft_cur (fst min) in + Seq.append (list_to_seq (List.Tot.rev (min :: acc_old))) (T.inorder ft_new) == original_inorder)) + = let min = T.tree_min ft_cur in + let ft_new = T.delete_avl range_cmp ft_cur (fst min) in + delete_min_is_tail_inorder ft_cur; + tree_min_head_inorder ft_cur; + List.Tot.Properties.rev_append [min] acc_old; + assert (List.Tot.rev (min :: acc_old) == List.Tot.append (List.Tot.rev acc_old) [min]); + list_to_seq_append (List.Tot.rev acc_old) [min]; + assert (Seq.equal (list_to_seq [min]) (Seq.create 1 min)); + Seq.append_assoc (list_to_seq (List.Tot.rev acc_old)) (Seq.create 1 min) (T.inorder ft_new); + cons_head_tail (T.inorder ft_cur); + assert (Seq.equal (Seq.append (Seq.create 1 min) (T.inorder ft_new)) (T.inorder ft_cur)) + +#pop-options + +(*** Range set type and predicate ***) + +let range_map_t = B.box (AVL.tree_t range unit) + +/// Tracks whether the extraction loop should continue. +/// When b = false, the tree must be empty (Leaf). +[@@no_mkeys] +let extract_cont (b: bool) (ft: T.tree range unit) : slprop = + pure (b == not (T.is_empty ft)) + +let is_range_map (rs: range_map_t) (repr: Seq.seq Spec.interval) : slprop = + exists* (ct: AVL.tree_t range unit) (t: T.tree range unit). + B.pts_to rs ct ** + AVL.is_tree ct t ** + pure (T.is_bst range_cmp t /\ + T.no_dup_tree range_cmp t /\ + tree_wf t /\ + tree_to_spec t == repr /\ + Spec.range_map_wf repr) + +(*** Pulse operations ***) + +fn range_map_create () + requires emp + returns rs: range_map_t + ensures is_range_map rs (Seq.empty #Spec.interval) +{ + let ct = AVL.create range unit; + let rs = B.alloc ct; + fold (is_range_map rs (Seq.empty #Spec.interval)); + rs +} + +fn range_map_free (rs: range_map_t) (#repr: erased (Seq.seq Spec.interval)) + requires is_range_map rs repr + ensures emp +{ + unfold is_range_map; + with ct t. _; + let h = B.op_Bang rs; + AVL.free h; + B.free rs +} + +/// Get contiguous coverage length from a given base offset +fn range_map_contiguous_from (rs: range_map_t) (base: SZ.t) (#repr: erased (Seq.seq Spec.interval)) + requires is_range_map rs repr + returns n: SZ.t + ensures is_range_map rs repr ** pure (SZ.v n == Spec.contiguous_from repr (SZ.v base)) +{ + unfold is_range_map; + with ct t. _; + let h = B.op_Bang rs; + let b = AVL.is_empty h; + if b { + fold (is_range_map rs repr); + 0sz + } else { + let min_pair = AVL.find_min range_cmp h; + tree_min_head_inorder t; + seq_to_spec_head (T.inorder t); + let r = fst min_pair; + let r_high = SZ.add r.start r.len; + if (SZ.lte r.start base && SZ.lt base r_high) { + fold (is_range_map rs repr); + SZ.sub r_high base + } else { + fold (is_range_map rs repr); + 0sz + } + } +} + +fn range_map_max_endpoint (rs: range_map_t) (#repr: erased (Seq.seq Spec.interval)) + requires is_range_map rs repr + returns n: SZ.t + ensures is_range_map rs repr ** pure (SZ.v n == Spec.range_map_max_endpoint repr) +{ + unfold is_range_map; + with ct t. _; + let h = B.op_Bang rs; + let b = AVL.is_empty h; + if b { + fold (is_range_map rs repr); + 0sz + } else { + let max_pair = AVL.find_max range_cmp h; + tree_max_last_spec t; + let r = fst max_pair; + let result = SZ.add r.start r.len; + fold (is_range_map rs repr); + result + } +} + +#push-options "--z3rlimit 60 --fuel 2 --ifuel 1" + +fn range_map_add (rs: range_map_t) (offset: SZ.t) (len: SZ.t{SZ.v len > 0}) + (#repr: erased (Seq.seq Spec.interval)) + requires is_range_map rs repr ** pure (SZ.fits (SZ.v offset + SZ.v len)) + ensures is_range_map rs (Spec.add_range repr (SZ.v offset) (SZ.v len)) +{ + unfold is_range_map; + with ct t. _; + + let h = B.op_Bang rs; + + let mut acc: list entry = []; + let mut tree = h; + let initial_empty = AVL.is_empty h; + + assert (pure (Seq.equal (Seq.append (list_to_seq (List.Tot.rev ([] #entry))) (T.inorder (G.reveal t))) (T.inorder (G.reveal t)))); + + fold (extract_cont (not initial_empty) (G.reveal t)); + + while ( + let tree_ref = !tree; + let acc_ref = !acc; + with b_old ft_cur. unfold (extract_cont b_old ft_cur); + let is_emp = AVL.is_empty tree_ref; + + if (not is_emp) { + let min = AVL.find_min range_cmp tree_ref; + tree_min_head_inorder ft_cur; + + let tree' = AVL.delete_avl range_cmp tree_ref (fst min); + T.delete_avl_preserves_bst range_cmp ft_cur (fst min); + T.delete_avl_preserves_no_dup_tree range_cmp ft_cur (fst min); + T.delete_avl_inorder range_cmp ft_cur (fst min); + seq_all_valid_sorted_remove range_cmp (fst min) (T.inorder ft_cur); + delete_min_is_tail_inorder ft_cur; + + acc := min :: acc_ref; + extract_step_invariant acc_ref ft_cur (T.inorder (G.reveal t)); + + tree := tree'; + + let e = AVL.is_empty tree'; + fold (extract_cont (not e) (T.delete_avl range_cmp ft_cur (fst (T.tree_min ft_cur)))); + not e + } else { + fold (extract_cont false ft_cur); + false + } + ) + invariant b. exists* tree_val acc_val ft_cur. + R.pts_to tree tree_val ** + R.pts_to acc acc_val ** + AVL.is_tree tree_val ft_cur ** + extract_cont b ft_cur ** + pure (T.is_bst range_cmp ft_cur /\ + T.no_dup_tree range_cmp ft_cur /\ + tree_wf ft_cur /\ + Seq.append (list_to_seq (List.Tot.rev acc_val)) (T.inorder ft_cur) == T.inorder (G.reveal t)) + { () }; + + with tree_val acc_val ft_cur. _; + unfold (extract_cont false ft_cur); + + // false == not (T.is_empty ft_cur) → T.is_empty ft_cur → ft_cur == Leaf + Seq.append_empty_r (list_to_seq (List.Tot.rev acc_val)); + let tree_final = !tree; + AVL.free tree_final; + + let acc_final = !acc; + let extracted = List.Tot.rev acc_final; + + assert (pure (list_to_seq extracted == T.inorder (G.reveal t))); + list_valid_from_seq_valid extracted; + list_to_seq_spec_eq extracted; + + Spec.add_range_wf repr (SZ.v offset) (SZ.v len); + let transformed = add_range_impl extracted offset len; + + range_map_wf_list_sorted_full transformed; + list_valid_to_seq_all_valid transformed; + list_to_tree_fwd_correct transformed; + list_to_tree_fwd_avl transformed T.Leaf; + + let mut new_tree = AVL.create range unit; + let mut remaining = transformed; + let mut processed_add: list entry = []; + + while ( + let r = !remaining; + Cons? r + ) + invariant exists* new_tree_val remaining_val ft_new proc_val. + R.pts_to new_tree new_tree_val ** + R.pts_to remaining remaining_val ** + R.pts_to processed_add proc_val ** + AVL.is_tree new_tree_val ft_new ** + pure (ft_new == list_to_tree_fwd proc_val T.Leaf /\ + T.is_avl range_cmp ft_new /\ + List.Tot.append proc_val remaining_val == transformed) + { + with new_tree_val remaining_val ft_new proc_val. _; + + let new_tree_curr = !new_tree; + let remaining_curr = !remaining; + let proc_curr = !processed_add; + + let Cons hd tl = remaining_curr; + + let new_tree' = AVL.insert_avl range_cmp new_tree_curr (fst hd) (snd hd); + T.insert_avl_proof range_cmp ft_new (fst hd) (snd hd); + + list_to_tree_fwd_snoc proc_curr hd; + List.Tot.Properties.append_assoc proc_curr [hd] tl; + + remaining := tl; + new_tree := new_tree'; + processed_add := List.Tot.append proc_curr [hd] + }; + + with new_tree_val remaining_val ft_new proc_val. _; + + assert (pure (List.Tot.append proc_val [] == transformed)); + List.Tot.Properties.append_l_nil proc_val; + + let final_tree = !new_tree; + B.op_Colon_Equals rs final_tree; + + // is_bst from is_avl + list_to_tree_fwd_correct transformed; + list_to_tree_fwd_avl transformed T.Leaf; + + // no_dup_tree from sorted_strict + bst + list_sorted_to_sorted_strict transformed; + T.bst_strict_sorted_no_dup range_cmp (list_to_tree_fwd transformed T.Leaf); + + // tree_wf: seq_all_valid (T.inorder ft_new) + list_valid_to_seq_all_valid transformed; + + // tree_to_spec: chain through list_to_spec + list_to_seq_spec_eq transformed; + + fold (is_range_map rs (Spec.add_range repr (SZ.v offset) (SZ.v len))) +} + +#pop-options diff --git a/lib/pulse/lib/Pulse.Lib.RangeVec.fst b/lib/pulse/lib/Pulse.Lib.RangeVec.fst new file mode 100644 index 000000000..0115fc7d9 --- /dev/null +++ b/lib/pulse/lib/Pulse.Lib.RangeVec.fst @@ -0,0 +1,1021 @@ +module Pulse.Lib.RangeVec + +/// Range tracker backed by a sorted vector of non-overlapping intervals. +/// Drop-in replacement for RangeMap (AVL-based) with better cache locality +/// and clean KaRaMeL extraction (no .fsti -> no monomorphization bug). + +#lang-pulse + +open Pulse.Lib.Pervasives + +module SZ = FStar.SizeT +module Seq = FStar.Seq +module Spec = Pulse.Lib.RangeMap.Spec +module V = Pulse.Lib.Vector +module B = Pulse.Lib.Box +module G = FStar.Ghost +module R = Pulse.Lib.Reference + +(*** Platform and bounds ***) + +/// 64-bit platform assumption — standard for Pulse/SizeT code +assume val platform_is_64bit : squash SZ.fits_u64 + +/// Upper bound on range vector entries. +/// Strictly greater than the maximum number of separated intervals that can +/// fit in a CircularBuffer with alloc_length ≤ pow2_63 (which is ≤ pow2_62). +/// The bound ensures vector capacity doubling is always representable. +assume val max_range_vec_entries : n:pos{n == pow2 62 + 1} + +(*** Types ***) + +noeq type range = { start: SZ.t; len: SZ.t } + +noextract +let range_valid (r: range) : prop = + SZ.v r.len > 0 /\ + SZ.fits (SZ.v r.start + SZ.v r.len) + +noextract +let range_to_interval (r: range) + : Pure Spec.interval (requires range_valid r) (ensures fun _ -> True) = + { Spec.low = SZ.v r.start; Spec.count = SZ.v r.len } + +let default_range : range = { start = 0sz; len = 1sz } + +noextract +let rec seq_all_valid (s: Seq.seq range) + : Tot prop (decreases Seq.length s) = + if Seq.length s = 0 then True + else range_valid (Seq.head s) /\ seq_all_valid (Seq.tail s) + +noextract +let rec seq_all_valid_index (s: Seq.seq range) (i: nat) + : Lemma (requires seq_all_valid s /\ i < Seq.length s) + (ensures range_valid (Seq.index s i)) + (decreases Seq.length s) = + if i = 0 then () + else seq_all_valid_index (Seq.tail s) (i - 1) + +noextract +let rec seq_to_spec (s: Seq.seq range) + : Pure (Seq.seq Spec.interval) + (requires seq_all_valid s) + (ensures fun r -> Seq.length r == Seq.length s) + (decreases Seq.length s) = + if Seq.length s = 0 then Seq.empty + else Seq.cons (range_to_interval (Seq.head s)) (seq_to_spec (Seq.tail s)) + +#push-options "--fuel 2 --ifuel 1" + +noextract +let rec seq_to_spec_index (s: Seq.seq range) (i: nat) + : Lemma (requires seq_all_valid s /\ i < Seq.length s) + (ensures range_valid (Seq.index s i) /\ + Seq.index (seq_to_spec s) i == range_to_interval (Seq.index s i)) + (decreases Seq.length s) = + seq_all_valid_index s i; + if i = 0 then () + else seq_to_spec_index (Seq.tail s) (i - 1) + +noextract +let rec seq_all_valid_forall (s: Seq.seq range) + : Lemma (requires seq_all_valid s) + (ensures forall (k:nat). k < Seq.length s ==> range_valid (Seq.index s k)) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else begin + seq_all_valid_forall (Seq.tail s); + let aux (k:nat{k < Seq.length s}) + : Lemma (range_valid (Seq.index s k)) = + seq_all_valid_index s k + in + Classical.forall_intro aux + end + +#pop-options + +#push-options "--fuel 3 --ifuel 2 --z3rlimit 20" + +(* Helper lemma: seq_all_valid for snoc *) +noextract +let rec seq_all_valid_snoc (s: Seq.seq range) (r: range) + : Lemma (requires seq_all_valid s /\ range_valid r) + (ensures seq_all_valid (Seq.snoc s r)) + (decreases Seq.length s) = + if Seq.length s = 0 then () + else begin + seq_all_valid_snoc (Seq.tail s) r; + // Help SMT understand the structure + assert (Seq.length (Seq.snoc s r) > 0); + assert (Seq.head (Seq.snoc s r) == Seq.head s); + assert (Seq.tail (Seq.snoc s r) `Seq.equal` Seq.snoc (Seq.tail s) r) + end + +(* Lemma 1: seq_to_spec commutes with snoc *) +noextract +let rec seq_to_spec_snoc (s: Seq.seq range) (r: range) + : Lemma (requires seq_all_valid s /\ range_valid r) + (ensures seq_all_valid (Seq.snoc s r) /\ + seq_to_spec (Seq.snoc s r) == Seq.snoc (seq_to_spec s) (range_to_interval r)) + (decreases Seq.length s) = + seq_all_valid_snoc s r; + if Seq.length s = 0 then begin + Seq.lemma_eq_intro (Seq.snoc s r) (Seq.create 1 r); + Seq.lemma_eq_intro (seq_to_spec (Seq.snoc s r)) (Seq.snoc (seq_to_spec s) (range_to_interval r)) + end else begin + seq_to_spec_snoc (Seq.tail s) r; + let s' = Seq.snoc s r in + assert (Seq.head s' == Seq.head s); + assert (Seq.tail s' `Seq.equal` Seq.snoc (Seq.tail s) r); + let a = range_to_interval (Seq.head s) in + let b = seq_to_spec (Seq.tail s) in + let c = range_to_interval r in + Seq.lemma_eq_intro (Seq.cons a (Seq.snoc b c)) (Seq.snoc (Seq.cons a b) c) + end + +(* Helper lemma: seq_all_valid for append *) +noextract +let rec seq_all_valid_append (s1 s2: Seq.seq range) + : Lemma (requires seq_all_valid s1 /\ seq_all_valid s2) + (ensures seq_all_valid (Seq.append s1 s2)) + (decreases Seq.length s1) = + if Seq.length s1 = 0 then + Seq.lemma_eq_intro (Seq.append s1 s2) s2 + else begin + seq_all_valid_append (Seq.tail s1) s2; + let s' = Seq.append s1 s2 in + assert (Seq.length s' > 0); + Seq.lemma_eq_intro (Seq.tail s') (Seq.append (Seq.tail s1) s2); + assert (Seq.head s' == Seq.head s1) + end + +(* Lemma 2: seq_to_spec commutes with append *) +noextract +let rec seq_to_spec_append (s1 s2: Seq.seq range) + : Lemma (requires seq_all_valid s1 /\ seq_all_valid s2) + (ensures seq_all_valid (Seq.append s1 s2) /\ + seq_to_spec (Seq.append s1 s2) == Seq.append (seq_to_spec s1) (seq_to_spec s2)) + (decreases Seq.length s1) = + seq_all_valid_append s1 s2; + if Seq.length s1 = 0 then begin + Seq.lemma_eq_intro (Seq.append s1 s2) s2; + Seq.lemma_eq_intro (seq_to_spec (Seq.append s1 s2)) (Seq.append (seq_to_spec s1) (seq_to_spec s2)) + end else begin + seq_to_spec_append (Seq.tail s1) s2; + let s' = Seq.append s1 s2 in + assert (Seq.head s' == Seq.head s1); + Seq.lemma_eq_intro (Seq.tail s') (Seq.append (Seq.tail s1) s2); + // cons a (append b c) == append (cons a b) c + let a = range_to_interval (Seq.head s1) in + let b = seq_to_spec (Seq.tail s1) in + let c = seq_to_spec s2 in + Seq.lemma_eq_intro (Seq.cons a (Seq.append b c)) (Seq.append (Seq.cons a b) c) + end + +(* Lemma 3: seq_all_valid preserves through slice *) +noextract +let rec seq_all_valid_slice (s: Seq.seq range) (i j: nat) + : Lemma (requires seq_all_valid s /\ i <= j /\ j <= Seq.length s) + (ensures seq_all_valid (Seq.slice s i j)) + (decreases Seq.length s) = + if i >= j then () + else if i = 0 then begin + if j = 0 then () + else if j = Seq.length s then () + else seq_all_valid_slice (Seq.tail s) 0 (j - 1) + end + else seq_all_valid_slice (Seq.tail s) (i - 1) (j - 1) + +(* Lemma 4: seq_to_spec commutes with slice *) +noextract +let seq_to_spec_slice (s: Seq.seq range) (i j: nat) + : Lemma (requires seq_all_valid s /\ i <= j /\ j <= Seq.length s) + (ensures seq_all_valid (Seq.slice s i j) /\ + seq_to_spec (Seq.slice s i j) == Seq.slice (seq_to_spec s) i j) = + seq_all_valid_slice s i j; + let sliced_range = Seq.slice s i j in + let sliced_spec = seq_to_spec sliced_range in + let spec_sliced = Seq.slice (seq_to_spec s) i j in + let aux (k: nat{k < Seq.length sliced_spec}) + : Lemma (Seq.index sliced_spec k == Seq.index spec_sliced k) = + seq_to_spec_index sliced_range k; + seq_all_valid_index s (i + k); + seq_to_spec_index s (i + k) + in + Classical.forall_intro aux; + Seq.lemma_eq_intro sliced_spec spec_sliced + +#pop-options + +(*** Predicate ***) + +let range_vec_t = V.vector range + +let is_range_vec (rv: range_vec_t) (repr: Seq.seq Spec.interval) : slprop = + exists* (s: Seq.seq range) (cap: SZ.t). + V.is_vector rv s cap ** + pure (seq_all_valid s /\ + seq_to_spec s == repr /\ + Spec.range_map_wf repr /\ + Seq.length s <= max_range_vec_entries /\ + (Seq.length s < SZ.v cap \/ SZ.fits (SZ.v cap + SZ.v cap))) + +(*** Create / Free ***) + +fn range_vec_create () + requires emp + returns rv: range_vec_t + ensures is_range_vec rv (Seq.empty #Spec.interval) +{ + let rv = V.create default_range 1sz; + let _ = V.pop_back rv; + with cap'. _; + fold (is_range_vec rv (Seq.empty #Spec.interval)); + rv +} + +fn range_vec_free (rv: range_vec_t) (#repr: erased (Seq.seq Spec.interval)) + requires is_range_vec rv repr + ensures emp +{ + unfold is_range_vec; + with s cap. _; + V.free rv +} + +(*** Queries ***) + +fn range_vec_contiguous_from (rv: range_vec_t) (base: SZ.t) (#repr: erased (Seq.seq Spec.interval)) + requires is_range_vec rv repr + returns n: SZ.t + ensures is_range_vec rv repr ** pure (SZ.v n == Spec.contiguous_from repr (SZ.v base)) +{ + unfold is_range_vec; + with s cap. _; + let sz = V.size rv; + if (SZ.eq sz 0sz) { + fold (is_range_vec rv repr); + 0sz + } else { + let first = V.at rv 0sz; + seq_to_spec_index s 0; + let r_high = SZ.add first.start first.len; + if (SZ.lte first.start base && SZ.lt base r_high) { + fold (is_range_vec rv repr); + SZ.sub r_high base + } else { + fold (is_range_vec rv repr); + 0sz + } + } +} + +fn range_vec_max_endpoint (rv: range_vec_t) (#repr: erased (Seq.seq Spec.interval)) + requires is_range_vec rv repr + returns n: SZ.t + ensures is_range_vec rv repr ** pure (SZ.v n == Spec.range_map_max_endpoint repr) +{ + unfold is_range_vec; + with s cap. _; + let sz = V.size rv; + if (SZ.eq sz 0sz) { + fold (is_range_vec rv repr); + 0sz + } else { + let last_idx = SZ.sub sz 1sz; + let last = V.at rv last_idx; + seq_to_spec_index s (SZ.v last_idx); + let result = SZ.add last.start last.len; + fold (is_range_vec rv repr); + result + } +} + + +(*** Add range — core operation ***) + +noextract +let seq_insert (#a:Type) (s: Seq.seq a) (i: nat) (x: a) : Seq.seq a = + if i <= Seq.length s then + Seq.append (Seq.slice s 0 i) (Seq.cons x (Seq.slice s i (Seq.length s))) + else s + +noextract +let seq_remove (#a:Type) (s: Seq.seq a) (i: nat) (count: nat) : Seq.seq a = + if i + count <= Seq.length s then + Seq.append (Seq.slice s 0 i) (Seq.slice s (i + count) (Seq.length s)) + else s + +(* Bridge: pointwise shift result matches seq_remove *) +noextract +let shift_to_seq_remove (#a:Type) (s s_cur: Seq.seq a) (i count: nat) + : Lemma (requires i + count <= Seq.length s /\ count > 0 /\ + Seq.length s_cur == Seq.length s /\ + (forall (k:nat). k < i ==> Seq.index s_cur k == Seq.index s k) /\ + (forall (k:nat). k >= i /\ k < Seq.length s - count ==> + Seq.index s_cur k == Seq.index s (k + count)) /\ + (forall (k:nat). k >= Seq.length s - count /\ k < Seq.length s ==> + Seq.index s_cur k == Seq.index s k)) + (ensures Seq.slice s_cur 0 (Seq.length s - count) == seq_remove s i count) = + let dst_end = Seq.length s - count in + let candidate = Seq.slice s_cur 0 dst_end in + let target = seq_remove s i count in + Seq.lemma_eq_intro candidate target + +(* Bridge: pointwise shift-right result matches seq_insert *) +noextract +let shift_to_seq_insert (#a:Type) (s s_cur: Seq.seq a) (i: nat) (r: a) + : Lemma (requires i < Seq.length s /\ + Seq.length s_cur == Seq.length s + 1 /\ + (forall (m:nat). m < i ==> Seq.index s_cur m == Seq.index s m) /\ + Seq.index s_cur i == r /\ + (forall (m:nat). m > i /\ m < Seq.length s_cur ==> + Seq.index s_cur m == Seq.index s (m - 1))) + (ensures s_cur == seq_insert s i r) = + Seq.lemma_eq_intro s_cur (seq_insert s i r) + +(* Bridge: shift-right + set produces seq_insert. Call BEFORE V.set. *) +noextract +let shift_set_is_seq_insert (#a:Type) (s s_shifted: Seq.seq a) (i: nat) (r: a) + : Lemma (requires i < Seq.length s /\ + Seq.length s_shifted == Seq.length s + 1 /\ + (forall (m:nat). m < i ==> Seq.index s_shifted m == Seq.index s m) /\ + (forall (m:nat). m > i /\ m < Seq.length s_shifted ==> + Seq.index s_shifted m == Seq.index s (m - 1))) + (ensures Seq.upd s_shifted i r == seq_insert s i r) = + Seq.lemma_eq_intro (Seq.upd s_shifted i r) (seq_insert s i r) + +(* Bridge: snoc is seq_insert at end *) +noextract +let snoc_is_seq_insert (#a:Type) (s: Seq.seq a) (r: a) (i: nat) + : Lemma (requires i >= Seq.length s /\ i <= Seq.length s) + (ensures Seq.snoc s r == seq_insert s i r) = + Seq.lemma_eq_intro (Seq.snoc s r) (seq_insert s i r) + +#push-options "--fuel 2 --ifuel 1 --z3rlimit 20" + +(* seq_all_valid of seq_insert *) +noextract +let seq_all_valid_insert (s: Seq.seq range) (i: nat) (r: range) + : Lemma (requires seq_all_valid s /\ range_valid r /\ i <= Seq.length s) + (ensures seq_all_valid (seq_insert s i r)) = + seq_all_valid_slice s 0 i; + seq_all_valid_slice s i (Seq.length s); + let suffix = Seq.slice s i (Seq.length s) in + let cons_r = Seq.cons r suffix in + Seq.lemma_eq_intro (Seq.tail cons_r) suffix; + seq_all_valid_append (Seq.slice s 0 i) cons_r + +(* seq_to_spec of seq_insert — relates to Seq operations on spec level *) +noextract +let seq_to_spec_insert (s: Seq.seq range) (i: nat) (r: range) + : Lemma (requires seq_all_valid s /\ range_valid r /\ i <= Seq.length s) + (ensures seq_all_valid (seq_insert s i r) /\ + seq_to_spec (seq_insert s i r) == + Seq.append (Seq.slice (seq_to_spec s) 0 i) + (Seq.cons (range_to_interval r) + (Seq.slice (seq_to_spec s) i (Seq.length s)))) = + seq_all_valid_insert s i r; + seq_all_valid_slice s 0 i; + seq_all_valid_slice s i (Seq.length s); + seq_to_spec_slice s 0 i; + seq_to_spec_slice s i (Seq.length s); + let prefix = Seq.slice s 0 i in + let suffix = Seq.slice s i (Seq.length s) in + let cons_r = Seq.cons r suffix in + Seq.lemma_eq_intro (Seq.tail cons_r) suffix; + seq_to_spec_append prefix cons_r + +(* Bridge: capacity condition after insert. + From original |s|= 2 then + // sz < cap → sz+1 <= cap < 2*cap = cap' ✓ + // sz == cap (via fits(cap+cap)) → sz+1 = cap+1 < 2*cap for cap≥2 ✓ + assert (sz + 1 < cap + cap) + else + SZ.fits_at_least_16 4 // cap == 1: SZ.fits(4) ✓ + end else begin + // No resize: cap' == cap. + if sz + 1 < cap then () + else begin + // sz + 1 == cap. cap = sz + 1 <= max_range_vec_entries <= pow2 62. + // cap + cap <= 2 * max_range_vec_entries <= 2 * pow2 62 = pow2 63 < pow2 64. + // With fits_u64: SZ.fits(cap + cap). + SZ.fits_u64_implies_fits (cap + cap) + end + end + +(* Forall highs-below-offset lifts from ranges to spec *) +noextract +let forall_high_below_to_spec (s: Seq.seq range) (n: nat) (bound: nat) + : Lemma (requires seq_all_valid s /\ n <= Seq.length s /\ + (forall (k:nat). k < n ==> SZ.v (Seq.index s k).start + SZ.v (Seq.index s k).len < bound)) + (ensures (forall (k:nat). k < n ==> Spec.high (Seq.index (seq_to_spec s) k) < bound)) = + let aux (k: nat{k < n}) + : Lemma (Spec.high (Seq.index (seq_to_spec s) k) < bound) = + seq_to_spec_index s k; + seq_all_valid_index s k + in + Classical.forall_intro aux + +#pop-options + +#push-options "--fuel 2 --ifuel 1 --z3rlimit 30" + +(* Overlap forall lifts from ranges to spec *) +noextract +let forall_overlap_to_spec (s: Seq.seq range) (iv j: nat) (mh: nat) + : Lemma (requires seq_all_valid s /\ iv <= j /\ j <= Seq.length s /\ + (forall (k:nat). k >= iv /\ k < j ==> + SZ.v (Seq.index s k).start <= mh)) + (ensures (forall (k:nat). k >= iv /\ k < j ==> + mh >= (Seq.index (seq_to_spec s) k).Spec.low)) = + let aux (k: nat{k >= iv /\ k < j}) + : Lemma (mh >= (Seq.index (seq_to_spec s) k).Spec.low) = + seq_to_spec_index s k; + seq_all_valid_index s k + in + Classical.forall_intro aux + +#pop-options + +#push-options "--fuel 2 --ifuel 1 --z3rlimit 20" + +(* seq_all_valid of seq_remove *) +noextract +let seq_all_valid_remove (s: Seq.seq range) (i count: nat) + : Lemma (requires seq_all_valid s /\ i + count <= Seq.length s) + (ensures seq_all_valid (seq_remove s i count)) = + seq_all_valid_slice s 0 i; + seq_all_valid_slice s (i + count) (Seq.length s); + seq_all_valid_append (Seq.slice s 0 i) (Seq.slice s (i + count) (Seq.length s)) + +(* seq_to_spec of seq_remove *) +noextract +let seq_to_spec_remove (s: Seq.seq range) (i count: nat) + : Lemma (requires seq_all_valid s /\ i + count <= Seq.length s) + (ensures seq_all_valid (seq_remove s i count) /\ + seq_to_spec (seq_remove s i count) == + Seq.append (Seq.slice (seq_to_spec s) 0 i) + (Seq.slice (seq_to_spec s) (i + count) (Seq.length s))) = + seq_all_valid_remove s i count; + seq_all_valid_slice s 0 i; + seq_all_valid_slice s (i + count) (Seq.length s); + seq_to_spec_slice s 0 i; + seq_to_spec_slice s (i + count) (Seq.length s); + seq_to_spec_append (Seq.slice s 0 i) (Seq.slice s (i + count) (Seq.length s)) + +(* Key bridge: Seq.upd at iv followed by seq_remove of [iv+1..j) gives merge result *) +noextract +let seq_upd_remove_spec (s: Seq.seq range) (iv j: nat) (r: range) + : Lemma (requires seq_all_valid s /\ iv < Seq.length s /\ j > iv /\ j <= Seq.length s /\ range_valid r) + (ensures (let result = + (if j > iv + 1 + then seq_remove (Seq.upd s iv r) (iv + 1) (j - iv - 1) + else Seq.upd s iv r) in + seq_all_valid result /\ + seq_to_spec result == + Seq.append (Seq.slice (seq_to_spec s) 0 iv) + (Seq.append (Seq.create 1 (range_to_interval r)) + (Seq.slice (seq_to_spec s) j (Seq.length s))))) = + let n = Seq.length s in + let s' = Seq.upd s iv r in + // upd preserves validity + let upd_valid () : Lemma (seq_all_valid s') = + let prefix = Seq.slice s 0 iv in + let suffix = Seq.slice s (iv + 1) n in + seq_all_valid_slice s 0 iv; + seq_all_valid_slice s (iv + 1) n; + Seq.lemma_eq_intro s' (Seq.append prefix (Seq.append (Seq.create 1 r) suffix)); + Seq.lemma_eq_intro (Seq.tail (Seq.cons r suffix)) suffix; + seq_all_valid_append (Seq.create 1 r) suffix; + Seq.lemma_eq_intro (Seq.append (Seq.create 1 r) suffix) (Seq.cons r suffix); + seq_all_valid_append prefix (Seq.cons r suffix) + in + upd_valid (); + if j > iv + 1 then begin + let count = j - iv - 1 in + // seq_remove s' (iv+1) count == append (slice s' 0 (iv+1)) (slice s' j n) + // slice s' 0 (iv+1) == append (slice s 0 iv) (create 1 r) + Seq.lemma_eq_intro (Seq.slice s' 0 (iv + 1)) + (Seq.append (Seq.slice s 0 iv) (Seq.create 1 r)); + // slice s' j n == slice s j n + Seq.lemma_eq_intro (Seq.slice s' j n) (Seq.slice s j n); + // So: seq_remove s' (iv+1) count == append (slice s 0 iv) (append (create 1 r) (slice s j n)) + let result = seq_remove s' (iv + 1) count in + Seq.lemma_eq_intro result + (Seq.append (Seq.slice s 0 iv) + (Seq.append (Seq.create 1 r) (Seq.slice s j n))); + // validity of result + seq_all_valid_slice s 0 iv; + seq_all_valid_slice s j n; + Seq.lemma_eq_intro (Seq.tail (Seq.cons r (Seq.slice s j n))) (Seq.slice s j n); + seq_all_valid_append (Seq.create 1 r) (Seq.slice s j n); + Seq.lemma_eq_intro (Seq.append (Seq.create 1 r) (Seq.slice s j n)) (Seq.cons r (Seq.slice s j n)); + seq_all_valid_append (Seq.slice s 0 iv) (Seq.cons r (Seq.slice s j n)); + // seq_to_spec of result + seq_to_spec_slice s 0 iv; + seq_to_spec_slice s j n; + seq_to_spec_append (Seq.create 1 r) (Seq.slice s j n); + seq_to_spec_append (Seq.slice s 0 iv) (Seq.cons r (Seq.slice s j n)) + end else begin + // j == iv + 1, no removal needed + Seq.lemma_eq_intro s' (Seq.append (Seq.slice s 0 iv) (Seq.append (Seq.create 1 r) (Seq.slice s j n))); + seq_all_valid_slice s 0 iv; + seq_all_valid_slice s j n; + Seq.lemma_eq_intro (Seq.tail (Seq.cons r (Seq.slice s j n))) (Seq.slice s j n); + seq_all_valid_append (Seq.create 1 r) (Seq.slice s j n); + Seq.lemma_eq_intro (Seq.append (Seq.create 1 r) (Seq.slice s j n)) (Seq.cons r (Seq.slice s j n)); + seq_all_valid_append (Seq.slice s 0 iv) (Seq.cons r (Seq.slice s j n)); + seq_to_spec_slice s 0 iv; + seq_to_spec_slice s j n; + seq_to_spec_append (Seq.create 1 r) (Seq.slice s j n); + seq_to_spec_append (Seq.slice s 0 iv) (Seq.cons r (Seq.slice s j n)) + end + +(* Bridge: lift exit condition from range-level to spec-level with mh0 *) +noextract +let exit_condition_to_spec (s: Seq.seq range) (repr: Seq.seq Spec.interval) (jv: nat) + (mh0_val final_high_val: nat) + : Lemma (requires seq_all_valid s /\ repr == seq_to_spec s /\ jv <= Seq.length s /\ + final_high_val >= mh0_val /\ + (jv == Seq.length s \/ final_high_val < SZ.v (Seq.index s jv).start)) + (ensures jv == Seq.length repr \/ mh0_val < (Seq.index repr jv).Spec.low) + = if jv < Seq.length s then seq_to_spec_index s jv + else () + +(* Bridge lemma for merge loop body: packages merge_absorbed_high step + mh0 coverage *) +noextract +let merge_loop_body_step (s: Seq.seq range) (iv jv: nat) (mh_val mh0_val: nat) + : Lemma (requires + iv + 1 <= Seq.length s /\ jv > iv /\ jv < Seq.length s /\ + seq_all_valid s /\ + range_valid (Seq.index s jv) /\ + mh_val >= SZ.v (Seq.index s jv).start /\ + (let suffix_tail = Seq.slice (seq_to_spec s) (iv + 1) (Seq.length s) in + let k = jv - iv - 1 in + k < Seq.length suffix_tail /\ + mh_val == Spec.merge_absorbed_high suffix_tail mh0_val k /\ + Spec.range_map_wf suffix_tail)) + (ensures + (let suffix_tail = Seq.slice (seq_to_spec s) (iv + 1) (Seq.length s) in + let r_high_val = SZ.v (Seq.index s jv).start + SZ.v (Seq.index s jv).len in + let new_mh = (if r_high_val > mh_val then r_high_val else mh_val) in + // 1. merge_absorbed_high step + new_mh == Spec.merge_absorbed_high suffix_tail mh0_val (jv - iv) /\ + // 2. mh0 covers all absorbed + mh0_val >= SZ.v (Seq.index s jv).start)) + = let k = jv - iv - 1 in + let suffix_tail = Seq.slice (seq_to_spec s) (iv + 1) (Seq.length s) in + // Connect suffix_tail indexing to original seq + seq_to_spec_index s jv; + seq_all_valid_index s jv; + assert (Seq.index suffix_tail k == range_to_interval (Seq.index s jv)); + // merge_absorbed_high unfold + Spec.merge_absorbed_high_unfold_right suffix_tail mh0_val k; + // mh0 covers absorbed + if k > 0 then + Spec.mh0_covers_absorbed suffix_tail mh0_val k + else () + +#pop-options + +#push-options "--z3rlimit 80 --fuel 2 --ifuel 1" + +/// Helper: shift elements [i..n) right by 1, set position i to r. +fn vec_insert_at (rv: range_vec_t) (i: SZ.t) (r: range) + (#s: erased (Seq.seq range)) (#cap: erased SZ.t) + requires V.is_vector rv s cap ** + pure (SZ.v i <= Seq.length s /\ + Seq.length s < max_range_vec_entries /\ + (Seq.length s < SZ.v cap \/ SZ.fits (SZ.v cap + SZ.v cap))) + ensures exists* (s': Seq.seq range) (cap': SZ.t). + V.is_vector rv s' cap' ** + pure (Seq.length s' == Seq.length s + 1 /\ s' == seq_insert s (SZ.v i) r /\ + (Seq.length s + 1 < SZ.v cap' \/ SZ.fits (SZ.v cap' + SZ.v cap'))) +{ + V.size_bounded rv; + V.push_back rv r; + with cap1. _; + insert_capacity_condition (Seq.length (G.reveal s)) (SZ.v (G.reveal cap)) (SZ.v cap1); + let sz = V.size rv; + if (SZ.gt sz 1sz && SZ.lt i (SZ.sub sz 1sz)) { + let mut j = SZ.sub sz 1sz; + let mut cont = true; + while (!cont) + invariant exists* jv cv s_cur cap_cur. + R.pts_to j jv ** R.pts_to cont cv ** + V.is_vector rv s_cur cap_cur ** + pure (SZ.v jv >= SZ.v i /\ SZ.v jv < Seq.length s_cur /\ + Seq.length s_cur == Seq.length (G.reveal s) + 1 /\ + cap_cur == cap1 /\ + // Prefix unchanged + (forall (m:nat). m < SZ.v jv ==> + Seq.index s_cur m == Seq.index (G.reveal s) m) /\ + // Shifted region + (forall (m:nat). m > SZ.v jv /\ m < Seq.length s_cur ==> + Seq.index s_cur m == Seq.index (G.reveal s) (m - 1)) /\ + // Exit + (not cv ==> SZ.v jv == SZ.v i)) + { + let jv = !j; + if (SZ.gt jv i) { + let prev = V.at rv (SZ.sub jv 1sz); + V.set rv jv prev; + let new_j = SZ.sub jv 1sz; + j := new_j; + if (SZ.eq new_j i) { cont := false } + } else { + cont := false + } + }; + // Bind loop existentials; call bridge lemma BEFORE V.set + with _jv2 _cv2 s_after_shift _cap_shift. _; + shift_set_is_seq_insert (G.reveal s) s_after_shift (SZ.v i) r; + V.set rv i r + } else { + // else: sz <= 1 or i >= sz - 1, so i >= |s| + assert (pure (SZ.v i >= Seq.length (G.reveal s))); + assert (pure (SZ.v i <= Seq.length (G.reveal s))); + snoc_is_seq_insert (G.reveal s) r (SZ.v i); + assert (pure (Seq.snoc (G.reveal s) r == seq_insert (G.reveal s) (SZ.v i) r)) + } +} + +#pop-options + +/// Helper: remove count elements starting at position i +fn vec_remove_range (rv: range_vec_t) (i: SZ.t) (count: SZ.t) + (#s: erased (Seq.seq range)) (#cap: erased SZ.t) + requires V.is_vector rv s cap ** + pure (SZ.v i + SZ.v count <= Seq.length s /\ SZ.v count > 0) + ensures exists* (s': Seq.seq range) (cap': SZ.t). + V.is_vector rv s' cap' ** + pure (s' == seq_remove s (SZ.v i) (SZ.v count) /\ + Seq.length s' + SZ.v count == Seq.length s /\ + (Seq.length s' < SZ.v cap' \/ SZ.fits (SZ.v cap' + SZ.v cap'))) +{ + let sz = V.size rv; + let dst_end = SZ.sub sz count; + // Phase 1: shift elements left — copy s[j+count] to s[j] for j in [i..dst_end) + let mut j = i; + let mut shift_cont = true; + while (!shift_cont) + invariant exists* jv sc s_cur cap_cur. + R.pts_to j jv ** R.pts_to shift_cont sc ** + V.is_vector rv s_cur cap_cur ** + pure (SZ.v jv >= SZ.v i /\ SZ.v jv <= SZ.v dst_end /\ + Seq.length s_cur == Seq.length s /\ + cap_cur == G.reveal cap /\ + // Prefix unchanged + (forall (k:nat). k < SZ.v i ==> Seq.index s_cur k == Seq.index (G.reveal s) k) /\ + // Shifted region + (forall (k:nat). k >= SZ.v i /\ k < SZ.v jv ==> + Seq.index s_cur k == Seq.index (G.reveal s) (k + SZ.v count)) /\ + // Suffix unchanged + (forall (k:nat). k >= SZ.v jv /\ k < Seq.length s_cur ==> + Seq.index s_cur k == Seq.index (G.reveal s) k) /\ + // Exit + (not sc ==> SZ.v jv == SZ.v dst_end)) + { + let jv = !j; + if (SZ.lt jv dst_end) { + let val_ = V.at rv (SZ.add jv count); + V.set rv jv val_; + j := SZ.add jv 1sz + } else { + shift_cont := false + } + }; + // After shift: first dst_end elements form seq_remove + with _jv1 _sc1 s_shifted _cap_shifted. _; + shift_to_seq_remove (G.reveal s) s_shifted (SZ.v i) (SZ.v count); + // Phase 2: pop count elements from the end + let target = G.hide (seq_remove (G.reveal s) (SZ.v i) (SZ.v count)); + let mut k = 0sz; + let mut pop_cont = true; + while (!pop_cont) + invariant exists* kv pc s_cur cap_cur. + R.pts_to k kv ** R.pts_to pop_cont pc ** + V.is_vector rv s_cur cap_cur ** + pure (SZ.v kv <= SZ.v count /\ + Seq.length s_cur + SZ.v kv == Seq.length (G.reveal s) /\ + // Content: first dst_end elements as a slice match seq_remove + Seq.slice s_cur 0 (SZ.v dst_end) == G.reveal target /\ + // Capacity: established after first pop (kv > 0) + (SZ.v kv > 0 ==> + (Seq.length s_cur < SZ.v cap_cur \/ SZ.fits (SZ.v cap_cur + SZ.v cap_cur))) /\ + (not pc ==> SZ.v kv >= SZ.v count)) + { + let kv = !k; + if (SZ.lt kv count) { + let _ = V.pop_back rv; + let new_k = SZ.add kv 1sz; + k := new_k; + if (SZ.eq new_k count) { + pop_cont := false + } + } else { + pop_cont := false + } + }; + // After pop: s_cur has dst_end elements, slice 0..dst_end == s_cur == target + with _kv1 _pc1 s_final _cap_final. _; + Seq.lemma_eq_intro s_final (G.reveal target) +} + +(*** Drain — remove/trim first entry up to new_bo ***) + +/// Bridge: seq_to_spec of tail +let seq_to_spec_tail (s: Seq.seq range) + : Lemma (requires Seq.length s > 0 /\ seq_all_valid s) + (ensures seq_all_valid (Seq.tail s) /\ + seq_to_spec (Seq.tail s) == Seq.tail (seq_to_spec s)) = + let tl = Seq.tail s in + seq_all_valid_slice s 1 (Seq.length s); + assert (Seq.slice s 1 (Seq.length s) `Seq.equal` tl); + let spec_tl = seq_to_spec tl in + let tail_spec = Seq.tail (seq_to_spec s) in + let aux (i: nat{i < Seq.length spec_tl}) + : Lemma (Seq.index spec_tl i == Seq.index tail_spec i) = + seq_to_spec_index tl i; + seq_all_valid_index s (i + 1); + seq_to_spec_index s (i + 1) + in + Classical.forall_intro aux; + Seq.lemma_eq_intro spec_tl tail_spec + +/// Bridge: updating first element maps to spec +let seq_to_spec_upd0 (s: Seq.seq range) (r: range) + : Lemma (requires Seq.length s > 0 /\ seq_all_valid s /\ range_valid r) + (ensures seq_all_valid (Seq.upd s 0 r) /\ + seq_to_spec (Seq.upd s 0 r) == Seq.upd (seq_to_spec s) 0 (range_to_interval r)) = + let s' = Seq.upd s 0 r in + let aux_valid (i: nat{i < Seq.length s'}) + : Lemma (range_valid (Seq.index s' i)) = + if i = 0 then () + else begin + Seq.lemma_index_upd2 s 0 r i; + seq_all_valid_index s i + end + in + Classical.forall_intro aux_valid; + let spec_s' = seq_to_spec s' in + let upd_spec = Seq.upd (seq_to_spec s) 0 (range_to_interval r) in + let aux (i: nat{i < Seq.length spec_s'}) + : Lemma (Seq.index spec_s' i == Seq.index upd_spec i) = + seq_to_spec_index s' i; + if i = 0 then () + else begin + seq_all_valid_index s i; + seq_to_spec_index s i + end + in + Classical.forall_intro aux; + Seq.lemma_eq_intro spec_s' upd_spec + +fn range_vec_drain (rv: range_vec_t) (new_bo: SZ.t) + (#repr: erased (Seq.seq Spec.interval)) + requires is_range_vec rv repr ** + pure (Seq.length repr > 0 /\ + (Seq.index repr 0).low <= SZ.v new_bo /\ + SZ.v new_bo <= Spec.high (Seq.index repr 0)) + ensures is_range_vec rv (Spec.drain_repr repr (SZ.v new_bo)) +{ + unfold is_range_vec; + with s cap. _; + let sz = V.size rv; + seq_to_spec_index s 0; + let first = V.at rv 0sz; + let r_high = SZ.add first.start first.len; + let first_spec = Spec.Mkinterval (SZ.v first.start) (SZ.v first.len); + assert (pure (Seq.index repr 0 == first_spec)); + if (SZ.lte r_high new_bo) { + // Remove first entry entirely + vec_remove_range rv 0sz 1sz; + with s' cap'. _; + // Bridge: seq_remove s 0 1 = tail s + Seq.lemma_eq_intro (seq_remove s 0 1) (Seq.tail s); + seq_to_spec_tail s; + Spec.drain_repr_wf repr (SZ.v new_bo); + Spec.drain_repr_length repr (SZ.v new_bo); + fold (is_range_vec rv (Spec.drain_repr repr (SZ.v new_bo))) + } else if (SZ.lt first.start new_bo) { + let new_len = SZ.sub r_high new_bo; + let new_range = { start = new_bo; len = new_len }; + V.set rv 0sz new_range; + seq_to_spec_upd0 s new_range; + // Connect upd to cons form and to drain_repr + Seq.lemma_eq_intro (Seq.upd repr 0 (range_to_interval new_range)) + (Seq.cons (range_to_interval new_range) (Seq.tail repr)); + Seq.lemma_eq_intro (Seq.upd repr 0 (range_to_interval new_range)) + (Spec.drain_repr repr (SZ.v new_bo)); + Spec.drain_repr_wf repr (SZ.v new_bo); + Spec.drain_repr_length repr (SZ.v new_bo); + fold (is_range_vec rv (Spec.drain_repr repr (SZ.v new_bo))) + } else { + // No-op: new_bo == first.start, drain_repr returns s unchanged + fold (is_range_vec rv (Spec.drain_repr repr (SZ.v new_bo))) + } +} + +#push-options "--z3rlimit 400 --fuel 2 --ifuel 1" + +fn range_vec_add (rv: range_vec_t) (offset: SZ.t) (len: SZ.t{SZ.v len > 0}) + (#repr: erased (Seq.seq Spec.interval)) + requires is_range_vec rv repr ** + pure (SZ.fits (SZ.v offset + SZ.v len) /\ + Seq.length repr < max_range_vec_entries) + ensures is_range_vec rv (Spec.add_range repr (SZ.v offset) (SZ.v len)) +{ + unfold is_range_vec; + with s cap. _; + let sz = V.size rv; + let off_plus_len = SZ.add offset len; + + // Phase 1: find insertion point (first i where high(rv[i]) >= offset) + seq_all_valid_forall s; + let mut idx = 0sz; + let mut search = true; + while (!search) + invariant exists* iv sv s_cur cap_cur. + R.pts_to idx iv ** R.pts_to search sv ** + V.is_vector rv s_cur cap_cur ** + pure (seq_all_valid s_cur /\ + s_cur == G.reveal s /\ cap_cur == G.reveal cap /\ + SZ.v iv <= Seq.length s_cur /\ + (forall (k:nat). k < Seq.length s_cur ==> range_valid (Seq.index s_cur k)) /\ + (forall (k:nat). k < SZ.v iv ==> + SZ.v (Seq.index s_cur k).start + SZ.v (Seq.index s_cur k).len < SZ.v offset) /\ + // Exit: when done, either iv==sz or high(s[iv]) >= offset + (not sv ==> (SZ.v iv == Seq.length s_cur \/ + SZ.v (Seq.index s_cur (SZ.v iv)).start + SZ.v (Seq.index s_cur (SZ.v iv)).len >= SZ.v offset))) + { + let iv = !idx; + if (SZ.lt iv sz) { + let r = V.at rv iv; + let high = SZ.add r.start r.len; + if (SZ.lt high offset) { + idx := SZ.add iv 1sz + } else { + search := false + } + } else { + search := false + } + }; + + let iv = !idx; + + if (SZ.eq sz 0sz || SZ.eq iv sz) { + // Append at end (empty vec or all ranges are before offset) + let r : range = { start = offset; len = len }; + // Prove spec facts while s is still live + forall_high_below_to_spec s (SZ.v iv) (SZ.v offset); + Spec.add_range_all_before repr (SZ.v offset) (SZ.v len); + seq_to_spec_snoc s r; + seq_all_valid_insert s (SZ.v iv) r; + Spec.add_range_wf repr (SZ.v offset) (SZ.v len); + vec_insert_at rv iv r; + with s' cap'. _; + snoc_is_seq_insert (G.reveal s) r (SZ.v iv); + fold (is_range_vec rv (Spec.add_range repr (SZ.v offset) (SZ.v len))) + } else { + let first_r = V.at rv iv; + if (SZ.lt off_plus_len first_r.start) { + // No overlap — insert before iv + let new_r : range = { start = offset; len = len }; + // Prove spec facts while s is still live + forall_high_below_to_spec s (SZ.v iv) (SZ.v offset); + seq_to_spec_index s (SZ.v iv); + seq_all_valid_index s (SZ.v iv); + Spec.add_range_insert_no_overlap repr (SZ.v offset) (SZ.v len) (SZ.v iv); + seq_to_spec_insert s (SZ.v iv) new_r; + seq_all_valid_insert s (SZ.v iv) new_r; + Spec.add_range_wf repr (SZ.v offset) (SZ.v len); + vec_insert_at rv iv new_r; + with s' cap'. _; + fold (is_range_vec rv (Spec.add_range repr (SZ.v offset) (SZ.v len))) + } else { + // Merge: compute merged bounds [merged_low, merged_high) + let merged_low = (if SZ.lt offset first_r.start then offset else first_r.start); + let first_high = SZ.add first_r.start first_r.len; + let mh0_val = (if SZ.gt off_plus_len first_high then off_plus_len else first_high); + let mut merged_high = mh0_val; + + // Ghost: capture initial mh0 and suffix_tail for merge_absorbed_high tracking + let mh0 = G.hide (SZ.v mh0_val); + let repr_snap = G.hide repr; + + // Extend merge rightward through overlapping/adjacent ranges + let mut j = SZ.add iv 1sz; + let mut merge_cont = true; + while (!merge_cont) + invariant exists* jv mh mc s_cur cap_cur. + R.pts_to j jv ** R.pts_to merged_high mh ** R.pts_to merge_cont mc ** + V.is_vector rv s_cur cap_cur ** + pure (seq_all_valid s_cur /\ + s_cur == G.reveal s /\ cap_cur == G.reveal cap /\ + SZ.v jv > SZ.v iv /\ SZ.v jv <= Seq.length s_cur /\ + SZ.v mh > SZ.v merged_low /\ + SZ.fits (SZ.v mh) /\ + (forall (k:nat). k < Seq.length s_cur ==> range_valid (Seq.index s_cur k)) /\ + // Overlap: mh covers all ranges in [iv..jv) + (forall (k:nat). k >= SZ.v iv /\ k < SZ.v jv ==> + SZ.v mh >= SZ.v (Seq.index s_cur k).start) /\ + // mh0 covers ranges in (iv..jv) — for merge_full precondition + (forall (k:nat). k > SZ.v iv /\ k < SZ.v jv ==> + G.reveal mh0 >= SZ.v (Seq.index s_cur k).start) /\ + // Exit: when loop done, either jv==sz or mh < s[jv].start + (not mc ==> (SZ.v jv == Seq.length s_cur \/ + SZ.v mh < SZ.v (Seq.index s_cur (SZ.v jv)).start)) /\ + // Track: mh == merge_absorbed_high(suffix_tail, mh0, jv-iv-1) + (let suffix_tail = Seq.slice (seq_to_spec s_cur) (SZ.v iv + 1) (Seq.length s_cur) in + SZ.v mh == Spec.merge_absorbed_high suffix_tail (G.reveal mh0) (SZ.v jv - SZ.v iv - 1))) + { + let jv = !j; + if (SZ.lt jv sz) { + let r = V.at rv jv; + let mh = !merged_high; + if (SZ.gte mh r.start) { + let r_high = SZ.add r.start r.len; + // Use bridge lemma for loop invariant step + Spec.range_map_wf_slice repr (SZ.v iv + 1); + merge_loop_body_step s (SZ.v iv) (SZ.v jv) (SZ.v mh) (G.reveal mh0); + if (SZ.gt r_high mh) { merged_high := r_high }; + j := SZ.add jv 1sz + } else { + merge_cont := false + } + } else { + merge_cont := false + } + }; + + // Write merged range at iv, remove subsumed ranges [iv+1..j) + let jv = !j; + let final_high = !merged_high; + // Bounds are valid: final_high > merged_low from loop invariant + let final_len = SZ.sub final_high merged_low; + + // Lift range-level facts to spec level BEFORE consuming s via V.set + forall_high_below_to_spec s (SZ.v iv) (SZ.v offset); + seq_to_spec_index s (SZ.v iv); + seq_all_valid_index s (SZ.v iv); + forall_overlap_to_spec s (SZ.v iv) (SZ.v jv) (SZ.v final_high); + // Lift mh0 forall to spec level + forall_overlap_to_spec s (SZ.v iv + 1) (SZ.v jv) (G.reveal mh0); + // Connect our ghost mh0 to spec's computed mh0 + assert (pure (Spec.high (Seq.index repr (SZ.v iv)) == SZ.v first_high)); + assert (pure (G.reveal mh0 == + (if SZ.v offset + SZ.v len > Spec.high (Seq.index repr (SZ.v iv)) + then SZ.v offset + SZ.v len + else Spec.high (Seq.index repr (SZ.v iv))))); + // Exit condition: mh0 < repr[j].low (from final_high >= mh0 and final_high < s[j].start) + Spec.merge_absorbed_high_mono + (Seq.slice repr (SZ.v iv + 1) (Seq.length repr)) + (G.reveal mh0) + (SZ.v jv - SZ.v iv - 1); + exit_condition_to_spec s repr (SZ.v jv) (G.reveal mh0) (SZ.v final_high); + + // Call merge_full with explicit mh0 parameter + Spec.add_range_merge_full_explicit repr (SZ.v offset) (SZ.v len) (SZ.v iv) (SZ.v jv) (G.reveal mh0); + + // Now do the imperative set + let merged_r : range = { start = merged_low; len = final_len }; + V.set rv iv merged_r; + + // Handle remove case + if (SZ.gt jv (SZ.add iv 1sz)) { + let remove_count = SZ.sub jv (SZ.add iv 1sz); + vec_remove_range rv (SZ.add iv 1sz) remove_count; + with s_final cap_final. _; + seq_upd_remove_spec s (SZ.v iv) (SZ.v jv) merged_r; + Spec.add_range_wf repr (SZ.v offset) (SZ.v len); + assert (pure (range_to_interval merged_r == Spec.({ low = SZ.v merged_low; count = SZ.v final_len }))); + fold (is_range_vec rv (Spec.add_range repr (SZ.v offset) (SZ.v len))) + } else { + // jv == iv + 1, no removal needed. V.set gives concrete V.is_vector rv (Seq.upd s iv merged_r) cap + seq_upd_remove_spec s (SZ.v iv) (SZ.v jv) merged_r; + Spec.add_range_wf repr (SZ.v offset) (SZ.v len); + // range_to_interval merged_r matches the spec's merged interval + assert (pure (range_to_interval merged_r == Spec.({ low = SZ.v merged_low; count = SZ.v final_len }))); + fold (is_range_vec rv (Spec.add_range repr (SZ.v offset) (SZ.v len))) + } + } + } +} + +#pop-options diff --git a/lib/pulse/lib/Pulse.Lib.Spec.AVLTree.fst b/lib/pulse/lib/Pulse.Lib.Spec.AVLTree.fst index fb4a06e99..7d087a2a3 100644 --- a/lib/pulse/lib/Pulse.Lib.Spec.AVLTree.fst +++ b/lib/pulse/lib/Pulse.Lib.Spec.AVLTree.fst @@ -10,336 +10,332 @@ module M = FStar.Math.Lib (**** The tree structure *) -type tree (a: Type) = - | Leaf : tree a - | Node: data: a -> left: tree a -> right: tree a -> tree a +type tree (k: Type) (v: Type) = + | Leaf : tree k v + | Node: key: k -> value: v -> left: tree k v -> right: tree k v -> tree k v (**** Binary search trees *) -type node_data (a b: Type) = { - key: a; - payload: b; -} - -let kv_tree (a: Type) (b: Type) = tree (node_data a b) -type cmp (a: Type) = compare: (a -> a -> int){ +type cmp (k: Type) = compare: (k -> k -> int){ squash (forall x. compare x x == 0) /\ squash (forall x y. compare x y > 0 <==> compare y x < 0) /\ squash (forall x y z. compare x y >= 0 /\ compare y z >= 0 ==> compare x z >= 0) } -let rec forall_keys (#a: Type) (t: tree a) (cond: a -> bool) : bool = +let rec forall_keys (#k: Type) (#v: Type) (t: tree k v) (cond: k -> bool) : bool = match t with | Leaf -> true - | Node data left right -> - cond data && forall_keys left cond && forall_keys right cond + | Node nd_key nd_val left right -> + cond nd_key && forall_keys left cond && forall_keys right cond -let key_left (#a: Type) (compare:cmp a) (root key: a) = +let key_left (#k: Type) (compare:cmp k) (root key: k) = compare root key >= 0 -let key_right (#a: Type) (compare : cmp a) (root key: a) = +let key_right (#k: Type) (compare : cmp k) (root key: k) = compare root key <= 0 -let rec is_bst (#a: Type) (compare : cmp a) (x: tree a) : bool = +let rec is_bst (#k: Type) (#v: Type) (compare : cmp k) (x: tree k v) : bool = match x with | Leaf -> true - | Node data left right -> + | Node nd_key nd_val left right -> is_bst compare left && is_bst compare right && - forall_keys left (key_left compare data) && - forall_keys right (key_right compare data) + forall_keys left (key_left compare nd_key) && + forall_keys right (key_right compare nd_key) -let bst (a: Type) (cmp:cmp a) = x:tree a {is_bst cmp x} +let bst (k: Type) (v: Type) (cmp:cmp k) = x:tree k v {is_bst cmp x} (*** Operations *) (**** empty *) -let is_empty (#a: Type) (r: tree a) : b:bool{b == true <==> r == Leaf} = +let is_empty (#k: Type) (#v: Type) (r: tree k v) : b:bool{b == true <==> r == Leaf} = match r with | Leaf -> true | _ -> false (**** Lookup *) -let rec mem (#a: Type) (r: tree a) (x: a) : prop = +let rec mem (#k: Type) (#v: Type) (r: tree k v) (x: k) : prop = match r with | Leaf -> False - | Node data left right -> - (data == x) \/ (mem right x) \/ mem left x + | Node nd_key nd_val left right -> + (nd_key == x) \/ (mem right x) \/ mem left x -let rec bst_search (#a: Type) (cmp:cmp a) (x: bst a cmp) (key: a) : option a = +let rec bst_search (#k: Type) (#v: Type) (cmp:cmp k) (x: bst k v cmp) (key: k) : option (k & v) = match x with | Leaf -> None - | Node data left right -> - let delta = cmp data key in + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in if delta < 0 then bst_search cmp right key else if delta > 0 then bst_search cmp left key else - Some data + Some (nd_key, nd_val) (**** Height *) -let rec height (#a: Type) (x: tree a) : nat = +let rec height (#k: Type) (#v: Type) (x: tree k v) : nat = match x with | Leaf -> 0 - | Node data left right -> + | Node nd_key nd_val left right -> if height left > height right then (height left) + 1 else (height right) + 1 (**** Append *) -let rec append_left (#a: Type) (x: tree a) (v: a) : tree a = +let rec append_left (#k: Type) (#v: Type) (x: tree k v) (ak: k) (av: v) : tree k v = match x with - | Leaf -> Node v Leaf Leaf - | Node x left right -> Node x (append_left left v) right + | Leaf -> Node ak av Leaf Leaf + | Node xk xv left right -> Node xk xv (append_left left ak av) right -let rec append_right (#a: Type) (x: tree a) (v: a) : tree a = +let rec append_right (#k: Type) (#v: Type) (x: tree k v) (ak: k) (av: v) : tree k v = match x with - | Leaf -> Node v Leaf Leaf - | Node x left right -> Node x left (append_right right v) + | Leaf -> Node ak av Leaf Leaf + | Node xk xv left right -> Node xk xv left (append_right right ak av) (**** Insertion *) (**** BST insertion *) -let rec insert_bst (#a: Type) (cmp:cmp a) (x: bst a cmp) (key: a) : tree a = +let rec insert_bst (#k: Type) (#v: Type) (cmp:cmp k) (x: bst k v cmp) (key: k) (val_: v) : tree k v = match x with - | Leaf -> Node key Leaf Leaf - | Node data left right -> - let delta = cmp data key in + | Leaf -> Node key val_ Leaf Leaf + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in if delta >= 0 then begin - let new_left = insert_bst cmp left key in - Node data new_left right + let new_left = insert_bst cmp left key val_ in + Node nd_key nd_val new_left right end else begin - let new_right = insert_bst cmp right key in - Node data left new_right + let new_right = insert_bst cmp right key val_ in + Node nd_key nd_val left new_right end let rec insert_bst_preserves_forall_keys - (#a: Type) - (cmp:cmp a) - (x: bst a cmp) - (key: a) - (cond: a -> bool) + (#k: Type) (#v: Type) + (cmp:cmp k) + (x: bst k v cmp) + (key: k) + (val_: v) + (cond: k -> bool) : Lemma (requires (forall_keys x cond /\ cond key)) - (ensures (forall_keys (insert_bst cmp x key) cond)) + (ensures (forall_keys (insert_bst cmp x key val_) cond)) = match x with | Leaf -> () - | Node data left right -> - let delta = cmp data key in + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in if delta >= 0 then - insert_bst_preserves_forall_keys cmp left key cond + insert_bst_preserves_forall_keys cmp left key val_ cond else - insert_bst_preserves_forall_keys cmp right key cond + insert_bst_preserves_forall_keys cmp right key val_ cond let rec insert_bst_preserves_bst - (#a: Type) - (cmp:cmp a) - (x: bst a cmp) - (key: a) - : Lemma(is_bst cmp (insert_bst cmp x key)) + (#k: Type) (#v: Type) + (cmp:cmp k) + (x: bst k v cmp) + (key: k) + (val_: v) + : Lemma(is_bst cmp (insert_bst cmp x key val_)) = match x with | Leaf -> () - | Node data left right -> - let delta = cmp data key in + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in if delta >= 0 then begin - insert_bst_preserves_forall_keys cmp left key (key_left cmp data); - insert_bst_preserves_bst cmp left key + insert_bst_preserves_forall_keys cmp left key val_ (key_left cmp nd_key); + insert_bst_preserves_bst cmp left key val_ end else begin - insert_bst_preserves_forall_keys cmp right key (key_right cmp data); - insert_bst_preserves_bst cmp right key + insert_bst_preserves_forall_keys cmp right key val_ (key_right cmp nd_key); + insert_bst_preserves_bst cmp right key val_ end (**** AVL insertion *) -let rec is_balanced (#a: Type) (x: tree a) : bool = +let rec is_balanced (#k: Type) (#v: Type) (x: tree k v) : bool = match x with | Leaf -> true - | Node data left right -> + | Node nd_key nd_val left right -> M.abs(height right - height left) <= 1 && is_balanced(right) && is_balanced(left) -let is_avl (#a: Type) (cmp:cmp a) (x: tree a) : prop = +let is_avl (#k: Type) (#v: Type) (cmp:cmp k) (x: tree k v) : prop = is_bst cmp x /\ is_balanced x -let avl (a: Type) (cmp:cmp a) = x: tree a {is_avl cmp x} +let avl (k: Type) (v: Type) (cmp:cmp k) = x: tree k v {is_avl cmp x} -let rotate_left (#a: Type) (r: tree a) : option (tree a) = +let rotate_left (#k: Type) (#v: Type) (r: tree k v) : option (tree k v) = match r with - | Node x t1 (Node z t2 t3) -> Some (Node z (Node x t1 t2) t3) + | Node xk xv t1 (Node zk zv t2 t3) -> Some (Node zk zv (Node xk xv t1 t2) t3) | _ -> None -let rotate_right (#a: Type) (r: tree a) : option (tree a) = +let rotate_right (#k: Type) (#v: Type) (r: tree k v) : option (tree k v) = match r with - | Node x (Node z t1 t2) t3 -> Some (Node z t1 (Node x t2 t3)) + | Node xk xv (Node zk zv t1 t2) t3 -> Some (Node zk zv t1 (Node xk xv t2 t3)) | _ -> None -let rotate_right_left (#a: Type) (r: tree a) : option (tree a) = +let rotate_right_left (#k: Type) (#v: Type) (r: tree k v) : option (tree k v) = match r with - | Node x t1 (Node z (Node y t2 t3) t4) -> Some (Node y (Node x t1 t2) (Node z t3 t4)) + | Node xk xv t1 (Node zk zv (Node yk yv t2 t3) t4) -> Some (Node yk yv (Node xk xv t1 t2) (Node zk zv t3 t4)) | _ -> None -let rotate_left_right (#a: Type) (r: tree a) : option (tree a) = +let rotate_left_right (#k: Type) (#v: Type) (r: tree k v) : option (tree k v) = match r with - | Node x (Node z t1 (Node y t2 t3)) t4 -> Some (Node y (Node z t1 t2) (Node x t3 t4)) + | Node xk xv (Node zk zv t1 (Node yk yv t2 t3)) t4 -> Some (Node yk yv (Node zk zv t1 t2) (Node xk xv t3 t4)) | _ -> None (** rotate preserves bst *) -let rec forall_keys_trans (#a: Type) (t: tree a) (cond1 cond2: a -> bool) +let rec forall_keys_trans (#k: Type) (#v: Type) (t: tree k v) (cond1 cond2: k -> bool) : Lemma (requires (forall x. cond1 x ==> cond2 x) /\ forall_keys t cond1) (ensures forall_keys t cond2) = match t with | Leaf -> () - | Node data left right -> + | Node nd_key nd_val left right -> forall_keys_trans left cond1 cond2; forall_keys_trans right cond1 cond2 -let rotate_left_bst (#a:Type) (cmp:cmp a) (r:tree a) +let rotate_left_bst (#k: Type) (#v: Type) (cmp:cmp k) (r:tree k v) : Lemma (requires is_bst cmp r /\ Some? (rotate_left r)) (ensures is_bst cmp (Some?.v (rotate_left r))) = match r with - | Node x t1 (Node z t2 t3) -> - assert (is_bst cmp (Node z t2 t3)); - assert (is_bst cmp (Node x t1 t2)); - forall_keys_trans t1 (key_left cmp x) (key_left cmp z) + | Node xk xv t1 (Node zk zv t2 t3) -> + assert (is_bst cmp (Node zk zv t2 t3)); + assert (is_bst cmp (Node xk xv t1 t2)); + forall_keys_trans t1 (key_left cmp xk) (key_left cmp zk) -let rotate_right_bst (#a:Type) (cmp:cmp a) (r:tree a) +let rotate_right_bst (#k: Type) (#v: Type) (cmp:cmp k) (r:tree k v) : Lemma (requires is_bst cmp r /\ Some? (rotate_right r)) (ensures is_bst cmp (Some?.v (rotate_right r))) = match r with - | Node x (Node z t1 t2) t3 -> - assert (is_bst cmp (Node z t1 t2)); - assert (is_bst cmp (Node x t2 t3)); - forall_keys_trans t3 (key_right cmp x) (key_right cmp z) + | Node xk xv (Node zk zv t1 t2) t3 -> + assert (is_bst cmp (Node zk zv t1 t2)); + assert (is_bst cmp (Node xk xv t2 t3)); + forall_keys_trans t3 (key_right cmp xk) (key_right cmp zk) -let rotate_right_left_bst (#a:Type) (cmp:cmp a) (r:tree a) +let rotate_right_left_bst (#k: Type) (#v: Type) (cmp:cmp k) (r:tree k v) : Lemma (requires is_bst cmp r /\ Some? (rotate_right_left r)) (ensures is_bst cmp (Some?.v (rotate_right_left r))) = match r with - | Node x t1 (Node z (Node y t2 t3) t4) -> - assert (is_bst cmp (Node z (Node y t2 t3) t4)); - assert (is_bst cmp (Node y t2 t3)); - let left = Node x t1 t2 in - let right = Node z t3 t4 in - - assert (forall_keys (Node y t2 t3) (key_right cmp x)); - assert (forall_keys t2 (key_right cmp x)); + | Node xk xv t1 (Node zk zv (Node yk yv t2 t3) t4) -> + assert (is_bst cmp (Node zk zv (Node yk yv t2 t3) t4)); + assert (is_bst cmp (Node yk yv t2 t3)); + let left = Node xk xv t1 t2 in + let right = Node zk zv t3 t4 in + + assert (forall_keys (Node yk yv t2 t3) (key_right cmp xk)); + assert (forall_keys t2 (key_right cmp xk)); assert (is_bst cmp left); assert (is_bst cmp right); - forall_keys_trans t1 (key_left cmp x) (key_left cmp y); - assert (forall_keys left (key_left cmp y)); + forall_keys_trans t1 (key_left cmp xk) (key_left cmp yk); + assert (forall_keys left (key_left cmp yk)); - forall_keys_trans t4 (key_right cmp z) (key_right cmp y); - assert (forall_keys right (key_right cmp y)) + forall_keys_trans t4 (key_right cmp zk) (key_right cmp yk); + assert (forall_keys right (key_right cmp yk)) -let rotate_left_right_bst (#a:Type) (cmp:cmp a) (r:tree a) +let rotate_left_right_bst (#k: Type) (#v: Type) (cmp:cmp k) (r:tree k v) : Lemma (requires is_bst cmp r /\ Some? (rotate_left_right r)) (ensures is_bst cmp (Some?.v (rotate_left_right r))) = match r with - | Node x (Node z t1 (Node y t2 t3)) t4 -> - // Node y (Node z t1 t2) (Node x t3 t4) - assert (is_bst cmp (Node z t1 (Node y t2 t3))); - assert (is_bst cmp (Node y t2 t3)); - let left = Node z t1 t2 in - let right = Node x t3 t4 in + | Node xk xv (Node zk zv t1 (Node yk yv t2 t3)) t4 -> + // Node yk yv (Node zk zv t1 t2) (Node xk xv t3 t4) + assert (is_bst cmp (Node zk zv t1 (Node yk yv t2 t3))); + assert (is_bst cmp (Node yk yv t2 t3)); + let left = Node zk zv t1 t2 in + let right = Node xk xv t3 t4 in assert (is_bst cmp left); - assert (forall_keys (Node y t2 t3) (key_left cmp x)); - assert (forall_keys t2 (key_left cmp x)); + assert (forall_keys (Node yk yv t2 t3) (key_left cmp xk)); + assert (forall_keys t2 (key_left cmp xk)); assert (is_bst cmp right); - forall_keys_trans t1 (key_left cmp z) (key_left cmp y); - assert (forall_keys left (key_left cmp y)); + forall_keys_trans t1 (key_left cmp zk) (key_left cmp yk); + assert (forall_keys left (key_left cmp yk)); - forall_keys_trans t4 (key_right cmp x) (key_right cmp y); - assert (forall_keys right (key_right cmp y)) + forall_keys_trans t4 (key_right cmp xk) (key_right cmp yk); + assert (forall_keys right (key_right cmp yk)) (** Same elements before and after rotate **) -let rotate_left_key_left (#a:Type) (cmp:cmp a) (r:tree a) (root:a) +let rotate_left_key_left (#k: Type) (#v: Type) (cmp:cmp k) (r:tree k v) (root:k) : Lemma (requires forall_keys r (key_left cmp root) /\ Some? (rotate_left r)) (ensures forall_keys (Some?.v (rotate_left r)) (key_left cmp root)) = match r with - | Node x t1 (Node z t2 t3) -> - assert (forall_keys (Node z t2 t3) (key_left cmp root)); - assert (forall_keys (Node x t1 t2) (key_left cmp root)) + | Node xk xv t1 (Node zk zv t2 t3) -> + assert (forall_keys (Node zk zv t2 t3) (key_left cmp root)); + assert (forall_keys (Node xk xv t1 t2) (key_left cmp root)) -let rotate_left_key_right (#a:Type) (cmp:cmp a) (r:tree a) (root:a) +let rotate_left_key_right (#k: Type) (#v: Type) (cmp:cmp k) (r:tree k v) (root:k) : Lemma (requires forall_keys r (key_right cmp root) /\ Some? (rotate_left r)) (ensures forall_keys (Some?.v (rotate_left r)) (key_right cmp root)) = match r with - | Node x t1 (Node z t2 t3) -> - assert (forall_keys (Node z t2 t3) (key_right cmp root)); - assert (forall_keys (Node x t1 t2) (key_right cmp root)) + | Node xk xv t1 (Node zk zv t2 t3) -> + assert (forall_keys (Node zk zv t2 t3) (key_right cmp root)); + assert (forall_keys (Node xk xv t1 t2) (key_right cmp root)) -let rotate_right_key_left (#a:Type) (cmp:cmp a) (r:tree a) (root:a) +let rotate_right_key_left (#k: Type) (#v: Type) (cmp:cmp k) (r:tree k v) (root:k) : Lemma (requires forall_keys r (key_left cmp root) /\ Some? (rotate_right r)) (ensures forall_keys (Some?.v (rotate_right r)) (key_left cmp root)) = match r with - | Node x (Node z t1 t2) t3 -> - assert (forall_keys (Node z t1 t2) (key_left cmp root)); - assert (forall_keys (Node x t2 t3) (key_left cmp root)) + | Node xk xv (Node zk zv t1 t2) t3 -> + assert (forall_keys (Node zk zv t1 t2) (key_left cmp root)); + assert (forall_keys (Node xk xv t2 t3) (key_left cmp root)) -let rotate_right_key_right (#a:Type) (cmp:cmp a) (r:tree a) (root:a) +let rotate_right_key_right (#k: Type) (#v: Type) (cmp:cmp k) (r:tree k v) (root:k) : Lemma (requires forall_keys r (key_right cmp root) /\ Some? (rotate_right r)) (ensures forall_keys (Some?.v (rotate_right r)) (key_right cmp root)) = match r with - | Node x (Node z t1 t2) t3 -> - assert (forall_keys (Node z t1 t2) (key_right cmp root)); - assert (forall_keys (Node x t2 t3) (key_right cmp root)) + | Node xk xv (Node zk zv t1 t2) t3 -> + assert (forall_keys (Node zk zv t1 t2) (key_right cmp root)); + assert (forall_keys (Node xk xv t2 t3) (key_right cmp root)) -let rotate_right_left_key_left (#a:Type) (cmp:cmp a) (r:tree a) (root:a) +let rotate_right_left_key_left (#k: Type) (#v: Type) (cmp:cmp k) (r:tree k v) (root:k) : Lemma (requires forall_keys r (key_left cmp root) /\ Some? (rotate_right_left r)) (ensures forall_keys (Some?.v (rotate_right_left r)) (key_left cmp root)) = match r with - | Node x t1 (Node z (Node y t2 t3) t4) -> - assert (forall_keys (Node z (Node y t2 t3) t4) (key_left cmp root)); - assert (forall_keys (Node y t2 t3) (key_left cmp root)); - let left = Node x t1 t2 in - let right = Node z t3 t4 in + | Node xk xv t1 (Node zk zv (Node yk yv t2 t3) t4) -> + assert (forall_keys (Node zk zv (Node yk yv t2 t3) t4) (key_left cmp root)); + assert (forall_keys (Node yk yv t2 t3) (key_left cmp root)); + let left = Node xk xv t1 t2 in + let right = Node zk zv t3 t4 in assert (forall_keys left (key_left cmp root)); assert (forall_keys right (key_left cmp root)) -let rotate_right_left_key_right (#a:Type) (cmp:cmp a) (r:tree a) (root:a) +let rotate_right_left_key_right (#k: Type) (#v: Type) (cmp:cmp k) (r:tree k v) (root:k) : Lemma (requires forall_keys r (key_right cmp root) /\ Some? (rotate_right_left r)) (ensures forall_keys (Some?.v (rotate_right_left r)) (key_right cmp root)) = match r with - | Node x t1 (Node z (Node y t2 t3) t4) -> - assert (forall_keys (Node z (Node y t2 t3) t4) (key_right cmp root)); - assert (forall_keys (Node y t2 t3) (key_right cmp root)); - let left = Node x t1 t2 in - let right = Node z t3 t4 in + | Node xk xv t1 (Node zk zv (Node yk yv t2 t3) t4) -> + assert (forall_keys (Node zk zv (Node yk yv t2 t3) t4) (key_right cmp root)); + assert (forall_keys (Node yk yv t2 t3) (key_right cmp root)); + let left = Node xk xv t1 t2 in + let right = Node zk zv t3 t4 in assert (forall_keys left (key_right cmp root)); assert (forall_keys right (key_right cmp root)) -let rotate_left_right_key_left (#a:Type) (cmp:cmp a) (r:tree a) (root:a) +let rotate_left_right_key_left (#k: Type) (#v: Type) (cmp:cmp k) (r:tree k v) (root:k) : Lemma (requires forall_keys r (key_left cmp root) /\ Some? (rotate_left_right r)) (ensures forall_keys (Some?.v (rotate_left_right r)) (key_left cmp root)) = match r with - | Node x (Node z t1 (Node y t2 t3)) t4 -> - // Node y (Node z t1 t2) (Node x t3 t4) - assert (forall_keys (Node z t1 (Node y t2 t3)) (key_left cmp root)); - assert (forall_keys (Node y t2 t3) (key_left cmp root)); - let left = Node z t1 t2 in - let right = Node x t3 t4 in + | Node xk xv (Node zk zv t1 (Node yk yv t2 t3)) t4 -> + // Node yk yv (Node zk zv t1 t2) (Node xk xv t3 t4) + assert (forall_keys (Node zk zv t1 (Node yk yv t2 t3)) (key_left cmp root)); + assert (forall_keys (Node yk yv t2 t3) (key_left cmp root)); + let left = Node zk zv t1 t2 in + let right = Node xk xv t3 t4 in assert (forall_keys left (key_left cmp root)); assert (forall_keys right (key_left cmp root)) -let rotate_left_right_key_right (#a:Type) (cmp:cmp a) (r:tree a) (root:a) +let rotate_left_right_key_right (#k: Type) (#v: Type) (cmp:cmp k) (r:tree k v) (root:k) : Lemma (requires forall_keys r (key_right cmp root) /\ Some? (rotate_left_right r)) (ensures forall_keys (Some?.v (rotate_left_right r)) (key_right cmp root)) = match r with - | Node x (Node z t1 (Node y t2 t3)) t4 -> - // Node y (Node z t1 t2) (Node x t3 t4) - assert (forall_keys (Node z t1 (Node y t2 t3)) (key_right cmp root)); - assert (forall_keys (Node y t2 t3) (key_right cmp root)); - let left = Node z t1 t2 in - let right = Node x t3 t4 in + | Node xk xv (Node zk zv t1 (Node yk yv t2 t3)) t4 -> + // Node yk yv (Node zk zv t1 t2) (Node xk xv t3 t4) + assert (forall_keys (Node zk zv t1 (Node yk yv t2 t3)) (key_right cmp root)); + assert (forall_keys (Node yk yv t2 t3) (key_right cmp root)); + let left = Node zk zv t1 t2 in + let right = Node xk xv t3 t4 in assert (forall_keys left (key_right cmp root)); assert (forall_keys right (key_right cmp root)) @@ -347,16 +343,16 @@ let rotate_left_right_key_right (#a:Type) (cmp:cmp a) (r:tree a) (root:a) (** Balancing operation for AVLs *) -let rebalance_avl (#a: Type) (x: tree a) : tree a = +let rebalance_avl (#k: Type) (#v: Type) (x: tree k v) : tree k v = match x with | Leaf -> x - | Node data left right -> + | Node nd_key nd_val left right -> if is_balanced x then x else ( if height left - height right > 1 then ( - let Node ldata lleft lright = left in + let Node lk lv lleft lright = left in if height lright > height lleft then ( match rotate_left_right x with | Some y -> y @@ -368,7 +364,7 @@ let rebalance_avl (#a: Type) (x: tree a) : tree a = ) ) else if height left - height right < -1 then ( - let Node rdata rleft rright = right in + let Node rk rv rleft rright = right in if height rleft > height rright then ( match rotate_right_left x with | Some y -> y @@ -384,13 +380,13 @@ let rebalance_avl (#a: Type) (x: tree a) : tree a = ) -let rebalance_avl_proof (#a: Type) (cmp:cmp a) (x: tree a) - (root:a) +let rebalance_avl_proof (#k: Type) (#v: Type) (cmp:cmp k) (x: tree k v) + (root:k) : Lemma (requires is_bst cmp x /\ ( match x with | Leaf -> True - | Node data left right -> + | Node nd_key nd_val left right -> is_balanced left /\ is_balanced right /\ height left - height right <= 2 /\ height right - height left <= 2 @@ -403,26 +399,26 @@ let rebalance_avl_proof (#a: Type) (cmp:cmp a) (x: tree a) = match x with | Leaf -> () - | Node data left right -> + | Node nd_key nd_val left right -> let x_f = rebalance_avl x in - let Node f_data f_left f_right = x_f in + let Node f_key f_val f_left f_right = x_f in if is_balanced x then () else ( if height left - height right > 1 then ( assert (height left = height right + 2); - let Node ldata lleft lright = left in + let Node lk lv lleft lright = left in if height lright > height lleft then ( assert (height left = height lright + 1); rotate_left_right_bst cmp x; Classical.move_requires (rotate_left_right_key_left cmp x) root; Classical.move_requires (rotate_left_right_key_right cmp x) root; - let Node y t2 t3 = lright in - let Node x (Node z t1 (Node y t2 t3)) t4 = x in - assert (f_data == y); - assert (f_left == Node z t1 t2); - assert (f_right == Node x t3 t4); - assert (lright == Node y t2 t3); + let Node yk yv t2 t3 = lright in + let Node xk xv (Node zk zv t1 (Node yk yv t2 t3)) t4 = x in + assert (f_key == yk); + assert (f_left == Node zk zv t1 t2); + assert (f_right == Node xk xv t3 t4); + assert (lright == Node yk yv t2 t3); // Left part @@ -433,7 +429,7 @@ let rebalance_avl_proof (#a: Type) (cmp:cmp a) (x: tree a) assert (is_balanced t1); - assert (is_balanced (Node y t2 t3)); + assert (is_balanced (Node yk yv t2 t3)); assert (is_balanced t2); assert (is_balanced f_left); @@ -458,23 +454,23 @@ let rebalance_avl_proof (#a: Type) (cmp:cmp a) (x: tree a) ) ) else if height left - height right < -1 then ( - let Node rdata rleft rright = right in + let Node rk rv rleft rright = right in if height rleft > height rright then ( rotate_right_left_bst cmp x; Classical.move_requires (rotate_right_left_key_left cmp x) root; Classical.move_requires (rotate_right_left_key_right cmp x) root; - let Node x t1 (Node z (Node y t2 t3) t4) = x in - assert (f_data == y); - assert (f_left == Node x t1 t2); - assert (f_right == Node z t3 t4); + let Node xk xv t1 (Node zk zv (Node yk yv t2 t3) t4) = x in + assert (f_key == yk); + assert (f_left == Node xk xv t1 t2); + assert (f_right == Node zk zv t3 t4); // Right part assert (is_balanced rleft); assert (height t3 - height t4 <= 1); assert (height t4 - height t4 <= 1); - assert (is_balanced (Node y t2 t3)); + assert (is_balanced (Node yk yv t2 t3)); assert (is_balanced f_right); // Left part @@ -496,55 +492,55 @@ let rebalance_avl_proof (#a: Type) (cmp:cmp a) (x: tree a) (** Insertion **) -let rec insert_avl (#a: Type) (cmp:cmp a) (x: tree a) (key: a) : tree a = +let rec insert_avl (#k: Type) (#v: Type) (cmp:cmp k) (x: tree k v) (key: k) (val_: v) : tree k v = match x with - | Leaf -> Node key Leaf Leaf - | Node data left right -> - let delta = cmp data key in + | Leaf -> Node key val_ Leaf Leaf + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in if delta >= 0 then ( - let new_left = insert_avl cmp left key in - let tmp = Node data new_left right in + let new_left = insert_avl cmp left key val_ in + let tmp = Node nd_key nd_val new_left right in rebalance_avl tmp ) else ( - let new_right = insert_avl cmp right key in - let tmp = Node data left new_right in + let new_right = insert_avl cmp right key val_ in + let tmp = Node nd_key nd_val left new_right in rebalance_avl tmp ) -let rec tree_max (#a: Type) (x: tree a {Node? x}) : a = +let rec tree_max (#k: Type) (#v: Type) (x: tree k v {Node? x}) : (k & v) = match x with - | Node d _ Leaf -> d - | Node _ _ r -> tree_max r + | Node dk dv _ Leaf -> (dk, dv) + | Node _ _ _ r -> tree_max r -(** Deletion **) -let rec delete_avl (#a: Type) (cmp:cmp a) (x: tree a) (key: a) : tree a = +(** Deletion *) +let rec delete_avl (#k: Type) (#v: Type) (cmp:cmp k) (x: tree k v) (key: k) : tree k v = match x with | Leaf -> Leaf - | Node data left right -> - let delta = cmp data key in + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in if delta = 0 then ( match left, right with | Leaf , Leaf -> Leaf - | Node _ _ _ , Leaf -> left - | Leaf , Node _ _ _ -> right + | Node _ _ _ _ , Leaf -> left + | Leaf , Node _ _ _ _ -> right | _ -> let m = tree_max left in - let new_left = delete_avl cmp left m in - let tmp = Node m new_left right in + let new_left = delete_avl cmp left (fst m) in + let tmp = Node (fst m) (snd m) new_left right in rebalance_avl tmp ) else ( if delta < 0 then ( assert (delta < 0); - let new_left = delete_avl cmp left key in - let tmp = Node data new_left right in + let new_right = delete_avl cmp right key in + let tmp = Node nd_key nd_val left new_right in rebalance_avl tmp ) else ( assert (delta > 0); - let new_right = delete_avl cmp right key in - let tmp = Node data left new_right in + let new_left = delete_avl cmp left key in + let tmp = Node nd_key nd_val new_left right in rebalance_avl tmp ) ) @@ -553,12 +549,12 @@ let rec delete_avl (#a: Type) (cmp:cmp a) (x: tree a) (key: a) : tree a = #push-options "--z3rlimit 50" -let rec insert_avl_proof_aux (#a: Type) (cmp:cmp a) (x: avl a cmp) (key: a) - (root:a) +let rec insert_avl_proof_aux (#k: Type) (#v: Type) (cmp:cmp k) (x: avl k v cmp) (key: k) (val_: v) + (root:k) : Lemma (requires is_avl cmp x) (ensures ( - let res = insert_avl cmp x key in + let res = insert_avl cmp x key val_ in is_avl cmp res /\ height x <= height res /\ height res <= height x + 1 /\ @@ -568,24 +564,24 @@ let rec insert_avl_proof_aux (#a: Type) (cmp:cmp a) (x: avl a cmp) (key: a) ) = match x with | Leaf -> () - | Node data left right -> - let delta = cmp data key in + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in if delta >= 0 then ( - let new_left = insert_avl cmp left key in - let tmp = Node data new_left right in + let new_left = insert_avl cmp left key val_ in + let tmp = Node nd_key nd_val new_left right in - insert_avl_proof_aux cmp left key data; + insert_avl_proof_aux cmp left key val_ nd_key; // Need this one for propagating that all elements are smaller than root - insert_avl_proof_aux cmp left key root; + insert_avl_proof_aux cmp left key val_ root; rebalance_avl_proof cmp tmp root ) else ( - let new_right = insert_avl cmp right key in - let tmp = Node data left new_right in + let new_right = insert_avl cmp right key val_ in + let tmp = Node nd_key nd_val left new_right in - insert_avl_proof_aux cmp right key data; - insert_avl_proof_aux cmp right key root; + insert_avl_proof_aux cmp right key val_ nd_key; + insert_avl_proof_aux cmp right key val_ root; rebalance_avl_proof cmp tmp root ) @@ -594,7 +590,1553 @@ let rec insert_avl_proof_aux (#a: Type) (cmp:cmp a) (x: avl a cmp) (key: a) #pop-options -let insert_avl_proof (#a: Type) (cmp:cmp a) (x: avl a cmp) (key: a) - : Lemma (requires is_avl cmp x) (ensures is_avl cmp (insert_avl cmp x key)) - = Classical.forall_intro (Classical.move_requires (insert_avl_proof_aux cmp x key)) +let insert_avl_proof (#k: Type) (#v: Type) (cmp:cmp k) (x: avl k v cmp) (key: k) (val_: v) + : Lemma (requires is_avl cmp x) (ensures is_avl cmp (insert_avl cmp x key val_)) + = Classical.forall_intro (Classical.move_requires (insert_avl_proof_aux cmp x key val_)) + +/// Minimum element of a non-empty tree (leftmost) +let rec tree_min (#k: Type) (#v: Type) (x: tree k v{Node? x}) : (k & v) = + match x with + | Node dk dv Leaf _ -> (dk, dv) + | Node _ _ l _ -> tree_min l + +/// Find largest element <= key (predecessor query). Returns None if no such element. +let rec find_le (#k: Type) (#v: Type) (cmp: cmp k) (x: tree k v) (key: k) : option (k & v) = + match x with + | Leaf -> None + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in + if delta > 0 then + // nd_key > key, go left + find_le cmp left key + else if delta = 0 then + // nd_key = key, exact match + Some (nd_key, nd_val) + else + // nd_key < key, nd_key is a candidate; check if right subtree has a better one + match find_le cmp right key with + | Some r -> Some r + | None -> Some (nd_key, nd_val) + +/// Find smallest element >= key (successor query). Returns None if no such element. +let rec find_ge (#k: Type) (#v: Type) (cmp: cmp k) (x: tree k v) (key: k) : option (k & v) = + match x with + | Leaf -> None + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in + if delta < 0 then + // nd_key < key, go right + find_ge cmp right key + else if delta = 0 then + // nd_key = key, exact match + Some (nd_key, nd_val) + else + // nd_key > key, nd_key is a candidate; check if left subtree has a better one + match find_ge cmp left key with + | Some r -> Some r + | None -> Some (nd_key, nd_val) + + +(*** Inorder traversal and sorted sequence correspondence *) + +/// In-order traversal of a tree, producing a sequence +let rec inorder (#k: Type) (#v: Type) (t: tree k v) : Tot (Seq.seq (k & v)) (decreases t) = + match t with + | Leaf -> Seq.empty + | Node dk dv l r -> Seq.append (inorder l) (Seq.cons (dk, dv) (inorder r)) + +/// All elements of a sequence satisfy a boolean predicate on keys +let rec seq_forall (#k: Type) (#v: Type) (f: k -> bool) (s: Seq.seq (k & v)) + : Tot bool (decreases Seq.length s) = + if Seq.length s = 0 then true + else f (fst (Seq.head s)) && seq_forall f (Seq.tail s) + +/// A sequence is sorted w.r.t. a comparison function +let rec sorted (#k: Type) (#v: Type) (compare: cmp k) (s: Seq.seq (k & v)) + : Tot bool (decreases Seq.length s) = + if Seq.length s <= 1 then true + else compare (fst (Seq.head s)) (fst (Seq.index s 1)) <= 0 && sorted compare (Seq.tail s) + +/// Insert an element at the correct position in a sorted sequence +let rec sorted_insert (#k: Type) (#v: Type) (compare: cmp k) (kv: (k & v)) (s: Seq.seq (k & v)) + : Tot (Seq.seq (k & v)) (decreases Seq.length s) = + if Seq.length s = 0 then Seq.create 1 kv + else + let hd = Seq.head s in + if compare (fst hd) (fst kv) >= 0 then + Seq.cons kv s + else + Seq.cons hd (sorted_insert compare kv (Seq.tail s)) + +/// Remove the first occurrence of an element equal to key from a sorted sequence +let rec sorted_remove (#k: Type) (#v: Type) (compare: cmp k) (key: k) (s: Seq.seq (k & v)) + : Tot (Seq.seq (k & v)) (decreases Seq.length s) = + if Seq.length s = 0 then Seq.empty + else + let hd = Seq.head s in + if compare (fst hd) key = 0 then Seq.tail s + else Seq.cons hd (sorted_remove compare key (Seq.tail s)) + +(** A2: Rotation lemmas — rotations preserve inorder traversal *) + +#push-options "--fuel 3 --z3rlimit 40" + +let rotate_left_inorder (#k: Type) (#v: Type) (r: tree k v) + : Lemma (requires Some? (rotate_left r)) + (ensures Seq.equal (inorder (Some?.v (rotate_left r))) (inorder r)) + = match r with + | Node xk xv t1 (Node zk zv t2 t3) -> + Seq.append_assoc (inorder t1) (Seq.cons (xk, xv) (inorder t2)) (Seq.cons (zk, zv) (inorder t3)) + +let rotate_right_inorder (#k: Type) (#v: Type) (r: tree k v) + : Lemma (requires Some? (rotate_right r)) + (ensures Seq.equal (inorder (Some?.v (rotate_right r))) (inorder r)) + = match r with + | Node xk xv (Node zk zv t1 t2) t3 -> + Seq.append_assoc (inorder t1) (Seq.cons (zk, zv) (inorder t2)) (Seq.cons (xk, xv) (inorder t3)) + +let rotate_right_left_inorder (#k: Type) (#v: Type) (r: tree k v) + : Lemma (requires Some? (rotate_right_left r)) + (ensures Seq.equal (inorder (Some?.v (rotate_right_left r))) (inorder r)) + = match r with + | Node xk xv t1 (Node zk zv (Node yk yv t2 t3) t4) -> + let l1 = inorder t1 in let l2 = inorder t2 in + let l3 = inorder t3 in let l4 = inorder t4 in + // Original: l1 ++ cons x ((l2 ++ cons y l3) ++ cons z l4) + // Target: (l1 ++ cons x l2) ++ cons y (l3 ++ cons z l4) + Seq.append_assoc l2 (Seq.cons (yk, yv) l3) (Seq.cons (zk, zv) l4); + Seq.append_assoc (Seq.create 1 (yk, yv)) l3 (Seq.cons (zk, zv) l4); + Seq.append_assoc (Seq.create 1 (xk, xv)) l2 (Seq.cons (yk, yv) (Seq.append l3 (Seq.cons (zk, zv) l4))); + Seq.append_assoc l1 (Seq.cons (xk, xv) l2) (Seq.cons (yk, yv) (Seq.append l3 (Seq.cons (zk, zv) l4))) + +let rotate_left_right_inorder (#k: Type) (#v: Type) (r: tree k v) + : Lemma (requires Some? (rotate_left_right r)) + (ensures Seq.equal (inorder (Some?.v (rotate_left_right r))) (inorder r)) + = match r with + | Node xk xv (Node zk zv t1 (Node yk yv t2 t3)) t4 -> + let l1 = inorder t1 in let l2 = inorder t2 in + let l3 = inorder t3 in let l4 = inorder t4 in + // Original: (l1 ++ cons z (l2 ++ cons y l3)) ++ cons x l4 + // Target: (l1 ++ cons z l2) ++ cons y (l3 ++ cons x l4) + Seq.append_assoc (Seq.create 1 (zk, zv)) l2 (Seq.cons (yk, yv) l3); + Seq.append_assoc l1 (Seq.cons (zk, zv) l2) (Seq.cons (yk, yv) l3); + Seq.append_assoc (Seq.append l1 (Seq.cons (zk, zv) l2)) (Seq.cons (yk, yv) l3) (Seq.cons (xk, xv) l4); + Seq.append_assoc (Seq.create 1 (yk, yv)) l3 (Seq.cons (xk, xv) l4) + +let rebalance_inorder (#k: Type) (#v: Type) (t: tree k v) + : Lemma (ensures Seq.equal (inorder (rebalance_avl t)) (inorder t)) + = match t with + | Leaf -> () + | Node nd_key nd_val left right -> + if is_balanced t then () + else if height left - height right > 1 then ( + let Node lk lv lleft lright = left in + if height lright > height lleft then + rotate_left_right_inorder t + else + rotate_right_inorder t + ) else if height left - height right < -1 then ( + let Node rk rv rleft rright = right in + if height rleft > height rright then + rotate_right_left_inorder t + else + rotate_left_inorder t + ) else () + +#pop-options + +(** A3: BST implies sorted inorder *) + +#push-options "--fuel 2 --z3rlimit 40" + +/// Helper: seq_forall distributes over cons +let seq_forall_cons (#k: Type) (#v: Type) (f: k -> bool) (kv: (k & v)) (s: Seq.seq (k & v)) + : Lemma (requires f (fst kv) /\ seq_forall f s) + (ensures seq_forall f (Seq.cons kv s)) + = let s' = Seq.cons kv s in + Seq.lemma_head_append (Seq.create 1 kv) s; + Seq.lemma_tail_append (Seq.create 1 kv) s; + assert (Seq.head s' == kv); + assert (Seq.equal (Seq.tail s') s) + +/// Helper: seq_forall distributes over append +let rec seq_forall_append (#k: Type) (#v: Type) (f: k -> bool) (s1 s2: Seq.seq (k & v)) + : Lemma (requires seq_forall f s1 /\ seq_forall f s2) + (ensures seq_forall f (Seq.append s1 s2)) + (decreases Seq.length s1) + = if Seq.length s1 = 0 then ( + assert (Seq.equal s1 (Seq.empty #(k & v))); + assert (Seq.equal (Seq.append s1 s2) s2) + ) else ( + seq_forall_append f (Seq.tail s1) s2; + Seq.lemma_head_append s1 s2; + Seq.lemma_tail_append s1 s2 + ) + +/// Bridge: forall_keys on tree implies seq_forall on inorder +let rec forall_keys_inorder (#k: Type) (#v: Type) (t: tree k v) (cond: k -> bool) + : Lemma (requires forall_keys t cond) + (ensures seq_forall cond (inorder t)) + (decreases t) + = match t with + | Leaf -> () + | Node dk dv l r -> + forall_keys_inorder l cond; + forall_keys_inorder r cond; + seq_forall_cons cond (dk, dv) (inorder r); + seq_forall_append cond (inorder l) (Seq.cons (dk, dv) (inorder r)) + +/// Helper: sorted left + all left ≤ d + d ≤ all right + sorted right → sorted (left ++ [d] ++ right) + +let sorted_cons_right (#k: Type) (#v: Type) (compare: cmp k) (d: (k & v)) (s: Seq.seq (k & v)) + : Lemma (requires sorted compare s /\ (Seq.length s > 0 ==> compare (fst d) (fst (Seq.head s)) <= 0)) + (ensures sorted compare (Seq.cons d s)) + = let cs = Seq.cons d s in + Seq.lemma_head_append (Seq.create 1 d) s; + Seq.lemma_tail_append (Seq.create 1 d) s; + assert (Seq.head cs == d); + assert (Seq.equal (Seq.tail cs) s); + if Seq.length s = 0 then () + else ( + assert (Seq.index cs 1 == Seq.head s) + ) + +let rec sorted_append (#k: Type) (#v: Type) (compare: cmp k) (s1: Seq.seq (k & v)) (d: (k & v)) (s2: Seq.seq (k & v)) + : Lemma (requires + sorted compare s1 /\ + sorted compare s2 /\ + seq_forall (key_left compare (fst d)) s1 /\ + seq_forall (key_right compare (fst d)) s2) + (ensures sorted compare (Seq.append s1 (Seq.cons d s2))) + (decreases Seq.length s1) + = let ds2 = Seq.cons d s2 in + if Seq.length s1 = 0 then ( + assert (Seq.equal s1 (Seq.empty #(k & v))); + assert (Seq.equal (Seq.append s1 ds2) ds2); + sorted_cons_right compare d s2 + ) else ( + let hd = Seq.head s1 in + let tl = Seq.tail s1 in + sorted_append compare tl d s2; + Seq.lemma_head_append s1 ds2; + Seq.lemma_tail_append s1 ds2; + if Seq.length tl = 0 then ( + assert (Seq.equal tl (Seq.empty #(k & v))); + assert (Seq.equal (Seq.append tl ds2) ds2); + Seq.lemma_head_append (Seq.create 1 d) s2; + assert (key_left compare (fst d) (fst hd)) + ) else ( + Seq.lemma_head_append tl ds2 + ) + ) + +/// The main theorem: BST inorder is sorted +let rec is_bst_sorted_inorder (#k: Type) (#v: Type) (compare: cmp k) (t: tree k v) + : Lemma (requires is_bst compare t) + (ensures sorted compare (inorder t)) + (decreases t) + = match t with + | Leaf -> () + | Node dk dv l r -> + is_bst_sorted_inorder compare l; + is_bst_sorted_inorder compare r; + forall_keys_inorder l (key_left compare dk); + forall_keys_inorder r (key_right compare dk); + sorted_append compare (inorder l) (dk, dv) (inorder r) + +#pop-options + +(** A4: Insert correspondence — inorder(insert_bst t k) == sorted_insert k (inorder t) *) + +#push-options "--fuel 3 --z3rlimit 60" + +/// Helper: sorted_insert into append — kv goes left when d >= kv +let rec sorted_insert_append_left (#k: Type) (#v: Type) (compare: cmp k) (s1: Seq.seq (k & v)) (d: (k & v)) (s2: Seq.seq (k & v)) (kv: (k & v)) + : Lemma + (requires sorted compare (Seq.append s1 (Seq.cons d s2)) /\ + seq_forall (key_left compare (fst d)) s1 /\ + seq_forall (key_right compare (fst d)) s2 /\ + compare (fst d) (fst kv) >= 0) + (ensures Seq.equal + (sorted_insert compare kv (Seq.append s1 (Seq.cons d s2))) + (Seq.append (sorted_insert compare kv s1) (Seq.cons d s2))) + (decreases Seq.length s1) + = if Seq.length s1 = 0 then ( + assert (Seq.equal s1 (Seq.empty #(k & v))); + assert (Seq.equal (Seq.append s1 (Seq.cons d s2)) (Seq.cons d s2)); + Seq.lemma_head_append (Seq.create 1 d) s2; + assert (Seq.head (Seq.cons d s2) == d) + ) else ( + let hd = Seq.head s1 in + let tl = Seq.tail s1 in + Seq.lemma_head_append s1 (Seq.cons d s2); + Seq.lemma_tail_append s1 (Seq.cons d s2); + if compare (fst hd) (fst kv) >= 0 then () + else + sorted_insert_append_left compare tl d s2 kv + ) + +/// Helper: sorted_insert into append — kv goes right when d < kv +let rec sorted_insert_append_right (#k: Type) (#v: Type) (compare: cmp k) (s1: Seq.seq (k & v)) (d: (k & v)) (s2: Seq.seq (k & v)) (kv: (k & v)) + : Lemma + (requires sorted compare (Seq.append s1 (Seq.cons d s2)) /\ + seq_forall (key_left compare (fst d)) s1 /\ + seq_forall (key_right compare (fst d)) s2 /\ + compare (fst d) (fst kv) < 0) + (ensures Seq.equal + (sorted_insert compare kv (Seq.append s1 (Seq.cons d s2))) + (Seq.append s1 (Seq.cons d (sorted_insert compare kv s2)))) + (decreases Seq.length s1) + = if Seq.length s1 = 0 then ( + assert (Seq.equal s1 (Seq.empty #(k & v))); + assert (Seq.equal (Seq.append s1 (Seq.cons d s2)) (Seq.cons d s2)); + Seq.lemma_head_append (Seq.create 1 d) s2; + Seq.lemma_tail_append (Seq.create 1 d) s2; + assert (Seq.equal (Seq.tail (Seq.cons d s2)) s2) + ) else ( + let hd = Seq.head s1 in + let tl = Seq.tail s1 in + Seq.lemma_head_append s1 (Seq.cons d s2); + Seq.lemma_tail_append s1 (Seq.cons d s2); + assert (key_left compare (fst d) (fst hd)); + sorted_insert_append_right compare tl d s2 kv + ) + +/// inorder(insert_bst t k) == sorted_insert k (inorder t) +let rec insert_bst_inorder (#k: Type) (#v: Type) (compare: cmp k) (t: bst k v compare) (key: k) (val_: v) + : Lemma (ensures Seq.equal (inorder (insert_bst compare t key val_)) (sorted_insert compare (key, val_) (inorder t))) + (decreases t) + = match t with + | Leaf -> () + | Node dk dv l r -> + let delta = compare dk key in + is_bst_sorted_inorder compare t; + forall_keys_inorder l (key_left compare dk); + forall_keys_inorder r (key_right compare dk); + if delta >= 0 then ( + insert_bst_inorder compare l key val_; + sorted_insert_append_left compare (inorder l) (dk, dv) (inorder r) (key, val_) + ) else ( + insert_bst_inorder compare r key val_; + sorted_insert_append_right compare (inorder l) (dk, dv) (inorder r) (key, val_) + ) + +/// inorder(insert_avl t k) == sorted_insert k (inorder t) +let rec insert_avl_inorder (#k: Type) (#v: Type) (compare: cmp k) (t: avl k v compare) (key: k) (val_: v) + : Lemma (ensures Seq.equal (inorder (insert_avl compare t key val_)) (sorted_insert compare (key, val_) (inorder t))) + (decreases t) + = match t with + | Leaf -> () + | Node dk dv l r -> + let delta = compare dk key in + is_bst_sorted_inorder compare t; + forall_keys_inorder l (key_left compare dk); + forall_keys_inorder r (key_right compare dk); + if delta >= 0 then ( + insert_avl_inorder compare l key val_; + let new_left = insert_avl compare l key val_ in + let tmp = Node dk dv new_left r in + rebalance_inorder tmp; + sorted_insert_append_left compare (inorder l) (dk, dv) (inorder r) (key, val_) + ) else ( + insert_avl_inorder compare r key val_; + let new_right = insert_avl compare r key val_ in + let tmp = Node dk dv l new_right in + rebalance_inorder tmp; + sorted_insert_append_right compare (inorder l) (dk, dv) (inorder r) (key, val_) + ) + +#pop-options + +(** A5: Delete correspondence — inorder(delete_avl t k) == sorted_remove k (inorder t) *) + +#push-options "--fuel 3 --z3rlimit 60" + +/// tree_max is the last element of inorder +let rec tree_max_last_inorder (#k: Type) (#v: Type) (t: tree k v{Node? t}) + : Lemma (ensures Seq.length (inorder t) > 0 /\ + tree_max t == Seq.index (inorder t) (Seq.length (inorder t) - 1)) + (decreases t) + = match t with + | Node dk dv l Leaf -> + Seq.lemma_len_append (inorder l) (Seq.cons (dk, dv) Seq.empty); + Seq.lemma_index_app2 (inorder l) (Seq.cons (dk, dv) Seq.empty) (Seq.length (inorder l) + 1 - 1) + | Node dk dv l r -> + tree_max_last_inorder r; + let ir = inorder r in + Seq.lemma_len_append (inorder l) (Seq.cons (dk, dv) ir); + let full = Seq.append (inorder l) (Seq.cons (dk, dv) ir) in + let full_len = Seq.length full in + Seq.lemma_index_app2 (inorder l) (Seq.cons (dk, dv) ir) (full_len - 1); + let idx_in_cons = full_len - 1 - Seq.length (inorder l) in + Seq.lemma_index_app2 (Seq.create 1 (dk, dv)) ir idx_in_cons + +/// Helper: sorted_remove on a sequence that doesn't contain key is identity +let rec sorted_remove_absent (#k: Type) (#v: Type) (compare: cmp k) (key: k) (s: Seq.seq (k & v)) + : Lemma (requires (forall (i:nat). i < Seq.length s ==> compare (fst (Seq.index s i)) key <> 0)) + (ensures Seq.equal (sorted_remove compare key s) s) + (decreases Seq.length s) + = if Seq.length s = 0 then () + else ( + assert (compare (fst (Seq.index s 0)) key <> 0); + assert (compare (fst (Seq.head s)) key <> 0); + sorted_remove_absent compare key (Seq.tail s) + ) + +/// seq_forall implies pointwise condition +let rec seq_forall_index (#k: Type) (#v: Type) (f: k -> bool) (s: Seq.seq (k & v)) + : Lemma (requires seq_forall f s) + (ensures forall (i:nat). i < Seq.length s ==> f (fst (Seq.index s i))) + (decreases Seq.length s) + = if Seq.length s = 0 then () + else ( + seq_forall_index f (Seq.tail s); + let aux (i:nat{i < Seq.length s}) : Lemma (f (fst (Seq.index s i))) = + if i = 0 then () + else ( + assert (i - 1 < Seq.length (Seq.tail s)); + assert (Seq.index s i == Seq.index (Seq.tail s) (i - 1)) + ) + in + Classical.forall_intro (fun (i:nat{i < Seq.length s}) -> aux i) + ) + +/// Helper: sorted_remove on append when key < all of s1 — passes through s1 to find d +let rec sorted_remove_append_left (#k: Type) (#v: Type) (compare: cmp k) (s1: Seq.seq (k & v)) (d: (k & v)) (s2: Seq.seq (k & v)) (key: k) + : Lemma + (requires sorted compare (Seq.append s1 (Seq.cons d s2)) /\ + seq_forall (key_left compare (fst d)) s1 /\ + seq_forall (key_right compare (fst d)) s2 /\ + compare (fst d) key > 0) + (ensures Seq.equal + (sorted_remove compare key (Seq.append s1 (Seq.cons d s2))) + (Seq.append (sorted_remove compare key s1) (Seq.cons d s2))) + (decreases Seq.length s1) + = if Seq.length s1 = 0 then ( + assert (Seq.equal s1 (Seq.empty #(k & v))); + assert (Seq.equal (Seq.append s1 (Seq.cons d s2)) (Seq.cons d s2)); + Seq.lemma_head_append (Seq.create 1 d) s2; + Seq.lemma_tail_append (Seq.create 1 d) s2; + assert (Seq.head (Seq.cons d s2) == d); + assert (compare (fst d) key <> 0); + // key < d <= all of s2, so key not in s2 + seq_forall_index (key_right compare (fst d)) s2; + sorted_remove_absent compare key s2; + assert (Seq.equal (sorted_remove compare key s2) s2); + assert (Seq.equal (Seq.tail (Seq.cons d s2)) s2); + assert (Seq.equal (Seq.append (Seq.empty #(k & v)) (Seq.cons d s2)) (Seq.cons d s2)) + ) else ( + let hd = Seq.head s1 in + let tl = Seq.tail s1 in + Seq.lemma_head_append s1 (Seq.cons d s2); + Seq.lemma_tail_append s1 (Seq.cons d s2); + if compare (fst hd) key = 0 then () + else + sorted_remove_append_left compare tl d s2 key + ) + +/// Helper: sorted_remove on append when key goes right past d +let rec sorted_remove_append_right (#k: Type) (#v: Type) (compare: cmp k) (s1: Seq.seq (k & v)) (d: (k & v)) (s2: Seq.seq (k & v)) (key: k) + : Lemma + (requires sorted compare (Seq.append s1 (Seq.cons d s2)) /\ + seq_forall (key_left compare (fst d)) s1 /\ + seq_forall (key_right compare (fst d)) s2 /\ + compare (fst d) key < 0) + (ensures Seq.equal + (sorted_remove compare key (Seq.append s1 (Seq.cons d s2))) + (Seq.append s1 (Seq.cons d (sorted_remove compare key s2)))) + (decreases Seq.length s1) + = if Seq.length s1 = 0 then ( + assert (Seq.equal s1 (Seq.empty #(k & v))); + assert (Seq.equal (Seq.append s1 (Seq.cons d s2)) (Seq.cons d s2)); + Seq.lemma_head_append (Seq.create 1 d) s2; + Seq.lemma_tail_append (Seq.create 1 d) s2; + assert (Seq.equal (Seq.tail (Seq.cons d s2)) s2); + assert (compare (fst d) key <> 0) + ) else ( + let hd = Seq.head s1 in + let tl = Seq.tail s1 in + Seq.lemma_head_append s1 (Seq.cons d s2); + Seq.lemma_tail_append s1 (Seq.cons d s2); + assert (key_left compare (fst d) (fst hd)); + // compare (fst d) (fst hd) >= 0, compare (fst d) key < 0, so compare hd key <= compare hd d <= 0 < compare key (fst d) + // Actually we need compare hd key != 0. Since d >= hd and d < key, we have hd <= d < key, so hd < key + // meaning compare hd key < 0 != 0. + assert (compare (fst d) (fst hd) >= 0); + // By transitivity: hd <= d and d < key means hd < key means compare hd key < 0 + sorted_remove_append_right compare tl d s2 key + ) + +/// Helper: sorted_remove on append when key = d — removes d from middle +/// Requires: no element in s1 compares equal to d (strict BST left property) +let rec sorted_remove_append_mid (#k: Type) (#v: Type) (compare: cmp k) (s1: Seq.seq (k & v)) (d: (k & v)) (s2: Seq.seq (k & v)) + : Lemma + (requires sorted compare (Seq.append s1 (Seq.cons d s2)) /\ + seq_forall (key_left compare (fst d)) s1 /\ + (forall (i:nat). i < Seq.length s1 ==> compare (fst (Seq.index s1 i)) (fst d) <> 0)) + (ensures Seq.equal + (sorted_remove compare (fst d) (Seq.append s1 (Seq.cons d s2))) + (Seq.append s1 s2)) + (decreases Seq.length s1) + = if Seq.length s1 = 0 then ( + assert (Seq.equal s1 (Seq.empty #(k & v))); + assert (Seq.equal (Seq.append s1 (Seq.cons d s2)) (Seq.cons d s2)); + Seq.lemma_head_append (Seq.create 1 d) s2; + Seq.lemma_tail_append (Seq.create 1 d) s2; + assert (Seq.equal (Seq.tail (Seq.cons d s2)) s2); + assert (compare (fst d) (fst d) = 0); + assert (Seq.equal (Seq.append s1 s2) s2) + ) else ( + let hd = Seq.head s1 in + let tl = Seq.tail s1 in + Seq.lemma_head_append s1 (Seq.cons d s2); + Seq.lemma_tail_append s1 (Seq.cons d s2); + assert (compare (fst (Seq.index s1 0)) (fst d) <> 0); + assert (compare (fst hd) (fst d) <> 0); + sorted_remove_append_mid compare tl d s2 + ) + +/// No-duplicate tree: each node's data is compare-distinct from all subtree elements +let rec no_dup_tree (#k: Type) (#v: Type) (compare: cmp k) (t: tree k v) : Tot bool (decreases t) = + match t with + | Leaf -> true + | Node dk dv l r -> + forall_keys l (fun x -> compare dk x <> 0) && + forall_keys r (fun x -> compare dk x <> 0) && + no_dup_tree compare l && + no_dup_tree compare r + +/// forall_keys on tree implies condition on tree_max +let rec forall_keys_tree_max (#k: Type) (#v: Type) (t: tree k v) (f: k -> bool) + : Lemma (requires forall_keys t f /\ Node? t) + (ensures f (fst (tree_max t))) + (decreases t) + = match t with + | Node _ _ _ Leaf -> () + | Node _ _ _ r -> forall_keys_tree_max r f + +/// sorted_remove is invariant under compare-equal key substitution +let rec sorted_remove_cmp_eq (#k: Type) (#v: Type) (compare: cmp k) (k1 k2: k) (s: Seq.seq (k & v)) + : Lemma (requires compare k1 k2 = 0) + (ensures Seq.equal (sorted_remove compare k1 s) (sorted_remove compare k2 s)) + (decreases Seq.length s) + = if Seq.length s = 0 then () + else if compare (fst (Seq.head s)) k1 = 0 then () + else sorted_remove_cmp_eq compare k1 k2 (Seq.tail s) + +/// Helper to decompose no_dup_tree at a Node +let no_dup_tree_node (#k: Type) (#v: Type) (compare: cmp k) (dk: k) (dv: v) (l r: tree k v) + : Lemma (requires no_dup_tree compare (Node dk dv l r)) + (ensures forall_keys l (fun x -> compare dk x <> 0) /\ + forall_keys r (fun x -> compare dk x <> 0) /\ + no_dup_tree compare l /\ + no_dup_tree compare r) + = normalize_term_spec (no_dup_tree compare (Node dk dv l r)) + +/// Removing tree_max from inorder and re-appending gives back inorder +let rec remove_max_reappend (#k: Type) (#v: Type) (compare: cmp k) (t: tree k v) + : Lemma (requires is_bst compare t /\ no_dup_tree compare t /\ Node? t) + (ensures (let m = tree_max t in + Seq.equal (Seq.append (sorted_remove compare (fst m) (inorder t)) (Seq.create 1 m)) (inorder t))) + (decreases t) + = match t with + | Node dk dv l Leaf -> + no_dup_tree_node compare dk dv l Leaf; + is_bst_sorted_inorder compare t; + forall_keys_inorder l (key_left compare dk); + forall_keys_inorder l (fun x -> compare dk x <> 0); + seq_forall_index (fun x -> compare dk x <> 0) (inorder l); + sorted_remove_append_mid compare (inorder l) (dk, dv) (Seq.empty #(k & v)); + Seq.append_assoc (inorder l) (Seq.empty #(k & v)) (Seq.create 1 (dk, dv)) + | Node dk dv l r -> + no_dup_tree_node compare dk dv l r; + let m = tree_max r in + remove_max_reappend compare r; + is_bst_sorted_inorder compare t; + forall_keys_inorder l (key_left compare dk); + forall_keys_inorder r (key_right compare dk); + forall_keys_tree_max r (fun x -> compare dk x <> 0); + forall_keys_tree_max r (key_right compare dk); + assert (compare dk (fst m) <> 0 /\ compare dk (fst m) <= 0); + sorted_remove_append_right compare (inorder l) (dk, dv) (inorder r) (fst m); + Seq.append_assoc (inorder l) (Seq.cons (dk, dv) (sorted_remove compare (fst m) (inorder r))) (Seq.create 1 m); + Seq.append_assoc (Seq.create 1 (dk, dv)) (sorted_remove compare (fst m) (inorder r)) (Seq.create 1 m) + +/// inorder(delete_avl t k) == sorted_remove k (inorder t) (requires no-dup BST) +let rec delete_avl_inorder (#k: Type) (#v: Type) (compare: cmp k) (t: tree k v) (key: k) + : Lemma (requires is_bst compare t /\ no_dup_tree compare t) + (ensures Seq.equal (inorder (delete_avl compare t key)) (sorted_remove compare key (inorder t))) + (decreases t) + = match t with + | Leaf -> () + | Node dk dv l r -> + no_dup_tree_node compare dk dv l r; + let delta = compare dk key in + is_bst_sorted_inorder compare t; + forall_keys_inorder l (key_left compare dk); + forall_keys_inorder r (key_right compare dk); + if delta = 0 then ( + match l, r with + | Leaf, Leaf -> + Seq.lemma_head_append (Seq.create 1 (dk, dv)) (Seq.empty #(k & v)); + Seq.lemma_tail_append (Seq.create 1 (dk, dv)) (Seq.empty #(k & v)) + | Node _ _ _ _, Leaf -> + forall_keys_inorder l (fun x -> compare dk x <> 0); + seq_forall_index (fun x -> compare dk x <> 0) (inorder l); + sorted_remove_append_mid compare (inorder l) (dk, dv) (Seq.empty #(k & v)); + sorted_remove_cmp_eq compare dk key (Seq.append (inorder l) (Seq.cons (dk, dv) (Seq.empty #(k & v)))) + | Leaf, Node _ _ _ _ -> + Seq.lemma_head_append (Seq.create 1 (dk, dv)) (inorder r); + Seq.lemma_tail_append (Seq.create 1 (dk, dv)) (inorder r) + | _ -> + let m = tree_max l in + // IH: inorder(delete l m) = sorted_remove m (inorder l) + delete_avl_inorder compare l (fst m); + // rebalance preserves inorder + rebalance_inorder (Node (fst m) (snd m) (delete_avl compare l (fst m)) r); + // remove_max_reappend: sorted_remove m (inorder l) ++ [m] = inorder l + remove_max_reappend compare l; + // Use assoc to show: sr_m_l ++ [m] ++ inorder r = inorder l ++ inorder r + Seq.append_assoc (sorted_remove compare (fst m) (inorder l)) (Seq.create 1 m) (inorder r); + // RHS: sorted_remove key (inorder t) = inorder l ++ inorder r + forall_keys_inorder l (fun x -> compare dk x <> 0); + seq_forall_index (fun x -> compare dk x <> 0) (inorder l); + sorted_remove_append_mid compare (inorder l) (dk, dv) (inorder r); + sorted_remove_cmp_eq compare dk key (Seq.append (inorder l) (Seq.cons (dk, dv) (inorder r))) + ) else if delta < 0 then ( + // data < key, recurse on right (fixed) + delete_avl_inorder compare r key; + rebalance_inorder (Node dk dv l (delete_avl compare r key)); + sorted_remove_append_right compare (inorder l) (dk, dv) (inorder r) key + ) else ( + // data > key, recurse on left (fixed) + delete_avl_inorder compare l key; + rebalance_inorder (Node dk dv (delete_avl compare l key) r); + sorted_remove_append_left compare (inorder l) (dk, dv) (inorder r) key + ) + +(** A6: Membership correspondence *) + +// TODO: mem_inorder needs rework for (k & v) pairs +// mem t x checks key membership, but Seq.mem operates on (k & v) pairs + +#pop-options + +(** A7: delete_avl preserves BST *) + +#push-options "--fuel 2 --ifuel 1 --z3rlimit 40" + +/// tree_max is maximal — all keys in tree satisfy key_left cmp (fst (tree_max t) +let rec tree_max_is_maximal (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) + : Lemma (requires Node? t /\ is_bst cmp t) + (ensures forall_keys t (key_left cmp (fst (tree_max t)))) + (decreases t) + = match t with + | Node dk dv l Leaf -> () + | Node dk dv l r -> + tree_max_is_maximal cmp r; + forall_keys_tree_max r (key_right cmp dk); + forall_keys_trans l (key_left cmp dk) (key_left cmp (fst (tree_max r))) + +/// rebalance_avl preserves is_bst +let rebalance_bst (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) + : Lemma (requires is_bst cmp t) + (ensures is_bst cmp (rebalance_avl t)) + = match t with + | Leaf -> () + | Node nd_key nd_val left right -> + if is_balanced t then () + else if height left - height right > 1 then ( + let Node lk lv lleft lright = left in + if height lright > height lleft then + rotate_left_right_bst cmp t + else + rotate_right_bst cmp t + ) else if height left - height right < -1 then ( + let Node rk rv rleft rright = right in + if height rleft > height rright then + rotate_right_left_bst cmp t + else + rotate_left_bst cmp t + ) else () + +/// rebalance_avl preserves forall_keys for key_left +let rebalance_forall_left (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) (root: k) + : Lemma (requires is_bst cmp t /\ forall_keys t (key_left cmp root)) + (ensures forall_keys (rebalance_avl t) (key_left cmp root)) + = match t with + | Leaf -> () + | Node nd_key nd_val left right -> + if is_balanced t then () + else if height left - height right > 1 then ( + let Node lk lv lleft lright = left in + if height lright > height lleft then + rotate_left_right_key_left cmp t root + else + rotate_right_key_left cmp t root + ) else if height left - height right < -1 then ( + let Node rk rv rleft rright = right in + if height rleft > height rright then + rotate_right_left_key_left cmp t root + else + rotate_left_key_left cmp t root + ) else () + +/// rebalance_avl preserves forall_keys for key_right +let rebalance_forall_right (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) (root: k) + : Lemma (requires is_bst cmp t /\ forall_keys t (key_right cmp root)) + (ensures forall_keys (rebalance_avl t) (key_right cmp root)) + = match t with + | Leaf -> () + | Node nd_key nd_val left right -> + if is_balanced t then () + else if height left - height right > 1 then ( + let Node lk lv lleft lright = left in + if height lright > height lleft then + rotate_left_right_key_right cmp t root + else + rotate_right_key_right cmp t root + ) else if height left - height right < -1 then ( + let Node rk rv rleft rright = right in + if height rleft > height rright then + rotate_right_left_key_right cmp t root + else + rotate_left_key_right cmp t root + ) else () + +#pop-options + +#push-options "--fuel 2 --ifuel 2 --z3rlimit 100" + +let rec delete_avl_proof_aux (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) (key: k) (root: k) + : Lemma (requires is_bst cmp t) + (ensures ( + let res = delete_avl cmp t key in + is_bst cmp res /\ + (forall_keys t (key_left cmp root) ==> forall_keys res (key_left cmp root)) /\ + (forall_keys t (key_right cmp root) ==> forall_keys res (key_right cmp root)) + )) + (decreases t) + = match t with + | Leaf -> () + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in + if delta = 0 then ( + match left, right with + | Leaf, Leaf -> () + | Node _ _ _ _, Leaf -> () + | Leaf, Node _ _ _ _ -> () + | _, _ -> + let m = tree_max left in + delete_avl_proof_aux cmp left (fst m) nd_key; + delete_avl_proof_aux cmp left (fst m) (fst m); + delete_avl_proof_aux cmp left (fst m) root; + let new_left = delete_avl cmp left (fst m) in + let tmp = Node (fst m) (snd m) new_left right in + tree_max_is_maximal cmp left; + forall_keys_tree_max left (key_left cmp nd_key); + forall_keys_trans right (key_right cmp nd_key) (key_right cmp (fst m)); + rebalance_bst cmp tmp; + let aux_left () : Lemma (forall_keys t (key_left cmp root) ==> forall_keys (rebalance_avl tmp) (key_left cmp root)) = + if forall_keys t (key_left cmp root) then ( + forall_keys_tree_max left (key_left cmp root); + assert (forall_keys tmp (key_left cmp root)); + rebalance_forall_left cmp tmp root + ) + in + aux_left (); + let aux_right () : Lemma (forall_keys t (key_right cmp root) ==> forall_keys (rebalance_avl tmp) (key_right cmp root)) = + if forall_keys t (key_right cmp root) then ( + forall_keys_tree_max left (key_right cmp root); + assert (forall_keys tmp (key_right cmp root)); + rebalance_forall_right cmp tmp root + ) + in + aux_right () + ) else if delta < 0 then ( + delete_avl_proof_aux cmp right key nd_key; + delete_avl_proof_aux cmp right key root; + let new_right = delete_avl cmp right key in + let tmp = Node nd_key nd_val left new_right in + rebalance_bst cmp tmp; + let aux_left () : Lemma (forall_keys t (key_left cmp root) ==> forall_keys (rebalance_avl tmp) (key_left cmp root)) = + if forall_keys t (key_left cmp root) then ( + assert (forall_keys tmp (key_left cmp root)); + rebalance_forall_left cmp tmp root + ) + in + aux_left (); + let aux_right () : Lemma (forall_keys t (key_right cmp root) ==> forall_keys (rebalance_avl tmp) (key_right cmp root)) = + if forall_keys t (key_right cmp root) then ( + assert (forall_keys tmp (key_right cmp root)); + rebalance_forall_right cmp tmp root + ) + in + aux_right () + ) else ( + delete_avl_proof_aux cmp left key nd_key; + delete_avl_proof_aux cmp left key root; + let new_left = delete_avl cmp left key in + let tmp = Node nd_key nd_val new_left right in + rebalance_bst cmp tmp; + let aux_left () : Lemma (forall_keys t (key_left cmp root) ==> forall_keys (rebalance_avl tmp) (key_left cmp root)) = + if forall_keys t (key_left cmp root) then ( + assert (forall_keys tmp (key_left cmp root)); + rebalance_forall_left cmp tmp root + ) + in + aux_left (); + let aux_right () : Lemma (forall_keys t (key_right cmp root) ==> forall_keys (rebalance_avl tmp) (key_right cmp root)) = + if forall_keys t (key_right cmp root) then ( + assert (forall_keys tmp (key_right cmp root)); + rebalance_forall_right cmp tmp root + ) + in + aux_right () + ) + +/// delete_avl preserves is_bst +let delete_avl_preserves_bst (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) (key: k) + : Lemma (requires is_bst cmp t) + (ensures is_bst cmp (delete_avl cmp t key)) + = match t with + | Leaf -> () + | Node nd_key _ _ _ -> delete_avl_proof_aux cmp t key nd_key +#pop-options + +/// rotate_left preserves forall_keys for any predicate +#push-options "--fuel 2 --ifuel 2" +let rotate_left_forall_keys (#k: Type) (#v: Type) (r: tree k v) (f: k -> bool) + : Lemma (requires forall_keys r f /\ Some? (rotate_left r)) + (ensures forall_keys (Some?.v (rotate_left r)) f) + = () +#pop-options + +/// rotate_right preserves forall_keys for any predicate +#push-options "--fuel 2 --ifuel 2" +let rotate_right_forall_keys (#k: Type) (#v: Type) (r: tree k v) (f: k -> bool) + : Lemma (requires forall_keys r f /\ Some? (rotate_right r)) + (ensures forall_keys (Some?.v (rotate_right r)) f) + = () +#pop-options + +/// rotate_left_right preserves forall_keys for any predicate +#push-options "--fuel 2 --ifuel 2 --z3rlimit 100" +let rotate_left_right_forall_keys (#k: Type) (#v: Type) (r: tree k v) (f: k -> bool) + : Lemma (requires forall_keys r f /\ Some? (rotate_left_right r)) + (ensures forall_keys (Some?.v (rotate_left_right r)) f) + = match r with + | Node xk xv (Node zk zv t1 (Node yk yv t2 t3)) t4 -> + // rotate_left_right: Node xk xv (Node zk zv t1 (Node yk yv t2 t3)) t4 -> Node yk yv (Node zk zv t1 t2) (Node xk xv t3 t4) + normalize_term_spec (forall_keys r f); + normalize_term_spec (forall_keys (Some?.v (rotate_left_right r)) f) + | _ -> () +#pop-options + +/// rotate_right_left preserves forall_keys for any predicate +#push-options "--fuel 2 --ifuel 2 --z3rlimit 100" +let rotate_right_left_forall_keys (#k: Type) (#v: Type) (r: tree k v) (f: k -> bool) + : Lemma (requires forall_keys r f /\ Some? (rotate_right_left r)) + (ensures forall_keys (Some?.v (rotate_right_left r)) f) + = match r with + | Node xk xv t1 (Node zk zv (Node yk yv t2 t3) t4) -> + normalize_term_spec (forall_keys r f); + normalize_term_spec (forall_keys (Some?.v (rotate_right_left r)) f) + | _ -> () +#pop-options + +/// rebalance_avl preserves forall_keys +#push-options "--z3rlimit 50" +let rebalance_preserves_forall_keys (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) (f: k -> bool) + : Lemma (requires forall_keys t f) + (ensures forall_keys (rebalance_avl t) f) + = match t with + | Leaf -> () + | Node nd_key nd_val left right -> + if is_balanced t then () + else if height left - height right > 1 then ( + let Node lk lv lleft lright = left in + if height lright > height lleft then + rotate_left_right_forall_keys t f + else + rotate_right_forall_keys t f + ) else if height right - height left > 1 then ( + let Node rk rv rleft rright = right in + if height rleft > height rright then + rotate_right_left_forall_keys t f + else + rotate_left_forall_keys t f + ) else () +#pop-options + +/// rotate_left preserves no_dup_tree (requires is_bst) +#push-options "--fuel 2 --ifuel 2 --z3rlimit 10" +let rotate_left_no_dup (#k: Type) (#v: Type) (cmp: cmp k) (r: tree k v) + : Lemma (requires is_bst cmp r /\ no_dup_tree cmp r /\ Some? (rotate_left r)) + (ensures no_dup_tree cmp (Some?.v (rotate_left r))) + = match r with + | Node xk xv t1 (Node zk zv t2 t3) -> + // Result: Node zk zv (Node xk xv t1 t2) t3 + // Unpack original no_dup_tree + no_dup_tree_node cmp xk xv t1 (Node zk zv t2 t3); + no_dup_tree_node cmp zk zv t2 t3; + + // From BST: t1 < x < z (transitivity) + // All keys in t1 satisfy (k <= x) and (k >= x), but x < z, so k < z + forall_keys_trans t1 (key_left cmp xk) (fun kk -> cmp zk kk <> 0); + + // Explicitly normalize both subtrees and result + normalize_term_spec (no_dup_tree cmp (Node xk xv t1 t2)); + normalize_term_spec (no_dup_tree cmp (Node zk zv (Node xk xv t1 t2) t3)) +#pop-options + +/// rotate_right preserves no_dup_tree (requires is_bst) +#push-options "--fuel 2 --ifuel 2 --z3rlimit 10" +let rotate_right_no_dup (#k: Type) (#v: Type) (cmp: cmp k) (r: tree k v) + : Lemma (requires is_bst cmp r /\ no_dup_tree cmp r /\ Some? (rotate_right r)) + (ensures no_dup_tree cmp (Some?.v (rotate_right r))) + = match r with + | Node xk xv (Node zk zv t1 t2) t3 -> + // Result: Node zk zv t1 (Node xk xv t2 t3) + // Unpack original no_dup_tree + no_dup_tree_node cmp xk xv (Node zk zv t1 t2) t3; + no_dup_tree_node cmp zk zv t1 t2; + + // From BST: z < x < t3 (transitivity) + // All keys in t3 satisfy (k >= x) and z < x, so z < k + forall_keys_trans t3 (key_right cmp xk) (fun kk -> cmp zk kk <> 0); + + // Explicitly normalize both subtrees and result + normalize_term_spec (no_dup_tree cmp (Node xk xv t2 t3)); + normalize_term_spec (no_dup_tree cmp (Node zk zv t1 (Node xk xv t2 t3))) +#pop-options + +/// rotate_left_right preserves no_dup_tree (requires is_bst) +#push-options "--fuel 2 --ifuel 2 --z3rlimit 10" +let rotate_left_right_no_dup (#k: Type) (#v: Type) (cmp: cmp k) (r: tree k v) + : Lemma (requires is_bst cmp r /\ no_dup_tree cmp r /\ Some? (rotate_left_right r)) + (ensures no_dup_tree cmp (Some?.v (rotate_left_right r))) + = match r with + | Node xk xv (Node zk zv t1 (Node yk yv t2 t3)) t4 -> + // Result: Node yk yv (Node zk zv t1 t2) (Node xk xv t3 t4) + // Unpack original no_dup_tree + no_dup_tree_node cmp xk xv (Node zk zv t1 (Node yk yv t2 t3)) t4; + no_dup_tree_node cmp zk zv t1 (Node yk yv t2 t3); + no_dup_tree_node cmp yk yv t2 t3; + + // From BST: t1 < z < y < x < t4 (chain of transitivity) + forall_keys_trans t1 (key_left cmp zk) (fun kk -> cmp yk kk <> 0); + forall_keys_trans t4 (key_right cmp xk) (fun kk -> cmp yk kk <> 0); + + // Show that y is distinct from z + // forall_keys (Node yk yv t2 t3) (fun kk -> cmp zk kk <> 0) from original + // This includes y itself, so cmp zk yk <> 0, hence cmp yk zk <> 0 + + // Explicitly normalize and assert needed facts + assert (forall_keys t1 (fun kk -> cmp yk kk <> 0)); + assert (forall_keys t2 (fun kk -> cmp yk kk <> 0)); + assert (forall_keys t3 (fun kk -> cmp yk kk <> 0)); + assert (forall_keys t4 (fun kk -> cmp yk kk <> 0)); + assert (forall_keys t1 (fun kk -> cmp zk kk <> 0)); + assert (forall_keys t2 (fun kk -> cmp zk kk <> 0)); + assert (forall_keys t3 (fun kk -> cmp xk kk <> 0)); + assert (forall_keys t4 (fun kk -> cmp xk kk <> 0)); + normalize_term_spec (no_dup_tree cmp (Node zk zv t1 t2)); + normalize_term_spec (no_dup_tree cmp (Node xk xv t3 t4)); + normalize_term_spec (no_dup_tree cmp (Node yk yv (Node zk zv t1 t2) (Node xk xv t3 t4))) +#pop-options + +/// rotate_right_left preserves no_dup_tree (requires is_bst) +#push-options "--fuel 2 --ifuel 2 --z3rlimit 10" +let rotate_right_left_no_dup (#k: Type) (#v: Type) (cmp: cmp k) (r: tree k v) + : Lemma (requires is_bst cmp r /\ no_dup_tree cmp r /\ Some? (rotate_right_left r)) + (ensures no_dup_tree cmp (Some?.v (rotate_right_left r))) + = match r with + | Node xk xv t1 (Node zk zv (Node yk yv t2 t3) t4) -> + // Result: Node yk yv (Node xk xv t1 t2) (Node zk zv t3 t4) + // Unpack original no_dup_tree + no_dup_tree_node cmp xk xv t1 (Node zk zv (Node yk yv t2 t3) t4); + no_dup_tree_node cmp zk zv (Node yk yv t2 t3) t4; + no_dup_tree_node cmp yk yv t2 t3; + + // From BST: t1 < x < y < z < t4 (chain of transitivity) + forall_keys_trans t1 (key_left cmp xk) (fun kk -> cmp yk kk <> 0); + forall_keys_trans t4 (key_right cmp zk) (fun kk -> cmp yk kk <> 0); + + // Explicitly assert needed facts + assert (forall_keys t1 (fun kk -> cmp yk kk <> 0)); + assert (forall_keys t2 (fun kk -> cmp yk kk <> 0)); + assert (forall_keys t3 (fun kk -> cmp yk kk <> 0)); + assert (forall_keys t4 (fun kk -> cmp yk kk <> 0)); + assert (forall_keys t1 (fun kk -> cmp xk kk <> 0)); + assert (forall_keys t2 (fun kk -> cmp xk kk <> 0)); + assert (forall_keys t3 (fun kk -> cmp zk kk <> 0)); + assert (forall_keys t4 (fun kk -> cmp zk kk <> 0)); + normalize_term_spec (no_dup_tree cmp (Node xk xv t1 t2)); + normalize_term_spec (no_dup_tree cmp (Node zk zv t3 t4)); + normalize_term_spec (no_dup_tree cmp (Node yk yv (Node xk xv t1 t2) (Node zk zv t3 t4))) +#pop-options + +/// rebalance_avl preserves no_dup_tree (requires is_bst) +#push-options "--fuel 2 --ifuel 2 --z3rlimit 50" +let rebalance_preserves_no_dup (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) + : Lemma (requires is_bst cmp t /\ no_dup_tree cmp t) + (ensures no_dup_tree cmp (rebalance_avl t)) + = match t with + | Leaf -> () + | Node nd_key nd_val left right -> + if is_balanced t then () + else if height left - height right > 1 then ( + let Node lk lv lleft lright = left in + if height lright > height lleft then + rotate_left_right_no_dup cmp t + else + rotate_right_no_dup cmp t + ) else if height right - height left > 1 then ( + let Node rk rv rleft rright = right in + if height rleft > height rright then + rotate_right_left_no_dup cmp t + else + rotate_left_no_dup cmp t + ) else () +#pop-options + +/// BST ordering implies distinctness: if m_key < data_key < all keys in r, then all keys in r are != m_key +#push-options "--fuel 2 --ifuel 2 --z3rlimit 250" +let rec bst_left_right_distinct (#k: Type) (#v: Type) (cmp: cmp k) (m_key data_key: k) (r: tree k v) + : Lemma (requires is_bst cmp r /\ + forall_keys r (key_right cmp data_key) /\ + key_left cmp data_key m_key /\ + cmp data_key m_key <> 0 /\ + forall_keys r (fun kk -> cmp data_key kk <> 0)) + (ensures forall_keys r (fun kk -> cmp m_key kk <> 0)) + (decreases r) + = match r with + | Leaf -> () + | Node nd_key nd_val l r_right -> + bst_left_right_distinct cmp m_key data_key l; + bst_left_right_distinct cmp m_key data_key r_right +#pop-options + +/// Extensionality for forall_keys: two functions that agree pointwise give the same result +#push-options "--fuel 2 --ifuel 1 --z3rlimit 20" +let rec forall_keys_ext (#k: Type) (#v: Type) (t: tree k v) (f g: k -> bool) + : Lemma (requires forall_keys t f /\ (forall (x:k). f x == g x)) + (ensures forall_keys t g) + (decreases t) + = match t with + | Leaf -> () + | Node dk dv l r -> forall_keys_ext l f g; forall_keys_ext r f g + +/// Assemble no_dup_tree from components, bridging lambda representations via forall_keys_ext +let no_dup_tree_assemble (#k: Type) (#v: Type) (cmp: cmp k) (dk: k) (dv: v) (l r: tree k v) (f: k -> bool) + : Lemma (requires forall_keys l f /\ forall_keys r f /\ + no_dup_tree cmp l /\ no_dup_tree cmp r /\ + (forall (x: k). f x == (cmp dk x <> 0))) + (ensures no_dup_tree cmp (Node dk dv l r)) + = let nd_f : (k -> bool) = fun kk -> cmp dk kk <> 0 in + forall_keys_ext l f nd_f; + forall_keys_ext r f nd_f; + normalize_term_spec (no_dup_tree cmp (Node dk dv l r)) +#pop-options + +/// Helper: forall_keys is preserved through delete_avl +#push-options "--fuel 3 --ifuel 2 --z3rlimit 100" +let rebalance_forall_keys_f (#k: Type) (#v: Type) (t: tree k v) (f: k -> bool) + : Lemma (requires forall_keys t f) + (ensures forall_keys (rebalance_avl t) f) = () + +let rec forall_keys_delete_avl (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) (key: k) (f: k -> bool) + : Lemma (requires forall_keys t f) + (ensures forall_keys (delete_avl cmp t key) f) + (decreases t) + = match t with + | Leaf -> () + | Node dk dv l r -> + if cmp dk key = 0 then ( + match l with + | Leaf -> () + | _ -> forall_keys_tree_max l f; + forall_keys_delete_avl cmp l (fst (tree_max l)) f; + rebalance_forall_keys_f (Node (fst (tree_max l)) (snd (tree_max l)) (delete_avl cmp l (fst (tree_max l))) r) f + ) else if cmp dk key < 0 then ( + forall_keys_delete_avl cmp r key f; + rebalance_forall_keys_f (Node dk dv l (delete_avl cmp r key)) f + ) else ( + forall_keys_delete_avl cmp l key f; + rebalance_forall_keys_f (Node dk dv (delete_avl cmp l key) r) f + ) +#pop-options + +/// Transitivity helper for cmp neq +#push-options "--fuel 1 --ifuel 1 --z3rlimit 40" +let neq_transitive (#k: Type) (cmp: cmp k) (m_key d_key k_key: k) + : Lemma (requires cmp d_key k_key >= 0 /\ cmp d_key k_key <> 0 /\ cmp m_key d_key > 0) + (ensures cmp m_key k_key <> 0) = () +#pop-options + +/// forall_keys neq via transitivity through BST ordering +#push-options "--fuel 2 --ifuel 1 --z3rlimit 60" +let rec forall_keys_neq_via_trans (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) (d_key m_key: k) + : Lemma (requires forall_keys t (key_left cmp d_key) /\ + forall_keys t (fun kk -> cmp d_key kk <> 0) /\ + cmp m_key d_key > 0) + (ensures forall_keys t (fun kk -> cmp m_key kk <> 0)) + (decreases t) + = match t with + | Leaf -> () + | Node xk xv l r -> + neq_transitive cmp m_key d_key xk; + forall_keys_neq_via_trans cmp l d_key m_key; + forall_keys_neq_via_trans cmp r d_key m_key +#pop-options + +/// BST+no_dup (Node dk dv l r) with Node? r => forall_keys l (fun k -> cmp (tree_max r) k <> 0) +#push-options "--fuel 2 --ifuel 1 --z3rlimit 100" +let left_max_right_distinct (#k: Type) (#v: Type) (cmp: cmp k) (dk: k) (dv: v) (l r: tree k v) + : Lemma (requires is_bst cmp (Node dk dv l r) /\ no_dup_tree cmp (Node dk dv l r) /\ Node? r) + (ensures (let m = tree_max r in + forall_keys l (fun kk -> cmp (fst m) kk <> 0))) + = let m = tree_max r in + no_dup_tree_node cmp dk dv l r; + forall_keys_tree_max r (key_right cmp dk); + forall_keys_tree_max r (fun kk -> cmp dk kk <> 0); + forall_keys_neq_via_trans cmp l dk (fst m); + let f_local : (k -> bool) = fun kk -> cmp (fst m) kk <> 0 in + let f_ensures : (k -> bool) = fun kk -> cmp (fst (tree_max r)) kk <> 0 in + forall_keys_ext l f_local f_ensures +#pop-options + +/// After deleting tree_max, all remaining keys are distinct from it +/// Uses forall_keys_ext to bridge lambda representations across recursive calls +#push-options "--fuel 3 --ifuel 2 --z3rlimit 250" +let rec delete_max_neq (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) (m_key: k) + : Lemma (requires Node? t /\ is_bst cmp t /\ no_dup_tree cmp t /\ m_key == fst (tree_max t)) + (ensures forall_keys (delete_avl cmp t m_key) (fun kk -> cmp m_key kk <> 0)) + (decreases t) + = let Node dk dv l r = t in + let f : (k -> bool) = fun kk -> cmp m_key kk <> 0 in + no_dup_tree_node cmp dk dv l r; + if Leaf? r then ( + let f_d : (k -> bool) = fun kk -> cmp dk kk <> 0 in + match l with + | Leaf -> () + | _ -> + forall_keys_ext l f_d f; + forall_keys_delete_avl cmp l (fst (tree_max l)) f; + forall_keys_tree_max l f; + rebalance_forall_keys_f (Node (fst (tree_max l)) (snd (tree_max l)) (delete_avl cmp l (fst (tree_max l))) Leaf) f + ) else ( + forall_keys_tree_max r (fun kk -> cmp dk kk <> 0); + forall_keys_tree_max r (key_right cmp dk); + delete_max_neq cmp r m_key; + forall_keys_neq_via_trans cmp l dk m_key; + let g : (k -> bool) = fun kk -> cmp m_key kk <> 0 in + forall_keys_ext l g f; + normalize_term_spec (forall_keys (Node dk dv l (delete_avl cmp r m_key)) f); + rebalance_forall_keys_f (Node dk dv l (delete_avl cmp r m_key)) f + ) + +let delete_max_keys_distinct (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) + : Lemma (requires is_bst cmp t /\ no_dup_tree cmp t /\ Node? t) + (ensures (let m = tree_max t in + forall_keys (delete_avl cmp t (fst m)) (fun kk -> cmp (fst m) kk <> 0))) + = delete_max_neq cmp t (fst (tree_max t)) +#pop-options + +/// delete_avl preserves no_dup_tree (with forall_keys tracking for the induction) +#push-options "--z3rlimit 1000 --fuel 2 --ifuel 2" +let rec delete_avl_no_dup_aux (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) (key: k) (root: k) + : Lemma (requires is_bst cmp t /\ no_dup_tree cmp t) + (ensures ( + let res = delete_avl cmp t key in + no_dup_tree cmp res /\ + (forall_keys t (fun kk -> cmp root kk <> 0) ==> forall_keys res (fun kk -> cmp root kk <> 0)) + )) + (decreases t) + = match t with + | Leaf -> () + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in + no_dup_tree_node cmp nd_key nd_val left right; + if delta = 0 then ( + match left, right with + | Leaf, Leaf -> () + | Node _ _ _ _, Leaf -> () + | Leaf, Node _ _ _ _ -> () + | _, _ -> + let m = tree_max left in + delete_avl_no_dup_aux cmp left (fst m) nd_key; + delete_avl_no_dup_aux cmp left (fst m) (fst m); + delete_avl_no_dup_aux cmp left (fst m) root; + let new_left = delete_avl cmp left (fst m) in + let tmp = Node (fst m) (snd m) new_left right in + + delete_avl_proof_aux cmp left (fst m) nd_key; + delete_avl_proof_aux cmp left (fst m) (fst m); + tree_max_is_maximal cmp left; + forall_keys_tree_max left (key_left cmp nd_key); + forall_keys_trans right (key_right cmp nd_key) (key_right cmp (fst m)); + forall_keys_tree_max left (fun kk -> cmp nd_key kk <> 0); + + // no_dup for tmp: use forall_keys_ext to bridge lambda representations + delete_max_keys_distinct cmp left; + let f_neq : (k -> bool) = fun kk -> cmp (fst m) kk <> 0 in + let f_src : (k -> bool) = fun kk -> cmp (fst (tree_max left)) kk <> 0 in + forall_keys_ext new_left f_src f_neq; + bst_left_right_distinct cmp (fst m) nd_key right; + let f_bst : (k -> bool) = fun kk -> cmp (fst m) kk <> 0 in + forall_keys_ext right f_bst f_neq; + no_dup_tree_assemble cmp (fst m) (snd m) new_left right f_neq; + + normalize_term_spec (is_bst cmp tmp); + rebalance_bst cmp tmp; + rebalance_preserves_no_dup cmp tmp; + if forall_keys (Node nd_key nd_val left right) (fun kk -> cmp root kk <> 0) then ( + forall_keys_tree_max left (fun kk -> cmp root kk <> 0); + rebalance_preserves_forall_keys cmp tmp (fun kk -> cmp root kk <> 0) + ) + ) else if delta < 0 then ( + delete_avl_no_dup_aux cmp right key nd_key; + delete_avl_no_dup_aux cmp right key root; + let new_right = delete_avl cmp right key in + let tmp = Node nd_key nd_val left new_right in + delete_avl_proof_aux cmp right key nd_key; + let f_neq2 : (k -> bool) = fun kk -> cmp nd_key kk <> 0 in + forall_keys_delete_avl cmp right key f_neq2; + no_dup_tree_assemble cmp nd_key nd_val left new_right f_neq2; + normalize_term_spec (is_bst cmp tmp); + rebalance_bst cmp tmp; + rebalance_preserves_no_dup cmp tmp; + if forall_keys (Node nd_key nd_val left right) (fun kk -> cmp root kk <> 0) then + rebalance_preserves_forall_keys cmp tmp (fun kk -> cmp root kk <> 0) + ) else ( + delete_avl_no_dup_aux cmp left key nd_key; + delete_avl_no_dup_aux cmp left key root; + let new_left = delete_avl cmp left key in + let tmp = Node nd_key nd_val new_left right in + delete_avl_proof_aux cmp left key nd_key; + let f_neq3 : (k -> bool) = fun kk -> cmp nd_key kk <> 0 in + forall_keys_delete_avl cmp left key f_neq3; + no_dup_tree_assemble cmp nd_key nd_val new_left right f_neq3; + normalize_term_spec (is_bst cmp tmp); + rebalance_bst cmp tmp; + rebalance_preserves_no_dup cmp tmp; + if forall_keys (Node nd_key nd_val left right) (fun kk -> cmp root kk <> 0) then + rebalance_preserves_forall_keys cmp tmp (fun kk -> cmp root kk <> 0) + ) + +/// delete_avl preserves no_dup_tree +let delete_avl_preserves_no_dup_tree (#k: Type) (#v: Type) (cmp: cmp k) (t: tree k v) (key: k) + : Lemma (requires is_bst cmp t /\ no_dup_tree cmp t) + (ensures no_dup_tree cmp (delete_avl cmp t key)) + = match t with + | Leaf -> () + | Node nd_key _ _ _ -> delete_avl_no_dup_aux cmp t key nd_key + +/// Helper: insert_avl preserves forall_keys +#push-options "--fuel 2 --ifuel 2 --z3rlimit 100" +let rec insert_avl_preserves_forall_keys + (#k: Type) (#v: Type) + (cmp: cmp k) + (t: tree k v) + (key: k) + (val_: v) + (cond: k -> bool) + : Lemma + (requires forall_keys t cond /\ cond key) + (ensures forall_keys (insert_avl cmp t key val_) cond) + (decreases t) + = match t with + | Leaf -> () + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in + if delta >= 0 then ( + insert_avl_preserves_forall_keys cmp left key val_ cond; + rebalance_preserves_forall_keys cmp (Node nd_key nd_val (insert_avl cmp left key val_) right) cond + ) else ( + insert_avl_preserves_forall_keys cmp right key val_ cond; + rebalance_preserves_forall_keys cmp (Node nd_key nd_val left (insert_avl cmp right key val_)) cond + ) +#pop-options + +/// Helper: insert_avl preserves is_bst +#push-options "--fuel 2 --ifuel 2 --z3rlimit 100" +let rec insert_avl_preserves_bst + (#k: Type) (#v: Type) + (cmp: cmp k) + (t: tree k v) + (key: k) + (val_: v) + : Lemma + (requires is_bst cmp t) + (ensures is_bst cmp (insert_avl cmp t key val_)) + (decreases t) + = match t with + | Leaf -> () + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in + if delta >= 0 then ( + insert_avl_preserves_forall_keys cmp left key val_ (key_left cmp nd_key); + insert_avl_preserves_bst cmp left key val_; + normalize_term_spec (is_bst cmp (Node nd_key nd_val (insert_avl cmp left key val_) right)); + rebalance_bst cmp (Node nd_key nd_val (insert_avl cmp left key val_) right) + ) else ( + insert_avl_preserves_forall_keys cmp right key val_ (key_right cmp nd_key); + insert_avl_preserves_bst cmp right key val_; + normalize_term_spec (is_bst cmp (Node nd_key nd_val left (insert_avl cmp right key val_))); + rebalance_bst cmp (Node nd_key nd_val left (insert_avl cmp right key val_)) + ) +#pop-options +/// insert_avl preserves no_dup_tree +#push-options "--fuel 2 --ifuel 2 --z3rlimit 100" +let rec insert_avl_preserves_no_dup_tree + (#k: Type) (#v: Type) + (cmp: cmp k) + (t: tree k v) + (key: k) + (val_: v) + : Lemma + (requires is_bst cmp t /\ no_dup_tree cmp t /\ + forall_keys t (fun kk -> cmp key kk <> 0)) + (ensures no_dup_tree cmp (insert_avl cmp t key val_)) + (decreases t) + = match t with + | Leaf -> () + | Node nd_key nd_val left right -> + let delta = cmp nd_key key in + let f : (k -> bool) = fun kk -> cmp nd_key kk <> 0 in + no_dup_tree_node cmp nd_key nd_val left right; + forall_keys_ext left (fun kk -> cmp nd_key kk <> 0) f; + forall_keys_ext right (fun kk -> cmp nd_key kk <> 0) f; + + if delta >= 0 then ( + insert_avl_preserves_no_dup_tree cmp left key val_; + insert_avl_preserves_bst cmp left key val_; + insert_avl_preserves_forall_keys cmp left key val_ f; + insert_avl_preserves_forall_keys cmp left key val_ (key_left cmp nd_key); + let tmp = Node nd_key nd_val (insert_avl cmp left key val_) right in + normalize_term_spec (is_bst cmp tmp); + no_dup_tree_assemble cmp nd_key nd_val (insert_avl cmp left key val_) right f; + rebalance_bst cmp tmp; + rebalance_preserves_no_dup cmp tmp + ) else ( + insert_avl_preserves_no_dup_tree cmp right key val_; + insert_avl_preserves_bst cmp right key val_; + insert_avl_preserves_forall_keys cmp right key val_ f; + insert_avl_preserves_forall_keys cmp right key val_ (key_right cmp nd_key); + let tmp = Node nd_key nd_val left (insert_avl cmp right key val_) in + normalize_term_spec (is_bst cmp tmp); + no_dup_tree_assemble cmp nd_key nd_val left (insert_avl cmp right key val_) f; + rebalance_bst cmp tmp; + rebalance_preserves_no_dup cmp tmp + ) +#pop-options +#pop-options + +(** Strictly sorted sequences and BST no_dup_tree bridge *) + +/// Strictly sorted: each consecutive pair has compare < 0 (not <=) +let rec sorted_strict (#k: Type) (#v: Type) (compare: cmp k) (s: Seq.seq (k & v)) + : Tot bool (decreases Seq.length s) = + if Seq.length s <= 1 then true + else compare (fst (Seq.head s)) (fst (Seq.index s 1)) < 0 && sorted_strict compare (Seq.tail s) + +/// sorted_strict implies sorted +let rec sorted_strict_implies_sorted (#k: Type) (#v: Type) (compare: cmp k) (s: Seq.seq (k & v)) + : Lemma (requires sorted_strict compare s) + (ensures sorted compare s) + (decreases Seq.length s) = + if Seq.length s <= 1 then () + else sorted_strict_implies_sorted compare (Seq.tail s) + +#push-options "--fuel 2 --ifuel 1 --z3rlimit 40" + +/// In a strictly sorted sequence, all pairs are distinct under cmp +let rec sorted_strict_forall_neq (#k: Type) (#v: Type) (compare: cmp k) (s: Seq.seq (k & v)) (i j: nat) + : Lemma (requires sorted_strict compare s /\ i < j /\ j < Seq.length s) + (ensures compare (fst (Seq.index s i)) (fst (Seq.index s j)) < 0) + (decreases Seq.length s) = + if i = 0 then ( + if j = 1 then () + else sorted_strict_forall_neq compare (Seq.tail s) 0 (j - 1) + ) else + sorted_strict_forall_neq compare (Seq.tail s) (i - 1) (j - 1) + +#pop-options + +#push-options "--fuel 2 --ifuel 2 --z3rlimit 100" + +/// Helper: In a strictly sorted sequence, no element equals a given element at a different position +let rec sorted_strict_neq_all (#k: Type) (#v: Type) (compare: cmp k) (s: Seq.seq (k & v)) (d: (k & v)) (d_pos: nat) + : Lemma + (requires sorted_strict compare s /\ d_pos < Seq.length s /\ Seq.index s d_pos == d) + (ensures (forall (i:nat{i < Seq.length s /\ i <> d_pos}). compare (fst d) (fst (Seq.index s i)) <> 0)) + (decreases Seq.length s) = + let aux (i:nat{i < Seq.length s /\ i <> d_pos}) + : Lemma (compare (fst d) (fst (Seq.index s i)) <> 0) = + if i < d_pos then sorted_strict_forall_neq compare s i d_pos + else sorted_strict_forall_neq compare s d_pos i + in + Classical.forall_intro aux + +/// sorted_strict on tail +let sorted_strict_tail (#k: Type) (#v: Type) (compare: cmp k) (s: Seq.seq (k & v)) + : Lemma (requires sorted_strict compare s /\ Seq.length s > 0) + (ensures sorted_strict compare (Seq.tail s)) = () + +/// sorted_strict on the append components +let rec sorted_strict_append_left (#k: Type) (#v: Type) (compare: cmp k) (s1 s2: Seq.seq (k & v)) + : Lemma (requires sorted_strict compare (Seq.append s1 s2)) + (ensures sorted_strict compare s1) + (decreases Seq.length s1) = + if Seq.length s1 <= 1 then () + else ( + Seq.lemma_head_append s1 s2; + Seq.lemma_tail_append s1 s2; + assert (Seq.index (Seq.append s1 s2) 0 == Seq.index s1 0); + assert (Seq.index (Seq.append s1 s2) 1 == Seq.index s1 1); + sorted_strict_append_left compare (Seq.tail s1) s2 + ) + +let rec sorted_strict_append_right (#k: Type) (#v: Type) (compare: cmp k) (s1 s2: Seq.seq (k & v)) + : Lemma (requires sorted_strict compare (Seq.append s1 s2)) + (ensures sorted_strict compare s2) + (decreases Seq.length s1) = + if Seq.length s1 = 0 then + assert (Seq.equal (Seq.append s1 s2) s2) + else ( + Seq.lemma_tail_append s1 s2; + sorted_strict_append_right compare (Seq.tail s1) s2 + ) + +/// seq_forall on append components (reverse of seq_forall_append) +let rec seq_forall_append_inv (#k: Type) (#v: Type) (f: k -> bool) (s1 s2: Seq.seq (k & v)) + : Lemma (requires seq_forall f (Seq.append s1 s2)) + (ensures seq_forall f s1 /\ seq_forall f s2) + (decreases Seq.length s1) = + if Seq.length s1 = 0 then + assert (Seq.equal (Seq.append s1 s2) s2) + else ( + Seq.lemma_head_append s1 s2; + Seq.lemma_tail_append s1 s2; + seq_forall_append_inv f (Seq.tail s1) s2 + ) + +/// Reverse bridge: seq_forall on inorder implies forall_keys on tree +let rec inorder_forall_keys (#k: Type) (#v: Type) (t: tree k v) (cond: k -> bool) + : Lemma (requires seq_forall cond (inorder t)) + (ensures forall_keys t cond) + (decreases t) = + match t with + | Leaf -> () + | Node dk dv l r -> + // inorder t == append (inorder l) (cons d (inorder r)) + seq_forall_append_inv cond (inorder l) (Seq.cons (dk, dv) (inorder r)); + // Now: seq_forall cond (inorder l) /\ seq_forall cond (cons d (inorder r)) + // From cons: head is d so cond d, tail is inorder r + let dr = Seq.cons (dk, dv) (inorder r) in + assert (Seq.head dr == (dk, dv)); + assert (Seq.equal (Seq.tail dr) (inorder r)); + inorder_forall_keys l cond; + inorder_forall_keys r cond + +#pop-options + +#push-options "--fuel 2 --ifuel 2 --z3rlimit 100" + +/// All elements left of d in a sorted_strict sequence have compare d k <> 0. +/// Since sorted_strict_forall_neq gives compare k d < 0 for k before d, +/// the cmp axiom gives compare d k > 0, hence <> 0. +let rec sorted_strict_left_neq (#k: Type) (#v: Type) (compare: cmp k) (s1: Seq.seq (k & v)) (d: (k & v)) (s2: Seq.seq (k & v)) + : Lemma + (requires sorted_strict compare (Seq.append s1 (Seq.cons d s2))) + (ensures seq_forall (fun x -> compare (fst d) x <> 0) s1) + (decreases Seq.length s1) + = if Seq.length s1 = 0 then () + else begin + let h = Seq.head s1 in + let t = Seq.tail s1 in + let ds = Seq.cons d s2 in + let full = Seq.append s1 ds in + sorted_strict_forall_neq compare full 0 (Seq.length s1); + Seq.lemma_tail_append s1 ds; + sorted_strict_left_neq compare t d s2; + seq_forall_cons (fun x -> compare (fst d) x <> 0) (fst h, snd h) t + end + +/// All elements after d in a sorted_strict (d :: s) have compare d k <> 0. +/// Key: from sorted_strict (d :: h :: t), derive sorted_strict (d :: t) via +/// sorted_strict_forall_neq to get compare d (head t) < 0. +let rec sorted_strict_right_neq (#k: Type) (#v: Type) (compare: cmp k) (d: (k & v)) (s: Seq.seq (k & v)) + : Lemma + (requires sorted_strict compare (Seq.cons d s)) + (ensures seq_forall (fun x -> compare (fst d) x <> 0) s) + (decreases Seq.length s) + = if Seq.length s = 0 then () + else begin + let h = Seq.head s in + let t = Seq.tail s in + let ds = Seq.cons d s in + // Establish tail(cons d s) == s + Seq.lemma_tail_append (Seq.create 1 d) s; + assert (Seq.equal (Seq.tail ds) s); + if Seq.length t = 0 then () + else begin + sorted_strict_tail compare ds; + sorted_strict_tail compare s; + sorted_strict_forall_neq compare ds 0 2; + let dt = Seq.cons d t in + Seq.lemma_head_append (Seq.create 1 d) t; + Seq.lemma_tail_append (Seq.create 1 d) t; + assert (Seq.head dt == d); + assert (Seq.equal (Seq.tail dt) t); + assert (Seq.index dt 1 == Seq.head t); + sorted_strict_right_neq compare d t + end; + seq_forall_cons (fun x -> compare (fst d) x <> 0) (fst h, snd h) t + end + +/// Extensionality for seq_forall: two predicates that agree on keys give the same result +let rec seq_forall_ext (#k: Type) (#v: Type) (f g: k -> bool) (s: Seq.seq (k & v)) + : Lemma (requires seq_forall f s /\ (forall (x:k). f x == g x)) + (ensures seq_forall g s) + (decreases Seq.length s) + = if Seq.length s = 0 then () + else seq_forall_ext f g (Seq.tail s) + +/// BST with strictly sorted inorder implies no_dup_tree +let rec bst_strict_sorted_no_dup (#k: Type) (#v: Type) (compare: cmp k) (t: tree k v) + : Lemma + (requires is_bst compare t /\ sorted_strict compare (inorder t)) + (ensures no_dup_tree compare t) + (decreases t) = + match t with + | Leaf -> () + | Node dk dv l r -> + let io_l = inorder l in + let io_r = inorder r in + let ds = Seq.cons (dk, dv) io_r in + // sorted_strict on sub-sequences + sorted_strict_append_left compare io_l ds; + sorted_strict_append_right compare io_l ds; + sorted_strict_tail compare ds; + assert (Seq.equal (Seq.tail ds) io_r); + bst_strict_sorted_no_dup compare l; + bst_strict_sorted_no_dup compare r; + // Build forall_keys l/r (fun x -> compare dk x <> 0) + let f : (k -> bool) = fun x -> compare dk x <> 0 in + let g : (k -> bool) = fun x -> compare (fst (dk, dv)) x <> 0 in + sorted_strict_left_neq compare io_l (dk, dv) io_r; + // sorted_strict_left_neq gives seq_forall g io_l, but we need seq_forall f io_l + // Since fst (dk, dv) = dk, g and f are extensionally equal + assert (forall (x:k). g x == f x); + seq_forall_ext g f io_l; + inorder_forall_keys l f; + sorted_strict_right_neq compare (dk, dv) io_r; + seq_forall_ext g f io_r; + inorder_forall_keys r f; + no_dup_tree_assemble compare dk dv l r f + +#pop-options diff --git a/lib/pulse/lib/Pulse.Lib.Vector.fst b/lib/pulse/lib/Pulse.Lib.Vector.fst new file mode 100644 index 000000000..149101005 --- /dev/null +++ b/lib/pulse/lib/Pulse.Lib.Vector.fst @@ -0,0 +1,320 @@ +(* + Copyright 2025 Microsoft Research + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*) + +module Pulse.Lib.Vector + +#lang-pulse + +open Pulse.Lib.Pervasives +module Seq = FStar.Seq +module SZ = FStar.SizeT +module A = Pulse.Lib.Array +module Box = Pulse.Lib.Box +open Pulse.Lib.Box + +/// Internal representation +noeq +type vector_internal (t:Type0) = { + arr: A.array t; + sz: SZ.t; + cap: SZ.t; + default_val: t; +} + +let vector t = box (vector_internal t) + +/// The is_vector predicate +let is_vector #t (v:vector t) (s:Seq.seq t) (cap:SZ.t) : slprop = + exists* (vi:vector_internal t) (buf:Seq.seq t). + pts_to v vi ** + A.pts_to vi.arr buf ** + pure ( + SZ.v vi.sz == Seq.length s /\ + SZ.v vi.cap == A.length vi.arr /\ + vi.cap == cap /\ + SZ.v vi.sz <= SZ.v vi.cap /\ + SZ.v vi.cap > 0 /\ + A.is_full_array vi.arr /\ + Seq.length buf == SZ.v vi.cap /\ + s `Seq.equal` Seq.slice buf 0 (SZ.v vi.sz) + ) + +/// Create +#push-options "--warn_error -288" +fn create (#t:Type0) (default:t) (n:SZ.t{SZ.v n > 0}) + returns v:vector t + ensures is_vector v (Seq.create (SZ.v n) default) n +{ + let arr = A.alloc default n; + A.pts_to_len arr; + let n' : SZ.t = n; + let vi = Mkvector_internal #t arr n' n' default; + let v = alloc vi; + rewrite (A.pts_to arr (Seq.create (SZ.v n) default)) + as (A.pts_to vi.arr (Seq.create (SZ.v n) default)); + fold (is_vector v (Seq.create (SZ.v n) default) n); + v +} +#pop-options + +/// Read element at index +fn at (#t:Type0) (v:vector t) (i:SZ.t) + (#s:erased (Seq.seq t){SZ.v i < Seq.length s}) (#cap:erased SZ.t) + preserves is_vector v s cap + returns x:t + ensures pure (x == Seq.index s (SZ.v i)) +{ + unfold (is_vector v s cap); + with vi buf. _; + + let vi_val = !v; + rewrite (A.pts_to vi.arr buf) as (A.pts_to vi_val.arr buf); + + A.pts_to_len vi_val.arr; + let x = A.op_Array_Access vi_val.arr i; + + rewrite (A.pts_to vi_val.arr buf) as (A.pts_to vi.arr buf); + fold (is_vector v s cap); + x +} + +/// Write element at index +fn set (#t:Type0) (v:vector t) (i:SZ.t) (x:t) + (#s:erased (Seq.seq t){SZ.v i < Seq.length s}) (#cap:erased SZ.t) + requires is_vector v s cap + ensures is_vector v (Seq.upd s (SZ.v i) x) cap +{ + unfold (is_vector v s cap); + with vi buf. _; + + let vi_val = !v; + rewrite (A.pts_to vi.arr buf) as (A.pts_to vi_val.arr buf); + + A.pts_to_len vi_val.arr; + A.op_Array_Assignment vi_val.arr i x; + with buf'. assert (A.pts_to vi_val.arr buf'); + + rewrite (A.pts_to vi_val.arr buf') as (A.pts_to vi.arr buf'); + fold (is_vector v (Seq.upd s (SZ.v i) x) cap) +} + +/// Get current size +fn size (#t:Type0) (v:vector t) (#s:erased (Seq.seq t)) (#cap:erased SZ.t) + preserves is_vector v s cap + returns n:SZ.t + ensures pure (SZ.v n == Seq.length s) +{ + unfold (is_vector v s cap); + with vi buf. _; + let vi_val = !v; + fold (is_vector v s cap); + vi_val.sz +} + +/// Get current capacity +fn capacity (#t:Type0) (v:vector t) (#s:erased (Seq.seq t)) (#cap:erased SZ.t) + preserves is_vector v s cap + returns n:SZ.t + ensures pure (n == reveal cap) +{ + unfold (is_vector v s cap); + with vi buf. _; + let vi_val = !v; + fold (is_vector v s cap); + vi_val.cap +} + +ghost fn size_bounded (#t:Type0) (v:vector t) (#s:erased (Seq.seq t)) (#cap:erased SZ.t) + requires is_vector v s cap + ensures is_vector v s cap ** pure (Seq.length s <= SZ.v cap) +{ + unfold (is_vector v s cap); + with vi buf. _; + fold (is_vector v s cap) +} + +/// Push back — append element, double capacity if full +#push-options "--warn_error -288 --z3rlimit_factor 2" +fn push_back (#t:Type0) (v:vector t) (x:t) + (#s:erased (Seq.seq t)) (#cap:erased SZ.t) + requires is_vector v s cap ** + pure (Seq.length s < SZ.v cap \/ SZ.fits (SZ.v cap + SZ.v cap)) + ensures exists* (cap':SZ.t). + is_vector v (Seq.snoc s x) cap' ** + pure (SZ.v cap' >= Seq.length s + 1 /\ SZ.v cap' > 0 /\ + (SZ.v cap' == SZ.v cap \/ SZ.v cap' == SZ.v cap + SZ.v cap)) +{ + unfold (is_vector v s cap); + with vi buf. _; + + let vi_val = !v; + rewrite (A.pts_to vi.arr buf) as (A.pts_to vi_val.arr buf); + A.pts_to_len vi_val.arr; + + if SZ.lt vi_val.sz vi_val.cap + { + // No resize needed — just insert at position size + A.op_Array_Assignment vi_val.arr vi_val.sz x; + with buf'. assert (A.pts_to vi_val.arr buf'); + let new_vi = Mkvector_internal #t vi_val.arr (SZ.add vi_val.sz 1sz) vi_val.cap vi_val.default_val; + ( := ) v new_vi; + + rewrite (A.pts_to vi_val.arr buf') as (A.pts_to new_vi.arr buf'); + fold (is_vector v (Seq.snoc s x) cap); + () + } + else + { + // Resize: allocate double, copy, write new element, free old + let new_cap = SZ.add vi_val.cap vi_val.cap; + let new_arr = A.alloc vi_val.default_val new_cap; + A.pts_to_len new_arr; + + let _sq = A.memcpy_l vi_val.cap vi_val.arr new_arr; + + A.op_Array_Assignment new_arr vi_val.sz x; + with buf'. assert (A.pts_to new_arr buf'); + + A.free vi_val.arr; + + let new_vi = Mkvector_internal #t new_arr (SZ.add vi_val.sz 1sz) new_cap vi_val.default_val; + ( := ) v new_vi; + + rewrite (A.pts_to new_arr buf') as (A.pts_to new_vi.arr buf'); + fold (is_vector v (Seq.snoc s x) new_cap); + () + } +} +#pop-options + +/// Pop back — remove last element, halve capacity when size == floor(cap/2) +#push-options "--warn_error -288 --z3rlimit_factor 2" +fn pop_back (#t:Type0) (v:vector t) + (#s:erased (Seq.seq t){Seq.length s > 0}) (#cap:erased SZ.t) + requires is_vector v s cap + returns x:t + ensures exists* (cap':SZ.t). + is_vector v (Seq.slice s 0 (Seq.length s - 1)) cap' ** + pure (x == Seq.index s (Seq.length s - 1) /\ + SZ.v cap' >= Seq.length s - 1 /\ SZ.v cap' > 0 /\ + (Seq.length s - 1 < SZ.v cap' \/ SZ.fits (SZ.v cap' + SZ.v cap'))) +{ + unfold (is_vector v s cap); + with vi buf. _; + + let vi_val = !v; + rewrite (A.pts_to vi.arr buf) as (A.pts_to vi_val.arr buf); + A.pts_to_len vi_val.arr; + + let last_idx = SZ.sub vi_val.sz 1sz; + let x = A.op_Array_Access vi_val.arr last_idx; + + let new_sz = last_idx; + let half_cap = SZ.div vi_val.cap 2sz; + let should_shrink = SZ.gt half_cap 0sz && SZ.eq new_sz half_cap; + + if should_shrink + { + // Shrink: allocate half, copy surviving elements, free old + let new_arr = A.alloc vi_val.default_val half_cap; + A.pts_to_len new_arr; + + let _sq = A.memcpy_l new_sz vi_val.arr new_arr; + + A.free vi_val.arr; + + let new_vi = Mkvector_internal #t new_arr new_sz half_cap vi_val.default_val; + ( := ) v new_vi; + + with buf_new. assert (A.pts_to new_arr buf_new); + rewrite (A.pts_to new_arr buf_new) as (A.pts_to new_vi.arr buf_new); + fold (is_vector v (Seq.slice s 0 (Seq.length s - 1)) half_cap); + x + } + else + { + // No shrink — just decrement size + let new_vi = Mkvector_internal #t vi_val.arr new_sz vi_val.cap vi_val.default_val; + ( := ) v new_vi; + + rewrite (A.pts_to vi_val.arr buf) as (A.pts_to new_vi.arr buf); + fold (is_vector v (Seq.slice s 0 (Seq.length s - 1)) cap); + x + } +} +#pop-options + +/// Resize +#push-options "--warn_error -288 --z3rlimit_factor 2" +fn resize (#t:Type0) (v:vector t) (new_size:SZ.t{SZ.v new_size > 0}) + (#s:erased (Seq.seq t)) (#cap:erased SZ.t) + requires is_vector v s cap + ensures exists* (s':Seq.seq t) (cap':SZ.t). + is_vector v s' cap' ** + pure (Seq.length s' == SZ.v new_size /\ + SZ.v cap' >= SZ.v new_size /\ SZ.v cap' > 0 /\ + (forall (i:nat). i < Seq.length s /\ i < SZ.v new_size ==> + Seq.index s' i == Seq.index s i)) +{ + unfold (is_vector v s cap); + with vi buf. _; + + let vi_val = !v; + rewrite (A.pts_to vi.arr buf) as (A.pts_to vi_val.arr buf); + A.pts_to_len vi_val.arr; + + if SZ.lte new_size vi_val.cap + { + let ns : SZ.t = new_size; + let new_vi = Mkvector_internal #t vi_val.arr ns vi_val.cap vi_val.default_val; + ( := ) v new_vi; + rewrite (A.pts_to vi_val.arr buf) as (A.pts_to new_vi.arr buf); + fold (is_vector v (Seq.slice buf 0 (SZ.v new_size)) cap); + () + } + else + { + let new_arr = A.alloc vi_val.default_val new_size; + A.pts_to_len new_arr; + let _sq = A.memcpy_l vi_val.sz vi_val.arr new_arr; + A.free vi_val.arr; + let ns : SZ.t = new_size; + let new_vi = Mkvector_internal #t new_arr ns ns vi_val.default_val; + ( := ) v new_vi; + with buf'. assert (A.pts_to new_arr buf'); + rewrite (A.pts_to new_arr buf') as (A.pts_to new_vi.arr buf'); + fold (is_vector v (Seq.slice buf' 0 (SZ.v new_size)) new_size); + () + } +} +#pop-options + +/// Free +#push-options "--warn_error -288" +fn free (#t:Type0) (v:vector t) (#s:erased (Seq.seq t)) (#cap:erased SZ.t) + requires is_vector v s cap +{ + unfold (is_vector v s cap); + with vi buf. _; + + let vi_val = !v; + rewrite (A.pts_to vi.arr buf) as (A.pts_to vi_val.arr buf); + + A.free vi_val.arr; + Box.free v; + () +} +#pop-options diff --git a/lib/pulse/lib/Pulse.Lib.Vector.fsti b/lib/pulse/lib/Pulse.Lib.Vector.fsti new file mode 100644 index 000000000..1807c54b8 --- /dev/null +++ b/lib/pulse/lib/Pulse.Lib.Vector.fsti @@ -0,0 +1,114 @@ +(* + Copyright 2025 Microsoft Research + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*) + +(** + A dynamically-resizable vector for Pulse. + + Doubles capacity on push_back when full. + Halves capacity on pop_back when size == floor(capacity / 2). + Backed by a flat array with a stored default value for unused slots. +*) + +module Pulse.Lib.Vector + +#lang-pulse + +open Pulse.Lib.Pervasives +module Seq = FStar.Seq +module SZ = FStar.SizeT + +/// Abstract vector type +val vector (t:Type0) : Type0 + +/// Predicate relating a vector to its logical contents and capacity +val is_vector (#t:Type0) ([@@@mkey]v:vector t) (s:Seq.seq t) (cap:SZ.t) : slprop + +/// Create a new vector with n elements all set to default. +/// Capacity is initially n. Requires n > 0. +fn create (#t:Type0) (default:t) (n:SZ.t{SZ.v n > 0}) + returns v:vector t + ensures is_vector v (Seq.create (SZ.v n) default) n + +/// Read element at index i. +/// Requires: i < size +fn at (#t:Type0) (v:vector t) (i:SZ.t) + (#s:erased (Seq.seq t){SZ.v i < Seq.length s}) (#cap:erased SZ.t) + preserves is_vector v s cap + returns x:t + ensures pure (x == Seq.index s (SZ.v i)) + +/// Write element at index i. +/// Requires: i < size +fn set (#t:Type0) (v:vector t) (i:SZ.t) (x:t) + (#s:erased (Seq.seq t){SZ.v i < Seq.length s}) (#cap:erased SZ.t) + requires is_vector v s cap + ensures is_vector v (Seq.upd s (SZ.v i) x) cap + +/// Get the current number of elements +fn size (#t:Type0) (v:vector t) (#s:erased (Seq.seq t)) (#cap:erased SZ.t) + preserves is_vector v s cap + returns n:SZ.t + ensures pure (SZ.v n == Seq.length s) + +/// Get the current capacity +fn capacity (#t:Type0) (v:vector t) (#s:erased (Seq.seq t)) (#cap:erased SZ.t) + preserves is_vector v s cap + returns n:SZ.t + ensures pure (n == cap) + +/// Extract the fact that size <= capacity (always true, but is_vector is abstract) +ghost fn size_bounded (#t:Type0) (v:vector t) (#s:erased (Seq.seq t)) (#cap:erased SZ.t) + requires is_vector v s cap + ensures is_vector v s cap ** pure (Seq.length s <= SZ.v cap) + +/// Append element to end. Doubles capacity if full. +/// Precondition: either there is room, or doubling is representable. +fn push_back (#t:Type0) (v:vector t) (x:t) + (#s:erased (Seq.seq t)) (#cap:erased SZ.t) + requires is_vector v s cap ** + pure (Seq.length s < SZ.v cap \/ SZ.fits (SZ.v cap + SZ.v cap)) + ensures exists* (cap':SZ.t). + is_vector v (Seq.snoc s x) cap' ** + pure (SZ.v cap' >= Seq.length s + 1 /\ SZ.v cap' > 0 /\ + (SZ.v cap' == SZ.v cap \/ SZ.v cap' == SZ.v cap + SZ.v cap)) + +/// Remove and return the last element. Halves capacity when size == floor(cap/2). +/// Requires: vector is non-empty +fn pop_back (#t:Type0) (v:vector t) + (#s:erased (Seq.seq t){Seq.length s > 0}) (#cap:erased SZ.t) + requires is_vector v s cap + returns x:t + ensures exists* (cap':SZ.t). + is_vector v (Seq.slice s 0 (Seq.length s - 1)) cap' ** + pure (x == Seq.index s (Seq.length s - 1) /\ + SZ.v cap' >= Seq.length s - 1 /\ SZ.v cap' > 0 /\ + (Seq.length s - 1 < SZ.v cap' \/ SZ.fits (SZ.v cap' + SZ.v cap'))) + +/// Resize the vector to new_size elements. +/// Preserves the first min(old_size, new_size) elements. +fn resize (#t:Type0) (v:vector t) (new_size:SZ.t{SZ.v new_size > 0}) + (#s:erased (Seq.seq t)) (#cap:erased SZ.t) + requires is_vector v s cap + ensures exists* (s':Seq.seq t) (cap':SZ.t). + is_vector v s' cap' ** + pure (Seq.length s' == SZ.v new_size /\ + SZ.v cap' >= SZ.v new_size /\ SZ.v cap' > 0 /\ + (forall (i:nat). i < Seq.length s /\ i < SZ.v new_size ==> + Seq.index s' i == Seq.index s i)) + +/// Free the vector and its backing storage +fn free (#t:Type0) (v:vector t) (#s:erased (Seq.seq t)) (#cap:erased SZ.t) + requires is_vector v s cap diff --git a/src/extraction/ExtractPulse.fst b/src/extraction/ExtractPulse.fst index 129332a05..103123e13 100644 --- a/src/extraction/ExtractPulse.fst +++ b/src/extraction/ExtractPulse.fst @@ -149,6 +149,10 @@ let pulse_translate_expr : translate_expr_t = fun env e -> when string_of_mlpath p = "Pulse.Lib.Vec.op_Array_Access" -> EBufRead (cb e, cb i) + | MLE_App ({ expr = MLE_TApp({ expr = MLE_Name p }, _) }, [ e; i; _p; _w ]) + when string_of_mlpath p = "Pulse.Lib.Array.PtsTo.op_Array_Access" -> + EBufRead (cb e, cb i) + | MLE_App ({ expr = MLE_TApp({ expr = MLE_Name p }, _) }, [ e; i; _p; _w; _m ]) when string_of_mlpath p = "Pulse.Lib.Array.Core.mask_read" -> EBufRead (cb e, cb i) @@ -157,6 +161,10 @@ let pulse_translate_expr : translate_expr_t = fun env e -> when string_of_mlpath p = "Pulse.Lib.Vec.op_Array_Assignment" -> EBufWrite (cb e, cb i, cb v) + | MLE_App ({ expr = MLE_TApp({ expr = MLE_Name p }, _) }, [ e; i; v; _s ]) + when string_of_mlpath p = "Pulse.Lib.Array.PtsTo.op_Array_Assignment" -> + EBufWrite (cb e, cb i, cb v) + | MLE_App ({ expr = MLE_TApp({ expr = MLE_Name p }, _) }, [ e; i; v; _; _ ]) when string_of_mlpath p = "Pulse.Lib.Array.Core.mask_write" -> EBufWrite (cb e, cb i, cb v)