Performance gains with type hints

main
Eric Ihli 3 years ago
parent f57ea59c22
commit 2356df1095

@ -4,7 +4,7 @@
<packaging>jar</packaging>
<groupId>com.owoga</groupId>
<artifactId>tightly-packed-trie</artifactId>
<version>0.2.2</version>
<version>0.2.3</version>
<name>tightly-packed-trie</name>
<scm>
<connection>scm:git:git://github.com/eihli/clj-tightly-packed-trie.git</connection>

@ -3,8 +3,7 @@
[com.owoga.tightly-packed-trie.encoding :as encoding]
[clojure.java.io :as io]
[clojure.string :as string]
[com.owoga.tightly-packed-trie.bit-manip :as bm]
[clojure.zip :as zip])
[com.owoga.tightly-packed-trie.bit-manip :as bm])
(:import (java.io ByteArrayOutputStream ByteArrayInputStream
DataOutputStream DataInputStream)))
@ -38,39 +37,8 @@
(.limit ~byte-buffer original-limit#)
(.position ~byte-buffer original-position#)))))
(defn -trie->depth-first-post-order-traversable-zipperable-vector
[path node decode-value-fn]
(vec
(map
(fn [child]
[(-trie->depth-first-post-order-traversable-zipperable-vector
(conj path (.key child))
child
decode-value-fn)
(wrap-byte-buffer
(.byte-buffer child)
(.limit (.byte-buffer child) (.limit child))
(.position (.byte-buffer child) (.address child))
(clojure.lang.MapEntry.
(conj path (.key child))
(decode-value-fn (.byte-buffer child))))])
(trie/children node))))
(defn trie->depth-first-post-order-traversable-zipperable-vector
[path node decode-value-fn]
(let [byte-buffer (.byte-buffer node)
val (wrap-byte-buffer
byte-buffer
(.limit byte-buffer (.limit node))
(.position byte-buffer (.address node))
(decode-value-fn byte-buffer))]
[(-trie->depth-first-post-order-traversable-zipperable-vector
path
node
decode-value-fn)
(clojure.lang.MapEntry. path val)]))
(defn rewind-to-key [bb stop]
(defn rewind-to-key [^java.nio.ByteBuffer bb
^Integer stop]
(loop []
(let [current (.get bb (.position bb))
previous (.get bb (dec (.position bb)))]
@ -81,18 +49,11 @@
(do (.position bb (dec (.position bb)))
(recur))))))
(defn forward-to-key [bb stop]
(loop []
(if (or (= stop (.position bb))
(and (encoding/key-byte? (.get bb (.position bb)))
(encoding/offset-byte?
(.get bb (inc (.position bb))))))
bb
(do (.position bb (inc (.position bb)))
(recur)))))
(defn find-key-in-index
[bb target-key max-address not-found]
[^java.nio.ByteBuffer bb
^Integer target-key
^Integer max-address
not-found]
(.limit bb max-address)
(let [key
(loop [previous-key nil
@ -100,9 +61,9 @@
max-position max-address]
(if (zero? (- max-position min-position))
not-found
(let [mid-position (+ min-position (quot (- max-position min-position) 2))]
(let [^Integer mid-position (+ min-position (quot (- max-position min-position) 2))]
(.position bb mid-position)
(let [bb (rewind-to-key bb min-position)
(let [^java.nio.ByteBuffer bb (rewind-to-key bb min-position)
current-key
(encoding/decode-number-from-tightly-packed-trie-index bb)]
(cond
@ -144,14 +105,13 @@
{:id value
:count freq}))
(defn -value [trie value-decode-fn]
(wrap-byte-buffer
(.byte-buffer trie)
(.limit (.byte-buffer trie) (.limit trie))
(.position (.byte-buffer trie) (.address trie))
(value-decode-fn (.byte-buffer trie))))
(declare -value)
(deftype TightlyPackedTrie [byte-buffer key address limit value-decode-fn]
(deftype TightlyPackedTrie [^java.nio.ByteBuffer byte-buffer
^Integer key
^Integer address
^Integer limit
value-decode-fn]
trie/ITrie
(lookup [self ks]
(wrap-byte-buffer
@ -183,8 +143,8 @@
(.position byte-buffer address)
(let [val (value-decode-fn byte-buffer)
size-of-index (encoding/decode byte-buffer)]
(.limit byte-buffer (+ (.position byte-buffer)
size-of-index))
(.limit byte-buffer ^Integer (+ (.position byte-buffer)
size-of-index))
(loop [children []]
(if (= (.position byte-buffer) (.limit byte-buffer))
children
@ -202,7 +162,7 @@
clojure.lang.ILookup
(valAt [self ks]
(if-let [node (trie/lookup self ks)]
(if-let [^TightlyPackedTrie node (trie/lookup self ks)]
(-value node value-decode-fn)
nil))
(valAt [self ks not-found]
@ -214,7 +174,9 @@
clojure.lang.Seqable
(seq [trie]
(let [step (fn step [path [[node & nodes] & stack] [parent & parents]]
(let [step (fn step [^clojure.lang.PersistentList path
[[^TightlyPackedTrie node & nodes] & stack]
[^TightlyPackedTrie parent & parents]]
(cond
node
(step (conj path (.key node))
@ -225,11 +187,11 @@
(lazy-seq
(cons (clojure.lang.MapEntry.
(rest path)
(let [byte-buffer (.byte-buffer parent)]
(let [^java.nio.ByteBuffer byte-buffer (.byte-buffer parent)]
(wrap-byte-buffer
byte-buffer
(.limit byte-buffer (.limit parent))
(.position byte-buffer (.address parent))
(.limit byte-buffer ^Integer (.limit parent))
(.position byte-buffer ^Integer (.address parent))
(value-decode-fn byte-buffer))))
(step (pop path)
stack
@ -237,6 +199,14 @@
:else nil))]
(step [] (list (list trie)) '()))))
(defn -value [^TightlyPackedTrie trie value-decode-fn]
(let [^java.nio.ByteBuffer byte-buffer (.byte-buffer trie)]
(wrap-byte-buffer
byte-buffer
(.limit byte-buffer ^Integer (.limit trie))
(.position byte-buffer ^Integer (.address trie))
(value-decode-fn byte-buffer))))
(defmethod print-method TightlyPackedTrie [trie ^java.io.Writer w]
(print-method (into {} trie) w))
@ -244,17 +214,19 @@
(print-ctor trie (fn [o w] (print-dup (into {} trie) w)) w))
(defn tightly-packed-trie
[trie value-encode-fn value-decode-fn]
(let [baos (ByteArrayOutputStream.)]
[^TightlyPackedTrie trie
value-encode-fn
value-decode-fn]
(let [^ByteArrayOutputStream baos (ByteArrayOutputStream.)]
(loop [nodes (seq trie)
current-offset 8
previous-depth 0
child-indexes []]
(let [current-node (first nodes)
(let [^TightlyPackedTrie current-node (first nodes)
current-depth (count (first current-node))]
(cond
(empty? nodes)
(let [child-index (last child-indexes)
(let [^clojure.lang.PersistentVector child-index (last child-indexes)
child-index-baos (ByteArrayOutputStream.)
_ (->> child-index
(run!
@ -269,12 +241,12 @@
child-index-byte-array (.toByteArray child-index-baos)
size-of-child-index (encoding/encode (count child-index-byte-array))
root-address current-offset
value (value-encode-fn 0)]
value #^bytes (value-encode-fn 0)]
(.write baos value)
(.write baos size-of-child-index)
(.write baos child-index-byte-array)
(let [ba (.toByteArray baos)
byte-buf (java.nio.ByteBuffer/allocate (+ 8 (count ba)))]
(let [#^bytes ba (.toByteArray baos)
^java.nio.ByteBuffer byte-buf (java.nio.ByteBuffer/allocate (+ 8 (count ba)))]
(do (.putLong byte-buf root-address)
(.put byte-buf ba)
(.rewind byte-buf)
@ -289,7 +261,7 @@
;; Process index of children.
(> previous-depth current-depth)
(let [[k v] (first nodes)
value (value-encode-fn v)
value #^bytes (value-encode-fn v)
child-index (last child-indexes)
child-index-baos (ByteArrayOutputStream.)
_ (->> child-index
@ -322,7 +294,7 @@
;; Start keeping track of new children index
:else
(let [[k v] (first nodes)
value (value-encode-fn v)
value #^bytes (value-encode-fn v)
size-of-child-index (encoding/encode 0)
child-indexes (into child-indexes
(vec (repeat (- current-depth previous-depth) [])))
@ -350,7 +322,7 @@
(with-open [i (io/input-stream filepath)
baos (ByteArrayOutputStream.)]
(io/copy i baos)
(let [byte-buffer (java.nio.ByteBuffer/wrap (.toByteArray baos))]
(let [^java.nio.ByteBuffer byte-buffer (java.nio.ByteBuffer/wrap (.toByteArray baos))]
(.rewind byte-buffer)
(->TightlyPackedTrie
byte-buffer

@ -95,12 +95,12 @@
-> 0001010 1001011
-> 00010101001011
-> 1355 (As a long...)"
[num-significant-bits & bytes]
(reduce
(fn [a b]
(bit-or b (bit-shift-left a num-significant-bits)))
0
bytes))
(#^bytes [num-significant-bits & bytes]
(reduce
(fn [a b]
(bit-or b (bit-shift-left a num-significant-bits)))
0
bytes)))
(comment
(let [b1 (bitstring->int "0110110")

@ -30,12 +30,12 @@
To decode: if the flag bit is not set, read the next byte and
concat the last 7 bits of the current byte to
the last 7 bits of the next byte."
[n]
(loop [b (list (bit-set (mod n 0x80) 7))
n (quot n 0x80)]
(if (zero? n)
(byte-array b)
(recur (cons (mod n 0x80) b) (quot n 0x80)))))
(#^bytes [n]
(loop [b (list (bit-set (mod n 0x80) 7))
n (quot n 0x80)]
(if (zero? n)
(byte-array b)
(recur (cons (mod n 0x80) b) (quot n 0x80))))))
(comment
(->> [0 1 2 127 128 129]
@ -52,8 +52,8 @@
(defn decode
"Decode one variable-length-encoded number from a ByteBuffer,
advancing the buffer's position to the byte following the encoded number."
[byte-buffer]
(loop [bytes (list (.get byte-buffer))]
^Integer [^java.nio.ByteBuffer byte-buffer]
(loop [bytes (list ^Byte (.get byte-buffer))]
(if (bit-test (first bytes) 7)
(->> (cons (bit-clear (first bytes) 7) (rest bytes))
reverse
@ -75,29 +75,27 @@
(def offset-byte? (complement key-byte?))
(defn encode-key-to-tightly-packed-trie-index
[n]
#^bytes [n]
(->> n encode (map #(bit-set % 7)) byte-array))
(defn encode-offset-to-tightly-packed-trie-index
[n]
#^bytes [n]
(->> n encode (map #(bit-clear % 7)) byte-array))
(defn decode-number-from-tightly-packed-trie-index
([byte-buffer]
(let [first-byte (.get byte-buffer)
continue? (fn []
(and (.hasRemaining byte-buffer)
(= (key-byte? (.get byte-buffer (.position byte-buffer)))
(key-byte? first-byte))))]
(loop [bytes [first-byte]]
(if (continue?)
(recur (conj bytes (.get byte-buffer)))
(->> bytes
(map (partial bit-and 0xFF))
(map #(bit-clear % 7))
(apply (partial bm/combine-significant-bits 7))))))))
(bm/to-binary-string 0xff)
[^java.nio.ByteBuffer byte-buffer]
(let [first-byte (.get byte-buffer)
continue? (fn []
(and (.hasRemaining byte-buffer)
(= (key-byte? (.get byte-buffer (.position byte-buffer)))
(key-byte? first-byte))))]
(loop [bytes [first-byte]]
(if (continue?)
(recur (conj bytes (.get byte-buffer)))
(->> bytes
(map (partial bit-and 0xFF))
(map #(bit-clear % 7))
(apply (partial bm/combine-significant-bits 7)))))))
(comment
(let [byte-buffer (java.nio.ByteBuffer/allocate 64)]

@ -1,37 +1,16 @@
(ns com.owoga.trie)
(declare ->Trie)
(defn -without
[trie [k & ks]]
(if k
(if-let [next-trie (get (.children- trie) k)]
(let [next-trie-without (-without next-trie ks)
new-trie (->Trie (.key trie)
(.value trie)
(if next-trie-without
(assoc (.children- trie) k next-trie-without)
(dissoc (.children- trie) k)))]
(if (and (empty? new-trie)
(nil? (.value new-trie)))
nil
new-trie)))
(if (seq (.children- trie))
(->Trie
(.key trie)
nil
(.children- trie))
nil)))
(declare -without)
(defprotocol ITrie
(children [self] "Immediate children of a node.")
(lookup [self ks] "Return node at key."))
(lookup [self ^clojure.lang.PersistentList ks] "Return node at key."))
(deftype Trie [key value children-]
(deftype Trie [key value ^clojure.lang.PersistentTreeMap children-]
ITrie
(children [trie]
(map
(fn [[k child]]
(fn [[k ^Trie child]]
(Trie. k
(.value child)
(.children- child)))
@ -54,7 +33,7 @@
clojure.lang.ILookup
(valAt [trie k]
(if-let [node (lookup trie k)]
(if-let [^Trie node (lookup trie k)]
(.value node)
nil))
@ -65,18 +44,18 @@
(cons [trie entry]
(cond
(instance? Trie (second entry))
(assoc trie (first entry) (.value (second entry)))
(assoc trie (first entry) (.value ^Trie (second entry)))
:else
(assoc trie (first entry) (second entry))))
(empty [trie]
(Trie. key nil (sorted-map)))
(equiv [trie o]
(and (= (.value trie)
(.value o))
(.value ^Trie o))
(= (.children- trie)
(.children- o))
(.children- ^Trie o))
(= (.key trie)
(.key o))))
(.key ^Trie o))))
clojure.lang.Associative
(assoc [trie opath ovalue]
@ -103,15 +82,15 @@
java.lang.Iterable
(iterator [trie]
(.iterator (seq trie)))
(.iterator ^clojure.lang.LazySeq (seq trie)))
clojure.lang.Counted
(count [trie]
(count (seq trie)))
clojure.lang.Seqable
(seq [trie]
(let [step (fn step [path [[node & nodes] & stack] [parent & parents]]
(seq ^clojure.lang.LazySeq [trie]
(let [step (fn step [path [[^Trie node & nodes] & stack] [^Trie parent & parents]]
(cond
node
(step (conj path (.key node))
@ -129,6 +108,27 @@
:else nil))]
(step [] (list (list trie)) '()))))
(defn -without
[^Trie trie [k & ks]]
(if k
(if-let [next-trie (get (.children- trie) k)]
(let [next-trie-without (-without next-trie ks)
^Trie new-trie (->Trie (.key trie)
(.value trie)
(if next-trie-without
(assoc (.children- trie) k next-trie-without)
(dissoc (.children- trie) k)))]
(if (and (empty? new-trie)
(nil? (.value new-trie)))
nil
new-trie)))
(if (seq (.children- trie))
(->Trie
(.key trie)
nil
(.children- trie))
nil)))
(defmethod print-method Trie [trie ^java.io.Writer w]
(print-method (into {} trie) w))

@ -6,14 +6,16 @@
[com.owoga.tightly-packed-trie.encoding :as encoding]
[com.owoga.tightly-packed-trie.bit-manip :as bm]))
(defn value-encode-fn [v]
(if (or (= v ::tpt/root)
(nil? v))
(encode/encode 0)
(encode/encode v)))
(set! *warn-on-reflection* true)
(defn value-decode-fn [byte-buffer]
(let [v (encode/decode byte-buffer)]
(defn value-encode-fn
(#^bytes [^Integer v]
(if (nil? v)
(encode/encode 0)
(encode/encode v))))
(defn value-decode-fn [^java.nio.ByteBuffer byte-buffer]
(let [^Integer v (encode/decode byte-buffer)]
(if (zero? v)
nil
v)))

Loading…
Cancel
Save