From c1f20a8f8c78e687c3df9dc4fb16a4596bf6166a Mon Sep 17 00:00:00 2001 From: Eric Ihli Date: Mon, 12 Apr 2021 19:27:48 -0500 Subject: [PATCH] Working with tpt lib --- deps.edn | 4 +- dev/examples/tpt.clj | 159 +++++++++++++++++----------- src/com/owoga/prhyme/rhyme_trie.clj | 18 ++-- src/com/owoga/prhyme/util/math.clj | 57 +++++++++- 4 files changed, 161 insertions(+), 77 deletions(-) diff --git a/deps.edn b/deps.edn index 0d2a69c..d1e6de1 100644 --- a/deps.edn +++ b/deps.edn @@ -16,9 +16,9 @@ net.sf.sociaal/freetts {:mvn/version "1.2.2"} enlive/enlive {:mvn/version "1.1.6"} integrant/integrant {:mvn/version "0.8.0"} - com.owoga/tightly-packed-trie {:local/root "/home/eihli/code/tightly-packed-trie/TightlyPackedTrie.jar"} org.clojure/data.fressian {:mvn/version "1.0.0"} com.taoensso/nippy {:mvn/version "3.0.0"} - com.taoensso/timbre {:mvn/version "4.10.0"}} + com.taoensso/timbre {:mvn/version "4.10.0"} + com.owoga/tightly-packed-trie {:mvn/version "0.2.0"}} :aliases {:dev {:extra-paths ["test" "examples" "dev"] :extra-deps {}}}} diff --git a/dev/examples/tpt.clj b/dev/examples/tpt.clj index f94d24e..363bf5a 100644 --- a/dev/examples/tpt.clj +++ b/dev/examples/tpt.clj @@ -1,7 +1,10 @@ (ns examples.tpt (:require [clojure.string :as string] [clojure.java.io :as io] - [com.owoga.tightly-packed-trie.core :as tpt] + [com.owoga.trie :as trie] + [com.owoga.tightly-packed-trie :as tpt] + [com.owoga.trie.math :as math] + [com.owoga.tightly-packed-trie.encoding :as encoding] [com.owoga.prhyme.util :as util] [com.owoga.prhyme.data.dictionary :as dict] [clojure.zip :as zip])) @@ -49,8 +52,6 @@ (map second) (map string/lower-case))) -(def database (atom {})) - (defn process-files-for-trie "Expects an entire song, lines seperated by \n." [files] @@ -84,23 +85,6 @@ (map #(partition n 1 %)) (apply concat)))) -(defn make-trie - ([] (tpt/->Trie - (fn update-fn [prev cur] - (if (nil? prev) - {:value (last cur) - :count 1} - (-> prev - (update :count (fnil inc 0)) - (assoc :value (last cur))))) - (sorted-map))) - ([& ks] - (reduce - (fn [t k] - (conj t k)) - (make-trie) - ks))) - (defn n-to-m-grams "Exclusive of m, similar to range." [n m text] @@ -112,47 +96,37 @@ :else (recur (inc i) (cons (text->ngrams text i) r))))) - (defn prep-ngram-for-trie "The tpt/trie expects values conjed into an ngram to be of format '(k1 k2 k3 value)." [ngram] - (concat ngram (list ngram))) + (clojure.lang.MapEntry. (vec ngram) ngram)) (defn create-trie-from-texts [texts] (->> texts (map #(n-to-m-grams 1 4 %)) (apply concat) (map prep-ngram-for-trie) - (apply make-trie))) - -(defn trie->seq-of-nodes [trie] - (->> trie - tpt/as-vec - zip/vector-zip - (iterate zip/next) - (take-while (complement zip/end?)) - (map zip/node) - (filter map?))) + (reduce + (fn [trie [k v]] + (let [existing (or (get trie k) {:count 0 :value (last v)})] + (conj trie [k (update existing :count inc)]))) + (trie/make-trie)))) (defn seq-of-nodes->sorted-by-count "Sorted first by the rank of the ngram, lowest ranks first. Sorted second by the frequency of the ngram, highest frequencies first. This is the order that you'd populate a mapping of keys to IDs." - [nodes] - (->> nodes - (map (comp first seq)) - (map (fn [[k v]] - (vector (:value v) (:count v)))) - ;; root node and padded starts - (remove (comp nil? second)) - (sort-by #(vector (count (first %)) - (- (second %)))))) + [trie] + (->> trie + trie/children + (map #(get % [])) + (sort-by :count) + reverse)) (defn trie->database [trie] (let [sorted-keys - (->> (trie->seq-of-nodes trie) - seq-of-nodes->sorted-by-count)] + (seq-of-nodes->sorted-by-count trie)] (loop [sorted-keys sorted-keys database {} i 1] @@ -161,36 +135,55 @@ (recur (rest sorted-keys) (-> database - (assoc (first (first sorted-keys)) - {:count (second (first sorted-keys)) - :id i}) - (assoc i (first (first sorted-keys)))) + (assoc i {:count (:count (first sorted-keys)) + :value (:value (first sorted-keys))}) + (assoc (:value (first sorted-keys)) i)) (inc i)))))) -(defn transform-trie->ids [trie database] - (let [transform-p #(map? (zip/node %)) - transform-f - (fn tf [loc] - (zip/edit - loc - (fn [node] - (let [[k v] (first (seq node))] - {(get-in database [(list k) :id] (if (= k :root) :root)) - (assoc v :value (get-in database [(list k) :count] 0))}))))] - (tpt/transform trie (tpt/visitor-filter transform-p transform-f)))) - -(defonce trie +(def trie (let [texts (->> (dark-corpus-file-seq 500 500) (map slurp))] (create-trie-from-texts texts))) -(defonce trie-database +(def trie-database (trie->database trie)) +(defn encode-fn [v] + (let [{:keys [count value]} v] + (if (and (number? v) (not (zero? v))) + (byte-array + (concat (encoding/encode (trie-database value)) + (encoding/encode count))) + (encoding/encode 0)))) + +(defn decode-fn [byte-buffer] + (let [v (encoding/decode byte-buffer)] + (if (and (number? v) (zero? v)) + nil + (trie-database v)))) + +(comment + (def tight-ready-trie + (->> trie + (map (fn [[k v]] + (let [k (map #(get trie-database %) k)] + [k v]))) + (into (trie/make-trie)))) + ) + (def tightly-packed-trie - (let [trie-with-ids (transform-trie->ids trie trie-database) - tightly-packed-trie (tpt/tightly-packed-trie trie-with-ids)] - tightly-packed-trie)) + (let [tight-ready-trie + (->> trie + (map (fn [[k v]] + (let [k (map #(get trie-database %) k)] + [k v]))) + (into (trie/make-trie))) + tightly-packed-trie + (tpt/tightly-packed-trie + tight-ready-trie + encode-fn + decode-fn)] + tight-ready-trie)) (defn key-get-in-tpt [tpt db ks] (let [id (map #(get-in db [(list %) :id]) ks) @@ -205,6 +198,44 @@ {ks (assoc v :value (get db id))})) (comment + (trie/lookup tightly-packed-trie [1 28 9]) + + + (def example-story + (loop [generated-text [(get trie-database "")] + i 0] + (if (> i 100) + generated-text + (let [node (loop [i 3] + (let [node (trie/lookup + tightly-packed-trie + (vec (take-last i generated-text)))] + (cond + (nil? node) (recur (dec i)) + (< i 0) (throw (Exception. "Error")) + (seq (trie/children node)) node + :else (recur (dec i)))))] + (recur + (conj + generated-text + (->> node + trie/children + (map #(get % [])) + (remove nil?) + (math/weighted-selection :count) + :value + (get trie-database))) + (inc i)))))) + + (->> example-story + (map #(get-in trie-database [% :value])) + (concat) + (string/join " ") + (#(string/replace % #" ([\.,\?])" "$1")) + ((fn [txt] + (string/replace txt #"(^|\. |\? )([a-z])" (fn [[a b c]] + (str b (.toUpperCase c))))))) + (key-get-in-tpt tightly-packed-trie trie-database diff --git a/src/com/owoga/prhyme/rhyme_trie.clj b/src/com/owoga/prhyme/rhyme_trie.clj index b5189ba..c18bd9c 100644 --- a/src/com/owoga/prhyme/rhyme_trie.clj +++ b/src/com/owoga/prhyme/rhyme_trie.clj @@ -398,7 +398,7 @@ (assoc m k (apply f (get m k) args)))))] (up m ks f args)))) -(defprotocol ITrie +(defprotocol ITrieP (as-map [this] "Map that underlies trie.") (as-vec [this] "Depth-first post-order vector.") (as-byte-array [this] "Tightly-packed byte-array.") @@ -406,22 +406,22 @@ ;; Seq offers a depth-first post-order traversal ;; with children ordered by key. -(deftype Trie [trie] - ITrie +(deftype TrieP [trie] + ITrieP (as-map [_] trie) (as-vec [_] (map-trie->seq-trie trie)) (as-byte-array [self] (->> (transform self (visitor-filter #(map? (zip/node %)) pack-index)) as-vec vec-trie->map-trie - (Trie.))) + (TrieP.))) (transform [self f] (->> self as-vec zip/vector-zip (zip-visitor f) (vec-trie->map-trie) - (Trie.))) + (TrieP.))) clojure.lang.ILookup (valAt [_ k] @@ -443,7 +443,7 @@ (let [path (cons :root (interleave (repeat :children) (butlast o))) id (last o) node (get-in trie path)] - (Trie. + (TrieP. (update-in-sorted trie path @@ -454,13 +454,13 @@ (-> prev (assoc :value id) ; Assert value same? (update :count (fnil inc 0))))))))) - (empty [_] (Trie. {})) + (empty [_] (TrieP. {})) (equiv [_ o] - (and (isa? (class o) Trie) + (and (isa? (class o) TrieP) (= (as-map o) trie)))) (defn trie - ([] (->Trie (sorted-map))) + ([] (->TrieP (sorted-map))) ([& entries] (reduce (fn [t entry] diff --git a/src/com/owoga/prhyme/util/math.clj b/src/com/owoga/prhyme/util/math.clj index 4605335..dde5b41 100644 --- a/src/com/owoga/prhyme/util/math.clj +++ b/src/com/owoga/prhyme/util/math.clj @@ -1,10 +1,8 @@ ;; Fast weighted random selection thanks to the Vose algorithm. ;; https://gist.github.com/ghadishayban/a26cc402958ef3c7ce61 - (ns com.owoga.prhyme.util.math (:import clojure.lang.PersistentQueue)) - ;; Vose's alias method ;; http://www.keithschwarz.com/darts-dice-coins/ (defprotocol Rand @@ -338,3 +336,58 @@ (apply + sgts)]) ) + +(defn sgt-with-counts [rs nrs] + (assert (and (not-empty nrs) (not-empty rs)) + "frequencies and frequency-of-frequencies can't be empty") + (let [l (count rs) + N (apply + (map #(apply * %) (map vector rs nrs))) + p0 (/ (first nrs) N) + zrs (average-consecutives rs nrs) + log-rs (map #(Math/log %) rs) + log-zrs (map #(Math/log %) zrs) + lm (least-squares-linear-regression log-rs log-zrs) + lgts (map lm rs) + estimations (loop [coll rs + lgt? false + e (estimator lm rs zrs) + estimations []] + (cond + (empty? coll) estimations + :else + (let [[estimation lgt?] (e (first coll) lgt?)] + (recur + (rest coll) + lgt? + e + (conj estimations estimation))))) + N* (apply + (map #(apply * %) (map vector nrs estimations))) + probs (cons + (float p0) + (map #(* (- 1 p0) (/ % N*)) estimations)) + sum-probs (apply + probs)] + [(cons 0 rs) + (map #(/ % sum-probs) probs) + estimations + lgts])) + + + +(defn discount-coefficient-map + "The probability of an unseen (Nr0) n-gram is Nr1/N. + We then have to adjust the probability of Nr1 down from the maximum-likelihood + estimate of Nr1 (which was Nr1/N) to something else. + + The size of this adjustment is captured by the discount coefficient." + [frequency->frequency-of-frequency] + (let [[xs ys] ((juxt keys vals) frequency->frequency-of-frequency) + sgt (into (sorted-map) (apply map vector (sgt xs ys)))] + + (into + (sorted-map) + (map + (fn [[r nr nr*]] + [r (/ nr* nr)]) + (map vector xs ys (vals sgt)))))) + +(discount-coefficient-map )