summaryrefslogtreecommitdiffstats
path: root/stdlib/map.ml
diff options
context:
space:
mode:
Diffstat (limited to 'stdlib/map.ml')
-rw-r--r--stdlib/map.ml116
1 files changed, 111 insertions, 5 deletions
diff --git a/stdlib/map.ml b/stdlib/map.ml
index a159e7aef..3d9597aa0 100644
--- a/stdlib/map.ml
+++ b/stdlib/map.ml
@@ -25,16 +25,28 @@ module type S =
type +'a t
val empty: 'a t
val is_empty: 'a t -> bool
+ val mem: key -> 'a t -> bool
val add: key -> 'a -> 'a t -> 'a t
- val find: key -> 'a t -> 'a
+ val singleton: key -> 'a -> 'a t
val remove: key -> 'a t -> 'a t
- val mem: key -> 'a t -> bool
+ val merge: (key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t
+ val compare: ('a -> 'a -> int) -> 'a t -> 'a t -> int
+ val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool
val iter: (key -> 'a -> unit) -> 'a t -> unit
+ val fold: (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b
+ val for_all: (key -> 'a -> bool) -> 'a t -> bool
+ val exists: (key -> 'a -> bool) -> 'a t -> bool
+ val filter: (key -> 'a -> bool) -> 'a t -> 'a t
+ val partition: (key -> 'a -> bool) -> 'a t -> 'a t * 'a t
+ val cardinal: 'a t -> int
+ val bindings: 'a t -> (key * 'a) list
+ val min_binding: 'a t -> (key * 'a)
+ val max_binding: 'a t -> (key * 'a)
+ val choose: 'a t -> (key * 'a)
+ val split: key -> 'a t -> 'a t * 'a option * 'a t
+ val find: key -> 'a t -> 'a
val map: ('a -> 'b) -> 'a t -> 'b t
val mapi: (key -> 'a -> 'b) -> 'a t -> 'b t
- val fold: (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b
- val compare: ('a -> 'a -> int) -> 'a t -> 'a t -> int
- val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool
end
module Make(Ord: OrderedType) = struct
@@ -53,6 +65,8 @@ module Make(Ord: OrderedType) = struct
let hl = height l and hr = height r in
Node(l, x, d, r, (if hl >= hr then hl + 1 else hr + 1))
+ let singleton x d = Node(Empty, x, d, Empty, 1)
+
let bal l x d r =
let hl = match l with Empty -> 0 | Node(_,_,_,_,h) -> h in
let hr = match r with Empty -> 0 | Node(_,_,_,_,h) -> h in
@@ -119,6 +133,11 @@ module Make(Ord: OrderedType) = struct
| Node(Empty, x, d, r, _) -> (x, d)
| Node(l, x, d, r, _) -> min_binding l
+ let rec max_binding = function
+ Empty -> raise Not_found
+ | Node(l, x, d, Empty, _) -> (x, d)
+ | Node(l, x, d, r, _) -> max_binding r
+
let rec remove_min_binding = function
Empty -> invalid_arg "Map.remove_min_elt"
| Node(Empty, x, d, r, _) -> r
@@ -173,6 +192,80 @@ module Make(Ord: OrderedType) = struct
| Node(l, v, d, r, _) ->
fold f r (f v d (fold f l accu))
+ let rec for_all p = function
+ Empty -> true
+ | Node(l, v, d, r, _) -> p v d && for_all p l && for_all p r
+
+ let rec exists p = function
+ Empty -> false
+ | Node(l, v, d, r, _) -> p v d || exists p l || exists p r
+
+ let filter p s =
+ let rec filt accu = function
+ | Empty -> accu
+ | Node(l, v, d, r, _) ->
+ filt (filt (if p v d then add v d accu else accu) l) r in
+ filt Empty s
+
+ let partition p s =
+ let rec part (t, f as accu) = function
+ | Empty -> accu
+ | Node(l, v, d, r, _) ->
+ part (part (if p v d then (add v d t, f) else (t, add v d f)) l) r in
+ part (Empty, Empty) s
+
+ (* Same as create and bal, but no assumptions are made on the
+ relative heights of l and r. *)
+
+ let rec join l v d r =
+ match (l, r) with
+ (Empty, _) -> add v d r
+ | (_, Empty) -> add v d l
+ | (Node(ll, lv, ld, lr, lh), Node(rl, rv, rd, rr, rh)) ->
+ if lh > rh + 2 then bal ll lv ld (join lr v d r) else
+ if rh > lh + 2 then bal (join l v d rl) rv rd rr else
+ create l v d r
+
+ (* Merge two trees l and r into one.
+ All elements of l must precede the elements of r.
+ No assumption on the heights of l and r. *)
+
+ let concat t1 t2 =
+ match (t1, t2) with
+ (Empty, t) -> t
+ | (t, Empty) -> t
+ | (_, _) ->
+ let (x, d) = min_binding t2 in
+ join t1 x d (remove_min_binding t2)
+
+ let concat_or_join t1 v d t2 =
+ match d with
+ | Some d -> join t1 v d t2
+ | None -> concat t1 t2
+
+ let rec split x = function
+ Empty ->
+ (Empty, None, Empty)
+ | Node(l, v, d, r, _) ->
+ let c = Ord.compare x v in
+ if c = 0 then (l, Some d, r)
+ else if c < 0 then
+ let (ll, pres, rl) = split x l in (ll, pres, join rl v d r)
+ else
+ let (lr, pres, rr) = split x r in (join l v d lr, pres, rr)
+
+ let rec merge f s1 s2 =
+ match (s1, s2) with
+ (Empty, Empty) -> Empty
+ | (Node (l1, v1, d1, r1, h1), _) when h1 >= height s2 ->
+ let (l2, d2, r2) = split v1 s2 in
+ concat_or_join (merge f l1 l2) v1 (f v1 (Some d1) d2) (merge f r1 r2)
+ | (_, Node (l2, v2, d2, r2, h2)) ->
+ let (l1, d1, r1) = split v2 s1 in
+ concat_or_join (merge f l1 l2) v2 (f v2 d1 (Some d2)) (merge f r1 r2)
+ | _ ->
+ assert false
+
type 'a enumeration = End | More of key * 'a * 'a t * 'a enumeration
let rec cons_enum m e =
@@ -205,4 +298,17 @@ module Make(Ord: OrderedType) = struct
equal_aux (cons_enum r1 e1) (cons_enum r2 e2)
in equal_aux (cons_enum m1 End) (cons_enum m2 End)
+ let rec cardinal = function
+ Empty -> 0
+ | Node(l, _, _, r, _) -> cardinal l + 1 + cardinal r
+
+ let rec bindings_aux accu = function
+ Empty -> accu
+ | Node(l, v, d, r, _) -> bindings_aux ((v, d) :: bindings_aux accu r) l
+
+ let bindings s =
+ bindings_aux [] s
+
+ let choose = min_binding
+
end