From 2356df10959de4347921bf1ed5da2cd045d88591 Mon Sep 17 00:00:00 2001 From: Eric Ihli Date: Wed, 14 Apr 2021 13:22:55 -0500 Subject: [PATCH] Performance gains with type hints --- pom.xml | 2 +- src/com/owoga/tightly_packed_trie.clj | 116 +++++++----------- .../owoga/tightly_packed_trie/bit_manip.clj | 12 +- .../owoga/tightly_packed_trie/encoding.clj | 48 ++++---- src/com/owoga/trie.clj | 66 +++++----- test/tightly_packed_trie_test.clj | 16 +-- 6 files changed, 116 insertions(+), 144 deletions(-) diff --git a/pom.xml b/pom.xml index e32f161..a266b4c 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ jar com.owoga tightly-packed-trie - 0.2.2 + 0.2.3 tightly-packed-trie scm:git:git://github.com/eihli/clj-tightly-packed-trie.git diff --git a/src/com/owoga/tightly_packed_trie.clj b/src/com/owoga/tightly_packed_trie.clj index dbb5f8b..b2c23de 100644 --- a/src/com/owoga/tightly_packed_trie.clj +++ b/src/com/owoga/tightly_packed_trie.clj @@ -3,8 +3,7 @@ [com.owoga.tightly-packed-trie.encoding :as encoding] [clojure.java.io :as io] [clojure.string :as string] - [com.owoga.tightly-packed-trie.bit-manip :as bm] - [clojure.zip :as zip]) + [com.owoga.tightly-packed-trie.bit-manip :as bm]) (:import (java.io ByteArrayOutputStream ByteArrayInputStream DataOutputStream DataInputStream))) @@ -38,39 +37,8 @@ (.limit ~byte-buffer original-limit#) (.position ~byte-buffer original-position#))))) -(defn -trie->depth-first-post-order-traversable-zipperable-vector - [path node decode-value-fn] - (vec - (map - (fn [child] - [(-trie->depth-first-post-order-traversable-zipperable-vector - (conj path (.key child)) - child - decode-value-fn) - (wrap-byte-buffer - (.byte-buffer child) - (.limit (.byte-buffer child) (.limit child)) - (.position (.byte-buffer child) (.address child)) - (clojure.lang.MapEntry. - (conj path (.key child)) - (decode-value-fn (.byte-buffer child))))]) - (trie/children node)))) - -(defn trie->depth-first-post-order-traversable-zipperable-vector - [path node decode-value-fn] - (let [byte-buffer (.byte-buffer node) - val (wrap-byte-buffer - byte-buffer - (.limit byte-buffer (.limit node)) - (.position byte-buffer (.address node)) - (decode-value-fn byte-buffer))] - [(-trie->depth-first-post-order-traversable-zipperable-vector - path - node - decode-value-fn) - (clojure.lang.MapEntry. path val)])) - -(defn rewind-to-key [bb stop] +(defn rewind-to-key [^java.nio.ByteBuffer bb + ^Integer stop] (loop [] (let [current (.get bb (.position bb)) previous (.get bb (dec (.position bb)))] @@ -81,18 +49,11 @@ (do (.position bb (dec (.position bb))) (recur)))))) -(defn forward-to-key [bb stop] - (loop [] - (if (or (= stop (.position bb)) - (and (encoding/key-byte? (.get bb (.position bb))) - (encoding/offset-byte? - (.get bb (inc (.position bb)))))) - bb - (do (.position bb (inc (.position bb))) - (recur))))) - (defn find-key-in-index - [bb target-key max-address not-found] + [^java.nio.ByteBuffer bb + ^Integer target-key + ^Integer max-address + not-found] (.limit bb max-address) (let [key (loop [previous-key nil @@ -100,9 +61,9 @@ max-position max-address] (if (zero? (- max-position min-position)) not-found - (let [mid-position (+ min-position (quot (- max-position min-position) 2))] + (let [^Integer mid-position (+ min-position (quot (- max-position min-position) 2))] (.position bb mid-position) - (let [bb (rewind-to-key bb min-position) + (let [^java.nio.ByteBuffer bb (rewind-to-key bb min-position) current-key (encoding/decode-number-from-tightly-packed-trie-index bb)] (cond @@ -144,14 +105,13 @@ {:id value :count freq})) -(defn -value [trie value-decode-fn] - (wrap-byte-buffer - (.byte-buffer trie) - (.limit (.byte-buffer trie) (.limit trie)) - (.position (.byte-buffer trie) (.address trie)) - (value-decode-fn (.byte-buffer trie)))) +(declare -value) -(deftype TightlyPackedTrie [byte-buffer key address limit value-decode-fn] +(deftype TightlyPackedTrie [^java.nio.ByteBuffer byte-buffer + ^Integer key + ^Integer address + ^Integer limit + value-decode-fn] trie/ITrie (lookup [self ks] (wrap-byte-buffer @@ -183,8 +143,8 @@ (.position byte-buffer address) (let [val (value-decode-fn byte-buffer) size-of-index (encoding/decode byte-buffer)] - (.limit byte-buffer (+ (.position byte-buffer) - size-of-index)) + (.limit byte-buffer ^Integer (+ (.position byte-buffer) + size-of-index)) (loop [children []] (if (= (.position byte-buffer) (.limit byte-buffer)) children @@ -202,7 +162,7 @@ clojure.lang.ILookup (valAt [self ks] - (if-let [node (trie/lookup self ks)] + (if-let [^TightlyPackedTrie node (trie/lookup self ks)] (-value node value-decode-fn) nil)) (valAt [self ks not-found] @@ -214,7 +174,9 @@ clojure.lang.Seqable (seq [trie] - (let [step (fn step [path [[node & nodes] & stack] [parent & parents]] + (let [step (fn step [^clojure.lang.PersistentList path + [[^TightlyPackedTrie node & nodes] & stack] + [^TightlyPackedTrie parent & parents]] (cond node (step (conj path (.key node)) @@ -225,11 +187,11 @@ (lazy-seq (cons (clojure.lang.MapEntry. (rest path) - (let [byte-buffer (.byte-buffer parent)] + (let [^java.nio.ByteBuffer byte-buffer (.byte-buffer parent)] (wrap-byte-buffer byte-buffer - (.limit byte-buffer (.limit parent)) - (.position byte-buffer (.address parent)) + (.limit byte-buffer ^Integer (.limit parent)) + (.position byte-buffer ^Integer (.address parent)) (value-decode-fn byte-buffer)))) (step (pop path) stack @@ -237,6 +199,14 @@ :else nil))] (step [] (list (list trie)) '())))) +(defn -value [^TightlyPackedTrie trie value-decode-fn] + (let [^java.nio.ByteBuffer byte-buffer (.byte-buffer trie)] + (wrap-byte-buffer + byte-buffer + (.limit byte-buffer ^Integer (.limit trie)) + (.position byte-buffer ^Integer (.address trie)) + (value-decode-fn byte-buffer)))) + (defmethod print-method TightlyPackedTrie [trie ^java.io.Writer w] (print-method (into {} trie) w)) @@ -244,17 +214,19 @@ (print-ctor trie (fn [o w] (print-dup (into {} trie) w)) w)) (defn tightly-packed-trie - [trie value-encode-fn value-decode-fn] - (let [baos (ByteArrayOutputStream.)] + [^TightlyPackedTrie trie + value-encode-fn + value-decode-fn] + (let [^ByteArrayOutputStream baos (ByteArrayOutputStream.)] (loop [nodes (seq trie) current-offset 8 previous-depth 0 child-indexes []] - (let [current-node (first nodes) + (let [^TightlyPackedTrie current-node (first nodes) current-depth (count (first current-node))] (cond (empty? nodes) - (let [child-index (last child-indexes) + (let [^clojure.lang.PersistentVector child-index (last child-indexes) child-index-baos (ByteArrayOutputStream.) _ (->> child-index (run! @@ -269,12 +241,12 @@ child-index-byte-array (.toByteArray child-index-baos) size-of-child-index (encoding/encode (count child-index-byte-array)) root-address current-offset - value (value-encode-fn 0)] + value #^bytes (value-encode-fn 0)] (.write baos value) (.write baos size-of-child-index) (.write baos child-index-byte-array) - (let [ba (.toByteArray baos) - byte-buf (java.nio.ByteBuffer/allocate (+ 8 (count ba)))] + (let [#^bytes ba (.toByteArray baos) + ^java.nio.ByteBuffer byte-buf (java.nio.ByteBuffer/allocate (+ 8 (count ba)))] (do (.putLong byte-buf root-address) (.put byte-buf ba) (.rewind byte-buf) @@ -289,7 +261,7 @@ ;; Process index of children. (> previous-depth current-depth) (let [[k v] (first nodes) - value (value-encode-fn v) + value #^bytes (value-encode-fn v) child-index (last child-indexes) child-index-baos (ByteArrayOutputStream.) _ (->> child-index @@ -322,7 +294,7 @@ ;; Start keeping track of new children index :else (let [[k v] (first nodes) - value (value-encode-fn v) + value #^bytes (value-encode-fn v) size-of-child-index (encoding/encode 0) child-indexes (into child-indexes (vec (repeat (- current-depth previous-depth) []))) @@ -350,7 +322,7 @@ (with-open [i (io/input-stream filepath) baos (ByteArrayOutputStream.)] (io/copy i baos) - (let [byte-buffer (java.nio.ByteBuffer/wrap (.toByteArray baos))] + (let [^java.nio.ByteBuffer byte-buffer (java.nio.ByteBuffer/wrap (.toByteArray baos))] (.rewind byte-buffer) (->TightlyPackedTrie byte-buffer diff --git a/src/com/owoga/tightly_packed_trie/bit_manip.clj b/src/com/owoga/tightly_packed_trie/bit_manip.clj index 2b7959a..9996676 100644 --- a/src/com/owoga/tightly_packed_trie/bit_manip.clj +++ b/src/com/owoga/tightly_packed_trie/bit_manip.clj @@ -95,12 +95,12 @@ -> 0001010 1001011 -> 00010101001011 -> 1355 (As a long...)" - [num-significant-bits & bytes] - (reduce - (fn [a b] - (bit-or b (bit-shift-left a num-significant-bits))) - 0 - bytes)) + (#^bytes [num-significant-bits & bytes] + (reduce + (fn [a b] + (bit-or b (bit-shift-left a num-significant-bits))) + 0 + bytes))) (comment (let [b1 (bitstring->int "0110110") diff --git a/src/com/owoga/tightly_packed_trie/encoding.clj b/src/com/owoga/tightly_packed_trie/encoding.clj index 1baa47b..c22f721 100644 --- a/src/com/owoga/tightly_packed_trie/encoding.clj +++ b/src/com/owoga/tightly_packed_trie/encoding.clj @@ -30,12 +30,12 @@ To decode: if the flag bit is not set, read the next byte and concat the last 7 bits of the current byte to the last 7 bits of the next byte." - [n] - (loop [b (list (bit-set (mod n 0x80) 7)) - n (quot n 0x80)] - (if (zero? n) - (byte-array b) - (recur (cons (mod n 0x80) b) (quot n 0x80))))) + (#^bytes [n] + (loop [b (list (bit-set (mod n 0x80) 7)) + n (quot n 0x80)] + (if (zero? n) + (byte-array b) + (recur (cons (mod n 0x80) b) (quot n 0x80)))))) (comment (->> [0 1 2 127 128 129] @@ -52,8 +52,8 @@ (defn decode "Decode one variable-length-encoded number from a ByteBuffer, advancing the buffer's position to the byte following the encoded number." - [byte-buffer] - (loop [bytes (list (.get byte-buffer))] + ^Integer [^java.nio.ByteBuffer byte-buffer] + (loop [bytes (list ^Byte (.get byte-buffer))] (if (bit-test (first bytes) 7) (->> (cons (bit-clear (first bytes) 7) (rest bytes)) reverse @@ -75,29 +75,27 @@ (def offset-byte? (complement key-byte?)) (defn encode-key-to-tightly-packed-trie-index - [n] + #^bytes [n] (->> n encode (map #(bit-set % 7)) byte-array)) (defn encode-offset-to-tightly-packed-trie-index - [n] + #^bytes [n] (->> n encode (map #(bit-clear % 7)) byte-array)) (defn decode-number-from-tightly-packed-trie-index - ([byte-buffer] - (let [first-byte (.get byte-buffer) - continue? (fn [] - (and (.hasRemaining byte-buffer) - (= (key-byte? (.get byte-buffer (.position byte-buffer))) - (key-byte? first-byte))))] - (loop [bytes [first-byte]] - (if (continue?) - (recur (conj bytes (.get byte-buffer))) - (->> bytes - (map (partial bit-and 0xFF)) - (map #(bit-clear % 7)) - (apply (partial bm/combine-significant-bits 7)))))))) - -(bm/to-binary-string 0xff) + [^java.nio.ByteBuffer byte-buffer] + (let [first-byte (.get byte-buffer) + continue? (fn [] + (and (.hasRemaining byte-buffer) + (= (key-byte? (.get byte-buffer (.position byte-buffer))) + (key-byte? first-byte))))] + (loop [bytes [first-byte]] + (if (continue?) + (recur (conj bytes (.get byte-buffer))) + (->> bytes + (map (partial bit-and 0xFF)) + (map #(bit-clear % 7)) + (apply (partial bm/combine-significant-bits 7))))))) (comment (let [byte-buffer (java.nio.ByteBuffer/allocate 64)] diff --git a/src/com/owoga/trie.clj b/src/com/owoga/trie.clj index 02290f6..34a0595 100644 --- a/src/com/owoga/trie.clj +++ b/src/com/owoga/trie.clj @@ -1,37 +1,16 @@ (ns com.owoga.trie) -(declare ->Trie) - -(defn -without - [trie [k & ks]] - (if k - (if-let [next-trie (get (.children- trie) k)] - (let [next-trie-without (-without next-trie ks) - new-trie (->Trie (.key trie) - (.value trie) - (if next-trie-without - (assoc (.children- trie) k next-trie-without) - (dissoc (.children- trie) k)))] - (if (and (empty? new-trie) - (nil? (.value new-trie))) - nil - new-trie))) - (if (seq (.children- trie)) - (->Trie - (.key trie) - nil - (.children- trie)) - nil))) +(declare -without) (defprotocol ITrie (children [self] "Immediate children of a node.") - (lookup [self ks] "Return node at key.")) + (lookup [self ^clojure.lang.PersistentList ks] "Return node at key.")) -(deftype Trie [key value children-] +(deftype Trie [key value ^clojure.lang.PersistentTreeMap children-] ITrie (children [trie] (map - (fn [[k child]] + (fn [[k ^Trie child]] (Trie. k (.value child) (.children- child))) @@ -54,7 +33,7 @@ clojure.lang.ILookup (valAt [trie k] - (if-let [node (lookup trie k)] + (if-let [^Trie node (lookup trie k)] (.value node) nil)) @@ -65,18 +44,18 @@ (cons [trie entry] (cond (instance? Trie (second entry)) - (assoc trie (first entry) (.value (second entry))) + (assoc trie (first entry) (.value ^Trie (second entry))) :else (assoc trie (first entry) (second entry)))) (empty [trie] (Trie. key nil (sorted-map))) (equiv [trie o] (and (= (.value trie) - (.value o)) + (.value ^Trie o)) (= (.children- trie) - (.children- o)) + (.children- ^Trie o)) (= (.key trie) - (.key o)))) + (.key ^Trie o)))) clojure.lang.Associative (assoc [trie opath ovalue] @@ -103,15 +82,15 @@ java.lang.Iterable (iterator [trie] - (.iterator (seq trie))) + (.iterator ^clojure.lang.LazySeq (seq trie))) clojure.lang.Counted (count [trie] (count (seq trie))) clojure.lang.Seqable - (seq [trie] - (let [step (fn step [path [[node & nodes] & stack] [parent & parents]] + (seq ^clojure.lang.LazySeq [trie] + (let [step (fn step [path [[^Trie node & nodes] & stack] [^Trie parent & parents]] (cond node (step (conj path (.key node)) @@ -129,6 +108,27 @@ :else nil))] (step [] (list (list trie)) '())))) +(defn -without + [^Trie trie [k & ks]] + (if k + (if-let [next-trie (get (.children- trie) k)] + (let [next-trie-without (-without next-trie ks) + ^Trie new-trie (->Trie (.key trie) + (.value trie) + (if next-trie-without + (assoc (.children- trie) k next-trie-without) + (dissoc (.children- trie) k)))] + (if (and (empty? new-trie) + (nil? (.value new-trie))) + nil + new-trie))) + (if (seq (.children- trie)) + (->Trie + (.key trie) + nil + (.children- trie)) + nil))) + (defmethod print-method Trie [trie ^java.io.Writer w] (print-method (into {} trie) w)) diff --git a/test/tightly_packed_trie_test.clj b/test/tightly_packed_trie_test.clj index 04bacbc..bad1a05 100644 --- a/test/tightly_packed_trie_test.clj +++ b/test/tightly_packed_trie_test.clj @@ -6,14 +6,16 @@ [com.owoga.tightly-packed-trie.encoding :as encoding] [com.owoga.tightly-packed-trie.bit-manip :as bm])) -(defn value-encode-fn [v] - (if (or (= v ::tpt/root) - (nil? v)) - (encode/encode 0) - (encode/encode v))) +(set! *warn-on-reflection* true) -(defn value-decode-fn [byte-buffer] - (let [v (encode/decode byte-buffer)] +(defn value-encode-fn + (#^bytes [^Integer v] + (if (nil? v) + (encode/encode 0) + (encode/encode v)))) + +(defn value-decode-fn [^java.nio.ByteBuffer byte-buffer] + (let [^Integer v (encode/decode byte-buffer)] (if (zero? v) nil v)))