You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

345 lines
13 KiB
Clojure

(ns com.owoga.tightly-packed-trie
(:require [com.owoga.trie :as trie]
[com.owoga.tightly-packed-trie.encoding :as encoding]
[clojure.java.io :as io]
[clojure.string :as string]
[com.owoga.tightly-packed-trie.bit-manip :as bm])
(:import (java.io ByteArrayOutputStream ByteArrayInputStream
DataOutputStream DataInputStream)))
#_(set! *warn-on-reflection* true)
;; A trie data structure that can be converted to
;; a contiguous array of bytes while maintaining
;; efficient lookups.
;;
;; A regular Clojure hash-map can be used as a trie,
;; but there's a lot of memory overhead that comes along
;; with hash-maps.
;;
;; To conveniently get the benefit of packing a trie into a contiguous array of
;; bytes, there are a few restrictions.
;;
;; Everything must be numeric IDs. Keys must be numeric IDs.
;; Values must be numeric IDs. Everything must be a number.
;;
;; This lets us encode everything as variable-length-encoded byte arrays.
;;
;; To maximize efficiency, your most common keys should have the
;; smallest IDs.
(defmacro wrap-byte-buffer
"Saves the position and limit of a byte buffer, runs body,
returns byte buffer to original position and limit."
[byte-buffer & body]
`(let [original-position# (.position ~byte-buffer)
original-limit# (.limit ~byte-buffer)]
(try (do ~@body)
(finally
(.limit ~byte-buffer original-limit#)
(.position ~byte-buffer original-position#)))))
(defn rewind-to-key [^java.nio.ByteBuffer bb
^Integer stop]
(loop []
(let [current (.get bb (.position bb))
previous (.get bb (dec (.position bb)))]
(if (or (= stop (.position bb))
(and (encoding/key-byte? current)
(encoding/offset-byte? previous)))
bb
(do (.position bb (dec (.position bb)))
(recur))))))
(defn find-key-in-index
[^java.nio.ByteBuffer bb
^Integer target-key
^Integer max-address
not-found]
(.limit bb max-address)
(let [key
(loop [previous-key nil
min-position (.position bb)
max-position max-address]
(if (zero? (- max-position min-position))
not-found
(let [^Integer mid-position (+ min-position (quot (- max-position min-position) 2))]
(.position bb mid-position)
(let [^java.nio.ByteBuffer bb (rewind-to-key bb min-position)
current-key
(encoding/decode-number-from-tightly-packed-trie-index bb)]
(cond
(= current-key target-key)
(encoding/decode-number-from-tightly-packed-trie-index bb)
(= current-key previous-key)
(do
(encoding/decode-number-from-tightly-packed-trie-index bb)
(let [final-key (encoding/decode-number-from-tightly-packed-trie-index bb)]
(if (= target-key final-key)
(encoding/decode-number-from-tightly-packed-trie-index bb)
(throw (Exception. "Key not found.")))))
(< current-key target-key)
;; Chew the next decoded number. It's a useless offset.
(do
(encoding/decode-number-from-tightly-packed-trie-index bb)
(recur
current-key
(.position bb)
max-position))
(> current-key target-key)
;; This could also be rewound.
(do
(rewind-to-key bb min-position)
(recur
current-key
min-position
(.position bb))))))))]
(.limit bb (.capacity bb))
key))
(defn tightly-packed-trie-node-value
[byte-buffer]
(let [value (encoding/decode byte-buffer)
freq (encoding/decode byte-buffer)]
{:id value
:count freq}))
(declare -value)
(declare children-)
(deftype TightlyPackedTrie [^java.nio.ByteBuffer byte-buffer
^Integer key
^Integer address
^Integer limit
value-decode-fn]
trie/ITrie
(lookup [self ks]
(wrap-byte-buffer
byte-buffer
(.limit byte-buffer limit)
(.position byte-buffer address)
(if (empty? ks)
self
(let [val (value-decode-fn byte-buffer)
size-of-index (encoding/decode byte-buffer)
offset (find-key-in-index
byte-buffer
(first ks)
(+ (.position byte-buffer) size-of-index)
:not-found)]
(if (= offset :not-found)
nil
(let [child (TightlyPackedTrie.
byte-buffer
(first ks)
(- address offset)
(.capacity byte-buffer)
value-decode-fn)]
(trie/lookup child (rest ks))))))))
(children [self]
(children- byte-buffer address limit value-decode-fn))
clojure.lang.ILookup
(valAt [self ks]
(if-let [^TightlyPackedTrie node (trie/lookup self ks)]
(-value node value-decode-fn)
nil))
(valAt [self ks not-found]
(or (get self ks) not-found))
clojure.lang.Counted
(count [trie]
(count (seq trie)))
clojure.lang.Seqable
(seq [trie]
(let [step (fn step [^clojure.lang.PersistentList path
[[^TightlyPackedTrie node & nodes] & stack]
[^TightlyPackedTrie parent & parents]]
(cond
node
(step (conj path (.key node))
(into (into stack (list nodes))
(list (trie/children node)))
(cons node (cons parent parents)))
(and parent (not= 0 (.key parent)))
(lazy-seq
(cons (clojure.lang.MapEntry.
(rest path)
(let [^java.nio.ByteBuffer byte-buffer (.byte-buffer parent)]
(wrap-byte-buffer
byte-buffer
(.limit byte-buffer ^Integer (.limit parent))
(.position byte-buffer ^Integer (.address parent))
(value-decode-fn byte-buffer))))
(step (pop path)
stack
parents)))
:else nil))]
(step [] (list (list trie)) '()))))
(defn -value [^TightlyPackedTrie trie value-decode-fn]
(let [^java.nio.ByteBuffer byte-buffer (.byte-buffer trie)]
(wrap-byte-buffer
byte-buffer
(.limit byte-buffer ^Integer (.limit trie))
(.position byte-buffer ^Integer (.address trie))
(value-decode-fn byte-buffer))))
(defn children- [^java.nio.ByteBuffer
byte-buffer
^Integer address
^Integer limit
value-decode-fn]
(wrap-byte-buffer
byte-buffer
(.limit byte-buffer limit)
(.position byte-buffer address)
(let [val (value-decode-fn byte-buffer)
size-of-index (encoding/decode byte-buffer)]
(.limit byte-buffer ^Integer (+ (.position byte-buffer)
size-of-index))
(loop [children []]
(if (= (.position byte-buffer) (.limit byte-buffer))
children
(let [child-key (encoding/decode-number-from-tightly-packed-trie-index byte-buffer)
child-offset (encoding/decode-number-from-tightly-packed-trie-index byte-buffer)]
(recur
(conj
children
(TightlyPackedTrie.
byte-buffer
child-key
(- address child-offset)
(.capacity byte-buffer)
value-decode-fn)))))))))
(defmethod print-method TightlyPackedTrie [trie ^java.io.Writer w]
(print-method (into {} trie) w))
(defmethod print-dup TightlyPackedTrie [trie ^java.io.Writer w]
(print-ctor trie (fn [o w] (print-dup (into {} trie) w)) w))
(defn tightly-packed-trie
[^TightlyPackedTrie trie
value-encode-fn
value-decode-fn]
(let [^ByteArrayOutputStream baos (ByteArrayOutputStream.)]
(loop [nodes (seq trie)
current-offset 8
previous-depth 0
child-indexes []]
(let [^trie/Trie current-node (first nodes)
current-depth (count (first current-node))]
(cond
(empty? nodes)
(let [^clojure.lang.PersistentVector child-index (last child-indexes)
child-index-baos (ByteArrayOutputStream.)
_ (->> child-index
(run!
(fn [[key offset]]
(.write
child-index-baos
(encoding/encode-key-to-tightly-packed-trie-index key))
(.write
child-index-baos
(encoding/encode-offset-to-tightly-packed-trie-index
(- current-offset offset))))))
child-index-byte-array (.toByteArray child-index-baos)
size-of-child-index (encoding/encode (count child-index-byte-array))
root-address current-offset
value #^bytes (value-encode-fn 0)]
(.write baos value)
(.write baos size-of-child-index)
(.write baos child-index-byte-array)
(let [#^bytes ba (.toByteArray baos)
^java.nio.ByteBuffer byte-buf (java.nio.ByteBuffer/allocate (+ 8 (count ba)))]
(do (.putLong byte-buf root-address)
(.put byte-buf ba)
(.rewind byte-buf)
(->TightlyPackedTrie
byte-buf
0
(.getLong byte-buf)
(.capacity byte-buf)
value-decode-fn))))
;; Gone up from depth to a parent.
;; Process index of children.
(> previous-depth current-depth)
(let [[k v] (first nodes)
value #^bytes (value-encode-fn v)
child-index (last child-indexes)
child-index-baos (ByteArrayOutputStream.)
_ (->> child-index
(run!
(fn [[key offset]]
(.write
child-index-baos
(encoding/encode-key-to-tightly-packed-trie-index key))
(.write
child-index-baos
(encoding/encode-offset-to-tightly-packed-trie-index
(- current-offset offset))))))
child-index-byte-array (.toByteArray child-index-baos)
size-of-child-index (encoding/encode (count child-index-byte-array))
current-index (last (pop child-indexes))]
(.write baos value)
(.write baos size-of-child-index)
(.write baos child-index-byte-array)
(recur (rest nodes)
(+ current-offset
(count value)
(count size-of-child-index)
(count child-index-byte-array))
current-depth
(conj (pop (pop child-indexes))
(conj current-index
[(last k)
current-offset]))))
;; Down or even in depth to children
;; Start keeping track of new children index
:else
(let [[k v] (first nodes)
value #^bytes (value-encode-fn v)
size-of-child-index (encoding/encode 0)
child-indexes (into child-indexes
(vec (repeat (- current-depth previous-depth) [])))
current-child-index (last child-indexes)]
(.write baos value)
(.write baos size-of-child-index)
(recur (rest nodes)
(+ current-offset
(count value)
(count size-of-child-index))
current-depth
(conj (pop child-indexes)
(conj current-child-index
[(last k)
current-offset])))))))))
;; TODO: Shared "save" interface for Trie?
(defn save-tightly-packed-trie-to-file
[filepath trie]
(with-open [o (java.io.FileOutputStream. filepath)]
(.rewind (.byte-buffer trie))
(.write (.getChannel o)
(.byte-buffer trie))))
(defn load-tightly-packed-trie-from-file
[filepath value-decode-fn]
(with-open [i (io/input-stream filepath)
baos (ByteArrayOutputStream.)]
(io/copy i baos)
(let [^java.nio.ByteBuffer byte-buffer (java.nio.ByteBuffer/wrap (.toByteArray baos))]
(.rewind byte-buffer)
(->TightlyPackedTrie
byte-buffer
0
(.getLong byte-buffer)
(.capacity byte-buffer)
value-decode-fn))))