diff --git a/Batteries/Data/List/Basic.lean b/Batteries/Data/List/Basic.lean index fc2cecae44..ca8d6ca04b 100644 --- a/Batteries/Data/List/Basic.lean +++ b/Batteries/Data/List/Basic.lean @@ -117,8 +117,14 @@ splitAt 2 [a, b, c] = ([a, b], [c]) ``` -/ def splitAt (n : Nat) (l : List α) : List α × List α := go l n [] where - /-- Auxiliary for `splitAt`: `splitAt.go xs n acc = (acc.reverse ++ take n xs, drop n xs)`. -/ + /-- + Auxiliary for `splitAt`: + `splitAt.go l xs n acc = (acc.reverse ++ take n xs, drop n xs)` if `n < xs.length`, + and `(l, [])` otherwise. + -/ go : List α → Nat → List α → List α × List α + | [], _, _ => (l, []) -- This branch ensures the pointer equality of the result with the input + -- without any runtime branching cost. | x :: xs, n+1, acc => go xs n (x :: acc) | xs, _, acc => (acc.reverse, xs) diff --git a/Batteries/Data/List/Lemmas.lean b/Batteries/Data/List/Lemmas.lean index 9e5f4e3ace..74d44af907 100644 --- a/Batteries/Data/List/Lemmas.lean +++ b/Batteries/Data/List/Lemmas.lean @@ -192,18 +192,21 @@ theorem get?_set_of_lt' (a : α) {m n} (l : List α) (h : m < length l) : /-! ### splitAt -/ theorem splitAt_go (n : Nat) (l acc : List α) : - splitAt.go l n acc = (acc.reverse ++ l.take n, l.drop n) := by - induction l generalizing n acc with + splitAt.go l xs n acc = + if n < xs.length then (acc.reverse ++ xs.take n, xs.drop n) else (l, []) := by + induction xs generalizing n acc with | nil => simp [splitAt.go] | cons x xs ih => cases n with | zero => simp [splitAt.go] | succ n => rw [splitAt.go, take_succ_cons, drop_succ_cons, ih n (x :: acc), - reverse_cons, append_assoc, singleton_append] + reverse_cons, append_assoc, singleton_append, length_cons] + simp only [Nat.succ_lt_succ_iff] theorem splitAt_eq (n : Nat) (l : List α) : splitAt n l = (l.take n, l.drop n) := by rw [splitAt, splitAt_go, reverse_nil, nil_append] + split <;> simp_all [take_of_length_le, drop_of_length_le] /-! ### eraseP -/