diff options
Diffstat (limited to 'stdlib/map.ml')
-rw-r--r-- | stdlib/map.ml | 116 |
1 files changed, 111 insertions, 5 deletions
diff --git a/stdlib/map.ml b/stdlib/map.ml index a159e7aef..3d9597aa0 100644 --- a/stdlib/map.ml +++ b/stdlib/map.ml @@ -25,16 +25,28 @@ module type S = type +'a t val empty: 'a t val is_empty: 'a t -> bool + val mem: key -> 'a t -> bool val add: key -> 'a -> 'a t -> 'a t - val find: key -> 'a t -> 'a + val singleton: key -> 'a -> 'a t val remove: key -> 'a t -> 'a t - val mem: key -> 'a t -> bool + val merge: (key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t + val compare: ('a -> 'a -> int) -> 'a t -> 'a t -> int + val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool val iter: (key -> 'a -> unit) -> 'a t -> unit + val fold: (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b + val for_all: (key -> 'a -> bool) -> 'a t -> bool + val exists: (key -> 'a -> bool) -> 'a t -> bool + val filter: (key -> 'a -> bool) -> 'a t -> 'a t + val partition: (key -> 'a -> bool) -> 'a t -> 'a t * 'a t + val cardinal: 'a t -> int + val bindings: 'a t -> (key * 'a) list + val min_binding: 'a t -> (key * 'a) + val max_binding: 'a t -> (key * 'a) + val choose: 'a t -> (key * 'a) + val split: key -> 'a t -> 'a t * 'a option * 'a t + val find: key -> 'a t -> 'a val map: ('a -> 'b) -> 'a t -> 'b t val mapi: (key -> 'a -> 'b) -> 'a t -> 'b t - val fold: (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b - val compare: ('a -> 'a -> int) -> 'a t -> 'a t -> int - val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool end module Make(Ord: OrderedType) = struct @@ -53,6 +65,8 @@ module Make(Ord: OrderedType) = struct let hl = height l and hr = height r in Node(l, x, d, r, (if hl >= hr then hl + 1 else hr + 1)) + let singleton x d = Node(Empty, x, d, Empty, 1) + let bal l x d r = let hl = match l with Empty -> 0 | Node(_,_,_,_,h) -> h in let hr = match r with Empty -> 0 | Node(_,_,_,_,h) -> h in @@ -119,6 +133,11 @@ module Make(Ord: OrderedType) = struct | Node(Empty, x, d, r, _) -> (x, d) | Node(l, x, d, r, _) -> min_binding l + let rec max_binding = function + Empty -> raise Not_found + | Node(l, x, d, Empty, _) -> (x, d) + | Node(l, x, d, r, _) -> max_binding r + let rec remove_min_binding = function Empty -> invalid_arg "Map.remove_min_elt" | Node(Empty, x, d, r, _) -> r @@ -173,6 +192,80 @@ module Make(Ord: OrderedType) = struct | Node(l, v, d, r, _) -> fold f r (f v d (fold f l accu)) + let rec for_all p = function + Empty -> true + | Node(l, v, d, r, _) -> p v d && for_all p l && for_all p r + + let rec exists p = function + Empty -> false + | Node(l, v, d, r, _) -> p v d || exists p l || exists p r + + let filter p s = + let rec filt accu = function + | Empty -> accu + | Node(l, v, d, r, _) -> + filt (filt (if p v d then add v d accu else accu) l) r in + filt Empty s + + let partition p s = + let rec part (t, f as accu) = function + | Empty -> accu + | Node(l, v, d, r, _) -> + part (part (if p v d then (add v d t, f) else (t, add v d f)) l) r in + part (Empty, Empty) s + + (* Same as create and bal, but no assumptions are made on the + relative heights of l and r. *) + + let rec join l v d r = + match (l, r) with + (Empty, _) -> add v d r + | (_, Empty) -> add v d l + | (Node(ll, lv, ld, lr, lh), Node(rl, rv, rd, rr, rh)) -> + if lh > rh + 2 then bal ll lv ld (join lr v d r) else + if rh > lh + 2 then bal (join l v d rl) rv rd rr else + create l v d r + + (* Merge two trees l and r into one. + All elements of l must precede the elements of r. + No assumption on the heights of l and r. *) + + let concat t1 t2 = + match (t1, t2) with + (Empty, t) -> t + | (t, Empty) -> t + | (_, _) -> + let (x, d) = min_binding t2 in + join t1 x d (remove_min_binding t2) + + let concat_or_join t1 v d t2 = + match d with + | Some d -> join t1 v d t2 + | None -> concat t1 t2 + + let rec split x = function + Empty -> + (Empty, None, Empty) + | Node(l, v, d, r, _) -> + let c = Ord.compare x v in + if c = 0 then (l, Some d, r) + else if c < 0 then + let (ll, pres, rl) = split x l in (ll, pres, join rl v d r) + else + let (lr, pres, rr) = split x r in (join l v d lr, pres, rr) + + let rec merge f s1 s2 = + match (s1, s2) with + (Empty, Empty) -> Empty + | (Node (l1, v1, d1, r1, h1), _) when h1 >= height s2 -> + let (l2, d2, r2) = split v1 s2 in + concat_or_join (merge f l1 l2) v1 (f v1 (Some d1) d2) (merge f r1 r2) + | (_, Node (l2, v2, d2, r2, h2)) -> + let (l1, d1, r1) = split v2 s1 in + concat_or_join (merge f l1 l2) v2 (f v2 d1 (Some d2)) (merge f r1 r2) + | _ -> + assert false + type 'a enumeration = End | More of key * 'a * 'a t * 'a enumeration let rec cons_enum m e = @@ -205,4 +298,17 @@ module Make(Ord: OrderedType) = struct equal_aux (cons_enum r1 e1) (cons_enum r2 e2) in equal_aux (cons_enum m1 End) (cons_enum m2 End) + let rec cardinal = function + Empty -> 0 + | Node(l, _, _, r, _) -> cardinal l + 1 + cardinal r + + let rec bindings_aux accu = function + Empty -> accu + | Node(l, v, d, r, _) -> bindings_aux ((v, d) :: bindings_aux accu r) l + + let bindings s = + bindings_aux [] s + + let choose = min_binding + end |