diff options
Diffstat (limited to 'stdlib/set.ml')
-rw-r--r-- | stdlib/set.ml | 122 |
1 files changed, 70 insertions, 52 deletions
diff --git a/stdlib/set.ml b/stdlib/set.ml index c93be2393..2404c5385 100644 --- a/stdlib/set.ml +++ b/stdlib/set.ml @@ -62,21 +62,22 @@ module Make(Ord: OrderedType) = Empty -> 0 | Node(_, _, _, h) -> h - (* Creates a new node with left son l, value x and right son r. + (* Creates a new node with left son l, value v and right son r. + We must have all elements of l < v < all elements of r. l and r must be balanced and | height l - height r | <= 2. Inline expansion of height for better speed. *) - let create l x r = + let create l v r = let hl = match l with Empty -> 0 | Node(_,_,_,h) -> h in let hr = match r with Empty -> 0 | Node(_,_,_,h) -> h in - Node(l, x, r, (if hl >= hr then hl + 1 else hr + 1)) + Node(l, v, r, (if hl >= hr then hl + 1 else hr + 1)) (* Same as create, but performs one step of rebalancing if necessary. - Assumes l and r balanced. + Assumes l and r balanced and | height l - height r | <= 3. Inline expansion of create for better speed in the most frequent case where no rebalancing is required. *) - let bal l x r = + let bal l v r = let hl = match l with Empty -> 0 | Node(_,_,_,h) -> h in let hr = match r with Empty -> 0 | Node(_,_,_,h) -> h in if hl > hr + 2 then begin @@ -84,70 +85,104 @@ module Make(Ord: OrderedType) = Empty -> invalid_arg "Set.bal" | Node(ll, lv, lr, _) -> if height ll >= height lr then - create ll lv (create lr x r) + create ll lv (create lr v r) else begin match lr with Empty -> invalid_arg "Set.bal" | Node(lrl, lrv, lrr, _)-> - create (create ll lv lrl) lrv (create lrr x r) + create (create ll lv lrl) lrv (create lrr v r) end end else if hr > hl + 2 then begin match r with Empty -> invalid_arg "Set.bal" | Node(rl, rv, rr, _) -> if height rr >= height rl then - create (create l x rl) rv rr + create (create l v rl) rv rr else begin match rl with Empty -> invalid_arg "Set.bal" | Node(rll, rlv, rlr, _) -> - create (create l x rll) rlv (create rlr rv rr) + create (create l v rll) rlv (create rlr rv rr) end end else - Node(l, x, r, (if hl >= hr then hl + 1 else hr + 1)) + Node(l, v, r, (if hl >= hr then hl + 1 else hr + 1)) - (* Same as bal, but repeat rebalancing until the final result - is balanced. *) + (* Insertion of one element *) - let rec join l x r = - match bal l x r with - Empty -> invalid_arg "Set.join" - | Node(l', x', r', _) as t' -> - let d = height l' - height r' in - if d < -2 || d > 2 then join l' x' r' else t' + let rec add x = function + Empty -> Node(Empty, x, Empty, 1) + | Node(l, v, r, _) as t -> + let c = Ord.compare x v in + if c = 0 then t else + if c < 0 then bal (add x l) v r else bal l v (add x r) + + (* Same as create and bal, but no assumptions are made on the + relative heights of l and r. *) + + let rec join l v r = + match (l, r) with + (Empty, _) -> add v r + | (_, Empty) -> add v l + | (Node(ll, lv, lr, lh), Node(rl, rv, rr, rh)) -> + if lh > rh + 2 then bal ll lv (join lr v r) else + if rh > lh + 2 then bal (join l v rl) rv rr else + create l v r + + (* Smallest and greatest element of a set *) + + let rec min_elt = function + Empty -> raise Not_found + | Node(Empty, v, r, _) -> v + | Node(l, v, r, _) -> min_elt l + + let rec max_elt = function + Empty -> raise Not_found + | Node(l, v, Empty, _) -> v + | Node(l, v, r, _) -> max_elt r + + (* Remove the smallest element of the given set *) + + let rec remove_min_elt = function + Empty -> invalid_arg "Set.remove_min_elt" + | Node(Empty, v, r, _) -> r + | Node(l, v, r, _) -> bal (remove_min_elt l) v r (* Merge two trees l and r into one. All elements of l must precede the elements of r. - Assumes | height l - height r | <= 2. *) + Assume | height l - height r | <= 2. *) - let rec merge t1 t2 = + let merge t1 t2 = match (t1, t2) with (Empty, t) -> t | (t, Empty) -> t - | (Node(l1, v1, r1, h1), Node(l2, v2, r2, h2)) -> - bal l1 v1 (bal (merge r1 l2) v2 r2) + | (_, _) -> bal t1 (min_elt t2) (remove_min_elt t2) - (* Same as merge, but does not assume anything about l and r. *) + (* Merge two trees l and r into one. + All elements of l must precede the elements of r. + No assumption on the heights of l and r. *) - let rec concat t1 t2 = + let concat t1 t2 = match (t1, t2) with (Empty, t) -> t | (t, Empty) -> t - | (Node(l1, v1, r1, h1), Node(l2, v2, r2, h2)) -> - join l1 v1 (join (concat r1 l2) v2 r2) + | (_, _) -> join t1 (min_elt t2) (remove_min_elt t2) - (* Splitting *) + (* Splitting. split x s returns a triple (l, present, r) where + - l is the set of elements of s that are < x + - r is the set of elements of s that are > x + - present is false if s contains no element equal to x, + or true if s contains an element equal to x. *) let rec split x = function Empty -> - (Empty, None, Empty) + (Empty, false, Empty) | Node(l, v, r, _) -> let c = Ord.compare x v in - if c = 0 then (l, Some v, r) + if c = 0 then (l, true, r) else if c < 0 then - let (ll, vl, rl) = split x l in (ll, vl, join rl v r) + let (ll, pres, rl) = split x l in (ll, pres, join rl v r) else - let (lr, vr, rr) = split x r in (join l v lr, vr, rr) + let (lr, pres, rr) = split x r in (join l v lr, pres, rr) (* Implementation of the set operations *) @@ -161,13 +196,6 @@ module Make(Ord: OrderedType) = let c = Ord.compare x v in c = 0 || mem x (if c < 0 then l else r) - let rec add x = function - Empty -> Node(Empty, x, Empty, 1) - | Node(l, v, r, _) as t -> - let c = Ord.compare x v in - if c = 0 then t else - if c < 0 then bal (add x l) v r else bal l v (add x r) - let singleton x = Node(Empty, x, Empty, 1) let rec remove x = function @@ -199,9 +227,9 @@ module Make(Ord: OrderedType) = | (t1, Empty) -> Empty | (Node(l1, v1, r1, _), t2) -> match split v1 t2 with - (l2, None, r2) -> + (l2, false, r2) -> concat (inter l1 l2) (inter r1 r2) - | (l2, Some _, r2) -> + | (l2, true, r2) -> join (inter l1 l2) v1 (inter r1 r2) let rec diff s1 s2 = @@ -210,9 +238,9 @@ module Make(Ord: OrderedType) = | (t1, Empty) -> t1 | (Node(l1, v1, r1, _), t2) -> match split v1 t2 with - (l2, None, r2) -> + (l2, false, r2) -> join (diff l1 l2) v1 (diff r1 r2) - | (l2, Some _, r2) -> + | (l2, true, r2) -> concat (diff l1 l2) (diff r1 r2) let rec compare_aux l1 l2 = @@ -293,16 +321,6 @@ module Make(Ord: OrderedType) = let elements s = elements_aux [] s - let rec min_elt = function - Empty -> raise Not_found - | Node(Empty, v, r, _) -> v - | Node(l, v, r, _) -> min_elt l - - let rec max_elt = function - Empty -> raise Not_found - | Node(l, v, Empty, _) -> v - | Node(l, v, r, _) -> max_elt r - let choose = min_elt end |