summaryrefslogtreecommitdiffstats
path: root/stdlib/set.ml
diff options
context:
space:
mode:
Diffstat (limited to 'stdlib/set.ml')
-rw-r--r--stdlib/set.ml122
1 files changed, 70 insertions, 52 deletions
diff --git a/stdlib/set.ml b/stdlib/set.ml
index c93be2393..2404c5385 100644
--- a/stdlib/set.ml
+++ b/stdlib/set.ml
@@ -62,21 +62,22 @@ module Make(Ord: OrderedType) =
Empty -> 0
| Node(_, _, _, h) -> h
- (* Creates a new node with left son l, value x and right son r.
+ (* Creates a new node with left son l, value v and right son r.
+ We must have all elements of l < v < all elements of r.
l and r must be balanced and | height l - height r | <= 2.
Inline expansion of height for better speed. *)
- let create l x r =
+ let create l v r =
let hl = match l with Empty -> 0 | Node(_,_,_,h) -> h in
let hr = match r with Empty -> 0 | Node(_,_,_,h) -> h in
- Node(l, x, r, (if hl >= hr then hl + 1 else hr + 1))
+ Node(l, v, r, (if hl >= hr then hl + 1 else hr + 1))
(* Same as create, but performs one step of rebalancing if necessary.
- Assumes l and r balanced.
+ Assumes l and r balanced and | height l - height r | <= 3.
Inline expansion of create for better speed in the most frequent case
where no rebalancing is required. *)
- let bal l x r =
+ let bal l v r =
let hl = match l with Empty -> 0 | Node(_,_,_,h) -> h in
let hr = match r with Empty -> 0 | Node(_,_,_,h) -> h in
if hl > hr + 2 then begin
@@ -84,70 +85,104 @@ module Make(Ord: OrderedType) =
Empty -> invalid_arg "Set.bal"
| Node(ll, lv, lr, _) ->
if height ll >= height lr then
- create ll lv (create lr x r)
+ create ll lv (create lr v r)
else begin
match lr with
Empty -> invalid_arg "Set.bal"
| Node(lrl, lrv, lrr, _)->
- create (create ll lv lrl) lrv (create lrr x r)
+ create (create ll lv lrl) lrv (create lrr v r)
end
end else if hr > hl + 2 then begin
match r with
Empty -> invalid_arg "Set.bal"
| Node(rl, rv, rr, _) ->
if height rr >= height rl then
- create (create l x rl) rv rr
+ create (create l v rl) rv rr
else begin
match rl with
Empty -> invalid_arg "Set.bal"
| Node(rll, rlv, rlr, _) ->
- create (create l x rll) rlv (create rlr rv rr)
+ create (create l v rll) rlv (create rlr rv rr)
end
end else
- Node(l, x, r, (if hl >= hr then hl + 1 else hr + 1))
+ Node(l, v, r, (if hl >= hr then hl + 1 else hr + 1))
- (* Same as bal, but repeat rebalancing until the final result
- is balanced. *)
+ (* Insertion of one element *)
- let rec join l x r =
- match bal l x r with
- Empty -> invalid_arg "Set.join"
- | Node(l', x', r', _) as t' ->
- let d = height l' - height r' in
- if d < -2 || d > 2 then join l' x' r' else t'
+ let rec add x = function
+ Empty -> Node(Empty, x, Empty, 1)
+ | Node(l, v, r, _) as t ->
+ let c = Ord.compare x v in
+ if c = 0 then t else
+ if c < 0 then bal (add x l) v r else bal l v (add x r)
+
+ (* Same as create and bal, but no assumptions are made on the
+ relative heights of l and r. *)
+
+ let rec join l v r =
+ match (l, r) with
+ (Empty, _) -> add v r
+ | (_, Empty) -> add v l
+ | (Node(ll, lv, lr, lh), Node(rl, rv, rr, rh)) ->
+ if lh > rh + 2 then bal ll lv (join lr v r) else
+ if rh > lh + 2 then bal (join l v rl) rv rr else
+ create l v r
+
+ (* Smallest and greatest element of a set *)
+
+ let rec min_elt = function
+ Empty -> raise Not_found
+ | Node(Empty, v, r, _) -> v
+ | Node(l, v, r, _) -> min_elt l
+
+ let rec max_elt = function
+ Empty -> raise Not_found
+ | Node(l, v, Empty, _) -> v
+ | Node(l, v, r, _) -> max_elt r
+
+ (* Remove the smallest element of the given set *)
+
+ let rec remove_min_elt = function
+ Empty -> invalid_arg "Set.remove_min_elt"
+ | Node(Empty, v, r, _) -> r
+ | Node(l, v, r, _) -> bal (remove_min_elt l) v r
(* Merge two trees l and r into one.
All elements of l must precede the elements of r.
- Assumes | height l - height r | <= 2. *)
+ Assume | height l - height r | <= 2. *)
- let rec merge t1 t2 =
+ let merge t1 t2 =
match (t1, t2) with
(Empty, t) -> t
| (t, Empty) -> t
- | (Node(l1, v1, r1, h1), Node(l2, v2, r2, h2)) ->
- bal l1 v1 (bal (merge r1 l2) v2 r2)
+ | (_, _) -> bal t1 (min_elt t2) (remove_min_elt t2)
- (* Same as merge, but does not assume anything about l and r. *)
+ (* Merge two trees l and r into one.
+ All elements of l must precede the elements of r.
+ No assumption on the heights of l and r. *)
- let rec concat t1 t2 =
+ let concat t1 t2 =
match (t1, t2) with
(Empty, t) -> t
| (t, Empty) -> t
- | (Node(l1, v1, r1, h1), Node(l2, v2, r2, h2)) ->
- join l1 v1 (join (concat r1 l2) v2 r2)
+ | (_, _) -> join t1 (min_elt t2) (remove_min_elt t2)
- (* Splitting *)
+ (* Splitting. split x s returns a triple (l, present, r) where
+ - l is the set of elements of s that are < x
+ - r is the set of elements of s that are > x
+ - present is false if s contains no element equal to x,
+ or true if s contains an element equal to x. *)
let rec split x = function
Empty ->
- (Empty, None, Empty)
+ (Empty, false, Empty)
| Node(l, v, r, _) ->
let c = Ord.compare x v in
- if c = 0 then (l, Some v, r)
+ if c = 0 then (l, true, r)
else if c < 0 then
- let (ll, vl, rl) = split x l in (ll, vl, join rl v r)
+ let (ll, pres, rl) = split x l in (ll, pres, join rl v r)
else
- let (lr, vr, rr) = split x r in (join l v lr, vr, rr)
+ let (lr, pres, rr) = split x r in (join l v lr, pres, rr)
(* Implementation of the set operations *)
@@ -161,13 +196,6 @@ module Make(Ord: OrderedType) =
let c = Ord.compare x v in
c = 0 || mem x (if c < 0 then l else r)
- let rec add x = function
- Empty -> Node(Empty, x, Empty, 1)
- | Node(l, v, r, _) as t ->
- let c = Ord.compare x v in
- if c = 0 then t else
- if c < 0 then bal (add x l) v r else bal l v (add x r)
-
let singleton x = Node(Empty, x, Empty, 1)
let rec remove x = function
@@ -199,9 +227,9 @@ module Make(Ord: OrderedType) =
| (t1, Empty) -> Empty
| (Node(l1, v1, r1, _), t2) ->
match split v1 t2 with
- (l2, None, r2) ->
+ (l2, false, r2) ->
concat (inter l1 l2) (inter r1 r2)
- | (l2, Some _, r2) ->
+ | (l2, true, r2) ->
join (inter l1 l2) v1 (inter r1 r2)
let rec diff s1 s2 =
@@ -210,9 +238,9 @@ module Make(Ord: OrderedType) =
| (t1, Empty) -> t1
| (Node(l1, v1, r1, _), t2) ->
match split v1 t2 with
- (l2, None, r2) ->
+ (l2, false, r2) ->
join (diff l1 l2) v1 (diff r1 r2)
- | (l2, Some _, r2) ->
+ | (l2, true, r2) ->
concat (diff l1 l2) (diff r1 r2)
let rec compare_aux l1 l2 =
@@ -293,16 +321,6 @@ module Make(Ord: OrderedType) =
let elements s =
elements_aux [] s
- let rec min_elt = function
- Empty -> raise Not_found
- | Node(Empty, v, r, _) -> v
- | Node(l, v, r, _) -> min_elt l
-
- let rec max_elt = function
- Empty -> raise Not_found
- | Node(l, v, Empty, _) -> v
- | Node(l, v, r, _) -> max_elt r
-
let choose = min_elt
end