Performance gains with type hints

main
Eric Ihli 4 years ago
parent f57ea59c22
commit 2356df1095

@ -4,7 +4,7 @@
<packaging>jar</packaging> <packaging>jar</packaging>
<groupId>com.owoga</groupId> <groupId>com.owoga</groupId>
<artifactId>tightly-packed-trie</artifactId> <artifactId>tightly-packed-trie</artifactId>
<version>0.2.2</version> <version>0.2.3</version>
<name>tightly-packed-trie</name> <name>tightly-packed-trie</name>
<scm> <scm>
<connection>scm:git:git://github.com/eihli/clj-tightly-packed-trie.git</connection> <connection>scm:git:git://github.com/eihli/clj-tightly-packed-trie.git</connection>

@ -3,8 +3,7 @@
[com.owoga.tightly-packed-trie.encoding :as encoding] [com.owoga.tightly-packed-trie.encoding :as encoding]
[clojure.java.io :as io] [clojure.java.io :as io]
[clojure.string :as string] [clojure.string :as string]
[com.owoga.tightly-packed-trie.bit-manip :as bm] [com.owoga.tightly-packed-trie.bit-manip :as bm])
[clojure.zip :as zip])
(:import (java.io ByteArrayOutputStream ByteArrayInputStream (:import (java.io ByteArrayOutputStream ByteArrayInputStream
DataOutputStream DataInputStream))) DataOutputStream DataInputStream)))
@ -38,39 +37,8 @@
(.limit ~byte-buffer original-limit#) (.limit ~byte-buffer original-limit#)
(.position ~byte-buffer original-position#))))) (.position ~byte-buffer original-position#)))))
(defn -trie->depth-first-post-order-traversable-zipperable-vector (defn rewind-to-key [^java.nio.ByteBuffer bb
[path node decode-value-fn] ^Integer stop]
(vec
(map
(fn [child]
[(-trie->depth-first-post-order-traversable-zipperable-vector
(conj path (.key child))
child
decode-value-fn)
(wrap-byte-buffer
(.byte-buffer child)
(.limit (.byte-buffer child) (.limit child))
(.position (.byte-buffer child) (.address child))
(clojure.lang.MapEntry.
(conj path (.key child))
(decode-value-fn (.byte-buffer child))))])
(trie/children node))))
(defn trie->depth-first-post-order-traversable-zipperable-vector
[path node decode-value-fn]
(let [byte-buffer (.byte-buffer node)
val (wrap-byte-buffer
byte-buffer
(.limit byte-buffer (.limit node))
(.position byte-buffer (.address node))
(decode-value-fn byte-buffer))]
[(-trie->depth-first-post-order-traversable-zipperable-vector
path
node
decode-value-fn)
(clojure.lang.MapEntry. path val)]))
(defn rewind-to-key [bb stop]
(loop [] (loop []
(let [current (.get bb (.position bb)) (let [current (.get bb (.position bb))
previous (.get bb (dec (.position bb)))] previous (.get bb (dec (.position bb)))]
@ -81,18 +49,11 @@
(do (.position bb (dec (.position bb))) (do (.position bb (dec (.position bb)))
(recur)))))) (recur))))))
(defn forward-to-key [bb stop]
(loop []
(if (or (= stop (.position bb))
(and (encoding/key-byte? (.get bb (.position bb)))
(encoding/offset-byte?
(.get bb (inc (.position bb))))))
bb
(do (.position bb (inc (.position bb)))
(recur)))))
(defn find-key-in-index (defn find-key-in-index
[bb target-key max-address not-found] [^java.nio.ByteBuffer bb
^Integer target-key
^Integer max-address
not-found]
(.limit bb max-address) (.limit bb max-address)
(let [key (let [key
(loop [previous-key nil (loop [previous-key nil
@ -100,9 +61,9 @@
max-position max-address] max-position max-address]
(if (zero? (- max-position min-position)) (if (zero? (- max-position min-position))
not-found not-found
(let [mid-position (+ min-position (quot (- max-position min-position) 2))] (let [^Integer mid-position (+ min-position (quot (- max-position min-position) 2))]
(.position bb mid-position) (.position bb mid-position)
(let [bb (rewind-to-key bb min-position) (let [^java.nio.ByteBuffer bb (rewind-to-key bb min-position)
current-key current-key
(encoding/decode-number-from-tightly-packed-trie-index bb)] (encoding/decode-number-from-tightly-packed-trie-index bb)]
(cond (cond
@ -144,14 +105,13 @@
{:id value {:id value
:count freq})) :count freq}))
(defn -value [trie value-decode-fn] (declare -value)
(wrap-byte-buffer
(.byte-buffer trie)
(.limit (.byte-buffer trie) (.limit trie))
(.position (.byte-buffer trie) (.address trie))
(value-decode-fn (.byte-buffer trie))))
(deftype TightlyPackedTrie [byte-buffer key address limit value-decode-fn] (deftype TightlyPackedTrie [^java.nio.ByteBuffer byte-buffer
^Integer key
^Integer address
^Integer limit
value-decode-fn]
trie/ITrie trie/ITrie
(lookup [self ks] (lookup [self ks]
(wrap-byte-buffer (wrap-byte-buffer
@ -183,7 +143,7 @@
(.position byte-buffer address) (.position byte-buffer address)
(let [val (value-decode-fn byte-buffer) (let [val (value-decode-fn byte-buffer)
size-of-index (encoding/decode byte-buffer)] size-of-index (encoding/decode byte-buffer)]
(.limit byte-buffer (+ (.position byte-buffer) (.limit byte-buffer ^Integer (+ (.position byte-buffer)
size-of-index)) size-of-index))
(loop [children []] (loop [children []]
(if (= (.position byte-buffer) (.limit byte-buffer)) (if (= (.position byte-buffer) (.limit byte-buffer))
@ -202,7 +162,7 @@
clojure.lang.ILookup clojure.lang.ILookup
(valAt [self ks] (valAt [self ks]
(if-let [node (trie/lookup self ks)] (if-let [^TightlyPackedTrie node (trie/lookup self ks)]
(-value node value-decode-fn) (-value node value-decode-fn)
nil)) nil))
(valAt [self ks not-found] (valAt [self ks not-found]
@ -214,7 +174,9 @@
clojure.lang.Seqable clojure.lang.Seqable
(seq [trie] (seq [trie]
(let [step (fn step [path [[node & nodes] & stack] [parent & parents]] (let [step (fn step [^clojure.lang.PersistentList path
[[^TightlyPackedTrie node & nodes] & stack]
[^TightlyPackedTrie parent & parents]]
(cond (cond
node node
(step (conj path (.key node)) (step (conj path (.key node))
@ -225,11 +187,11 @@
(lazy-seq (lazy-seq
(cons (clojure.lang.MapEntry. (cons (clojure.lang.MapEntry.
(rest path) (rest path)
(let [byte-buffer (.byte-buffer parent)] (let [^java.nio.ByteBuffer byte-buffer (.byte-buffer parent)]
(wrap-byte-buffer (wrap-byte-buffer
byte-buffer byte-buffer
(.limit byte-buffer (.limit parent)) (.limit byte-buffer ^Integer (.limit parent))
(.position byte-buffer (.address parent)) (.position byte-buffer ^Integer (.address parent))
(value-decode-fn byte-buffer)))) (value-decode-fn byte-buffer))))
(step (pop path) (step (pop path)
stack stack
@ -237,6 +199,14 @@
:else nil))] :else nil))]
(step [] (list (list trie)) '())))) (step [] (list (list trie)) '()))))
(defn -value [^TightlyPackedTrie trie value-decode-fn]
(let [^java.nio.ByteBuffer byte-buffer (.byte-buffer trie)]
(wrap-byte-buffer
byte-buffer
(.limit byte-buffer ^Integer (.limit trie))
(.position byte-buffer ^Integer (.address trie))
(value-decode-fn byte-buffer))))
(defmethod print-method TightlyPackedTrie [trie ^java.io.Writer w] (defmethod print-method TightlyPackedTrie [trie ^java.io.Writer w]
(print-method (into {} trie) w)) (print-method (into {} trie) w))
@ -244,17 +214,19 @@
(print-ctor trie (fn [o w] (print-dup (into {} trie) w)) w)) (print-ctor trie (fn [o w] (print-dup (into {} trie) w)) w))
(defn tightly-packed-trie (defn tightly-packed-trie
[trie value-encode-fn value-decode-fn] [^TightlyPackedTrie trie
(let [baos (ByteArrayOutputStream.)] value-encode-fn
value-decode-fn]
(let [^ByteArrayOutputStream baos (ByteArrayOutputStream.)]
(loop [nodes (seq trie) (loop [nodes (seq trie)
current-offset 8 current-offset 8
previous-depth 0 previous-depth 0
child-indexes []] child-indexes []]
(let [current-node (first nodes) (let [^TightlyPackedTrie current-node (first nodes)
current-depth (count (first current-node))] current-depth (count (first current-node))]
(cond (cond
(empty? nodes) (empty? nodes)
(let [child-index (last child-indexes) (let [^clojure.lang.PersistentVector child-index (last child-indexes)
child-index-baos (ByteArrayOutputStream.) child-index-baos (ByteArrayOutputStream.)
_ (->> child-index _ (->> child-index
(run! (run!
@ -269,12 +241,12 @@
child-index-byte-array (.toByteArray child-index-baos) child-index-byte-array (.toByteArray child-index-baos)
size-of-child-index (encoding/encode (count child-index-byte-array)) size-of-child-index (encoding/encode (count child-index-byte-array))
root-address current-offset root-address current-offset
value (value-encode-fn 0)] value #^bytes (value-encode-fn 0)]
(.write baos value) (.write baos value)
(.write baos size-of-child-index) (.write baos size-of-child-index)
(.write baos child-index-byte-array) (.write baos child-index-byte-array)
(let [ba (.toByteArray baos) (let [#^bytes ba (.toByteArray baos)
byte-buf (java.nio.ByteBuffer/allocate (+ 8 (count ba)))] ^java.nio.ByteBuffer byte-buf (java.nio.ByteBuffer/allocate (+ 8 (count ba)))]
(do (.putLong byte-buf root-address) (do (.putLong byte-buf root-address)
(.put byte-buf ba) (.put byte-buf ba)
(.rewind byte-buf) (.rewind byte-buf)
@ -289,7 +261,7 @@
;; Process index of children. ;; Process index of children.
(> previous-depth current-depth) (> previous-depth current-depth)
(let [[k v] (first nodes) (let [[k v] (first nodes)
value (value-encode-fn v) value #^bytes (value-encode-fn v)
child-index (last child-indexes) child-index (last child-indexes)
child-index-baos (ByteArrayOutputStream.) child-index-baos (ByteArrayOutputStream.)
_ (->> child-index _ (->> child-index
@ -322,7 +294,7 @@
;; Start keeping track of new children index ;; Start keeping track of new children index
:else :else
(let [[k v] (first nodes) (let [[k v] (first nodes)
value (value-encode-fn v) value #^bytes (value-encode-fn v)
size-of-child-index (encoding/encode 0) size-of-child-index (encoding/encode 0)
child-indexes (into child-indexes child-indexes (into child-indexes
(vec (repeat (- current-depth previous-depth) []))) (vec (repeat (- current-depth previous-depth) [])))
@ -350,7 +322,7 @@
(with-open [i (io/input-stream filepath) (with-open [i (io/input-stream filepath)
baos (ByteArrayOutputStream.)] baos (ByteArrayOutputStream.)]
(io/copy i baos) (io/copy i baos)
(let [byte-buffer (java.nio.ByteBuffer/wrap (.toByteArray baos))] (let [^java.nio.ByteBuffer byte-buffer (java.nio.ByteBuffer/wrap (.toByteArray baos))]
(.rewind byte-buffer) (.rewind byte-buffer)
(->TightlyPackedTrie (->TightlyPackedTrie
byte-buffer byte-buffer

@ -95,12 +95,12 @@
-> 0001010 1001011 -> 0001010 1001011
-> 00010101001011 -> 00010101001011
-> 1355 (As a long...)" -> 1355 (As a long...)"
[num-significant-bits & bytes] (#^bytes [num-significant-bits & bytes]
(reduce (reduce
(fn [a b] (fn [a b]
(bit-or b (bit-shift-left a num-significant-bits))) (bit-or b (bit-shift-left a num-significant-bits)))
0 0
bytes)) bytes)))
(comment (comment
(let [b1 (bitstring->int "0110110") (let [b1 (bitstring->int "0110110")

@ -30,12 +30,12 @@
To decode: if the flag bit is not set, read the next byte and To decode: if the flag bit is not set, read the next byte and
concat the last 7 bits of the current byte to concat the last 7 bits of the current byte to
the last 7 bits of the next byte." the last 7 bits of the next byte."
[n] (#^bytes [n]
(loop [b (list (bit-set (mod n 0x80) 7)) (loop [b (list (bit-set (mod n 0x80) 7))
n (quot n 0x80)] n (quot n 0x80)]
(if (zero? n) (if (zero? n)
(byte-array b) (byte-array b)
(recur (cons (mod n 0x80) b) (quot n 0x80))))) (recur (cons (mod n 0x80) b) (quot n 0x80))))))
(comment (comment
(->> [0 1 2 127 128 129] (->> [0 1 2 127 128 129]
@ -52,8 +52,8 @@
(defn decode (defn decode
"Decode one variable-length-encoded number from a ByteBuffer, "Decode one variable-length-encoded number from a ByteBuffer,
advancing the buffer's position to the byte following the encoded number." advancing the buffer's position to the byte following the encoded number."
[byte-buffer] ^Integer [^java.nio.ByteBuffer byte-buffer]
(loop [bytes (list (.get byte-buffer))] (loop [bytes (list ^Byte (.get byte-buffer))]
(if (bit-test (first bytes) 7) (if (bit-test (first bytes) 7)
(->> (cons (bit-clear (first bytes) 7) (rest bytes)) (->> (cons (bit-clear (first bytes) 7) (rest bytes))
reverse reverse
@ -75,15 +75,15 @@
(def offset-byte? (complement key-byte?)) (def offset-byte? (complement key-byte?))
(defn encode-key-to-tightly-packed-trie-index (defn encode-key-to-tightly-packed-trie-index
[n] #^bytes [n]
(->> n encode (map #(bit-set % 7)) byte-array)) (->> n encode (map #(bit-set % 7)) byte-array))
(defn encode-offset-to-tightly-packed-trie-index (defn encode-offset-to-tightly-packed-trie-index
[n] #^bytes [n]
(->> n encode (map #(bit-clear % 7)) byte-array)) (->> n encode (map #(bit-clear % 7)) byte-array))
(defn decode-number-from-tightly-packed-trie-index (defn decode-number-from-tightly-packed-trie-index
([byte-buffer] [^java.nio.ByteBuffer byte-buffer]
(let [first-byte (.get byte-buffer) (let [first-byte (.get byte-buffer)
continue? (fn [] continue? (fn []
(and (.hasRemaining byte-buffer) (and (.hasRemaining byte-buffer)
@ -95,9 +95,7 @@
(->> bytes (->> bytes
(map (partial bit-and 0xFF)) (map (partial bit-and 0xFF))
(map #(bit-clear % 7)) (map #(bit-clear % 7))
(apply (partial bm/combine-significant-bits 7)))))))) (apply (partial bm/combine-significant-bits 7)))))))
(bm/to-binary-string 0xff)
(comment (comment
(let [byte-buffer (java.nio.ByteBuffer/allocate 64)] (let [byte-buffer (java.nio.ByteBuffer/allocate 64)]

@ -1,37 +1,16 @@
(ns com.owoga.trie) (ns com.owoga.trie)
(declare ->Trie) (declare -without)
(defn -without
[trie [k & ks]]
(if k
(if-let [next-trie (get (.children- trie) k)]
(let [next-trie-without (-without next-trie ks)
new-trie (->Trie (.key trie)
(.value trie)
(if next-trie-without
(assoc (.children- trie) k next-trie-without)
(dissoc (.children- trie) k)))]
(if (and (empty? new-trie)
(nil? (.value new-trie)))
nil
new-trie)))
(if (seq (.children- trie))
(->Trie
(.key trie)
nil
(.children- trie))
nil)))
(defprotocol ITrie (defprotocol ITrie
(children [self] "Immediate children of a node.") (children [self] "Immediate children of a node.")
(lookup [self ks] "Return node at key.")) (lookup [self ^clojure.lang.PersistentList ks] "Return node at key."))
(deftype Trie [key value children-] (deftype Trie [key value ^clojure.lang.PersistentTreeMap children-]
ITrie ITrie
(children [trie] (children [trie]
(map (map
(fn [[k child]] (fn [[k ^Trie child]]
(Trie. k (Trie. k
(.value child) (.value child)
(.children- child))) (.children- child)))
@ -54,7 +33,7 @@
clojure.lang.ILookup clojure.lang.ILookup
(valAt [trie k] (valAt [trie k]
(if-let [node (lookup trie k)] (if-let [^Trie node (lookup trie k)]
(.value node) (.value node)
nil)) nil))
@ -65,18 +44,18 @@
(cons [trie entry] (cons [trie entry]
(cond (cond
(instance? Trie (second entry)) (instance? Trie (second entry))
(assoc trie (first entry) (.value (second entry))) (assoc trie (first entry) (.value ^Trie (second entry)))
:else :else
(assoc trie (first entry) (second entry)))) (assoc trie (first entry) (second entry))))
(empty [trie] (empty [trie]
(Trie. key nil (sorted-map))) (Trie. key nil (sorted-map)))
(equiv [trie o] (equiv [trie o]
(and (= (.value trie) (and (= (.value trie)
(.value o)) (.value ^Trie o))
(= (.children- trie) (= (.children- trie)
(.children- o)) (.children- ^Trie o))
(= (.key trie) (= (.key trie)
(.key o)))) (.key ^Trie o))))
clojure.lang.Associative clojure.lang.Associative
(assoc [trie opath ovalue] (assoc [trie opath ovalue]
@ -103,15 +82,15 @@
java.lang.Iterable java.lang.Iterable
(iterator [trie] (iterator [trie]
(.iterator (seq trie))) (.iterator ^clojure.lang.LazySeq (seq trie)))
clojure.lang.Counted clojure.lang.Counted
(count [trie] (count [trie]
(count (seq trie))) (count (seq trie)))
clojure.lang.Seqable clojure.lang.Seqable
(seq [trie] (seq ^clojure.lang.LazySeq [trie]
(let [step (fn step [path [[node & nodes] & stack] [parent & parents]] (let [step (fn step [path [[^Trie node & nodes] & stack] [^Trie parent & parents]]
(cond (cond
node node
(step (conj path (.key node)) (step (conj path (.key node))
@ -129,6 +108,27 @@
:else nil))] :else nil))]
(step [] (list (list trie)) '())))) (step [] (list (list trie)) '()))))
(defn -without
[^Trie trie [k & ks]]
(if k
(if-let [next-trie (get (.children- trie) k)]
(let [next-trie-without (-without next-trie ks)
^Trie new-trie (->Trie (.key trie)
(.value trie)
(if next-trie-without
(assoc (.children- trie) k next-trie-without)
(dissoc (.children- trie) k)))]
(if (and (empty? new-trie)
(nil? (.value new-trie)))
nil
new-trie)))
(if (seq (.children- trie))
(->Trie
(.key trie)
nil
(.children- trie))
nil)))
(defmethod print-method Trie [trie ^java.io.Writer w] (defmethod print-method Trie [trie ^java.io.Writer w]
(print-method (into {} trie) w)) (print-method (into {} trie) w))

@ -6,14 +6,16 @@
[com.owoga.tightly-packed-trie.encoding :as encoding] [com.owoga.tightly-packed-trie.encoding :as encoding]
[com.owoga.tightly-packed-trie.bit-manip :as bm])) [com.owoga.tightly-packed-trie.bit-manip :as bm]))
(defn value-encode-fn [v] (set! *warn-on-reflection* true)
(if (or (= v ::tpt/root)
(nil? v)) (defn value-encode-fn
(#^bytes [^Integer v]
(if (nil? v)
(encode/encode 0) (encode/encode 0)
(encode/encode v))) (encode/encode v))))
(defn value-decode-fn [byte-buffer] (defn value-decode-fn [^java.nio.ByteBuffer byte-buffer]
(let [v (encode/decode byte-buffer)] (let [^Integer v (encode/decode byte-buffer)]
(if (zero? v) (if (zero? v)
nil nil
v))) v)))

Loading…
Cancel
Save