diff --git a/freq_freqs.txt b/freq_freqs.txt new file mode 100644 index 0000000..0f4374e --- /dev/null +++ b/freq_freqs.txt @@ -0,0 +1,12 @@ +1 32 +2 20 +3 10 +4 3 +5 1 +6 2 +7 1 +8 1 +9 1 +10 2 +12 1 +26 1 diff --git a/src/com/owoga/prhyme/corpus/db.clj b/src/com/owoga/prhyme/corpus/db.clj index b63c25e..26ff7a9 100644 --- a/src/com/owoga/prhyme/corpus/db.clj +++ b/src/com/owoga/prhyme/corpus/db.clj @@ -1,2 +1,6 @@ (ns com.owoga.prhyme.corpus.db (:require [integrant.core :as ig])) + +(defn tokens->db + [tokens] + ) diff --git a/src/com/owoga/prhyme/data/tpt.clj b/src/com/owoga/prhyme/data/tpt.clj index b65a504..9cf7e4b 100644 --- a/src/com/owoga/prhyme/data/tpt.clj +++ b/src/com/owoga/prhyme/data/tpt.clj @@ -46,6 +46,37 @@ (bit-shift-left n 7)) (inc i)))))) +(defn vb-decode + ([ba] + (vb-decode ba 0)) + ([ba i] + (if (>= i (count ba)) + (cons (first (vb-decode-1 ba)) + nil) + (let [[value byte-count] (vb-decode-1 ba)] + (lazy-seq + (cons + value + (vb-decode (byte-array (drop byte-count ba)) + (+ i byte-count)))))))) + +(comment + (let [n1 0 + n2 1 + n3 127 + n4 128 + n5 257 + n6 9876543210 + baos (java.io.ByteArrayOutputStream.)] + (->> [n1 n2 n3 n4 n5 n6] + (map vb-encode) + (run! #(.writeBytes baos %))) + (let [ba (.toByteArray baos)] + (vb-decode ba))) + + ;; => ([0 1] [1 1] [127 1] [128 2] [257 2] [9876543210 5]) + ) + (def dictionary ["hi" "my" "name" "is" "what"]) (defn slurp-bytes [x] diff --git a/src/com/owoga/prhyme/generation/simple_good_turing.clj b/src/com/owoga/prhyme/generation/simple_good_turing.clj index 0381889..5276424 100644 --- a/src/com/owoga/prhyme/generation/simple_good_turing.clj +++ b/src/com/owoga/prhyme/generation/simple_good_turing.clj @@ -14,6 +14,10 @@ commas, periods, and newlines." #"(?s).*?([a-zA-Z\d]+(?:['\-]?[a-zA-Z]+)?|,|\.|\n)") +(defn pad-tokens + [tokens n] + (concat (repeat (min 1 (dec n)) "") tokens [""])) + (defn tokenize-line [line] (->> line diff --git a/src/com/owoga/prhyme/rhyme_trie.clj b/src/com/owoga/prhyme/rhyme_trie.clj new file mode 100644 index 0000000..51223d2 --- /dev/null +++ b/src/com/owoga/prhyme/rhyme_trie.clj @@ -0,0 +1,287 @@ +(ns com.owoga.prhyme.rhyme-trie + (:require [clojure.java.io :as io] + [clojure.walk :as walk] + [clojure.zip :as zip]) + (:import (java.io ByteArrayOutputStream ByteArrayInputStream + DataOutputStream DataInputStream))) + +(defn branch? [node] + (and (map? node) + (:children (first (vals node))))) + +(defn children [node] + (map (partial apply hash-map) (seq (:children (first (vals node)))))) + +(defn without-children [node] + {(first (keys node)) + (dissoc (get node (first (keys node))) :children)}) + +(defn map-trie->seq-trie + [trie] + [(vec (map map-trie->seq-trie (children trie))) + (without-children trie)]) + +(let [m {:root + {:children + {"T" + {:children + {"A" {:children + {"T" {:value "TAT", :count 1} + "S" {:value "SAT" :count 1}} + :value "AT" + :count 1}, + "U" {:children {"T" {:value "TUT", :count 1}}}}}}}}] + (let [z (zip/vector-zip (map-trie->seq-trie m))] + (->> z + (iterate zip/next) + (take-while (complement zip/end?)) + (map zip/node)))) + +(defn vec-trie->map-trie + [trie] + (let [children (first trie) + parent (second trie) + [parent-key parent-val] (first (seq parent))] + {parent-key (assoc parent-val :children (into {} (map vec-trie->map-trie children)))})) + +(comment + (let [vect [[[[[[[[] {"T" {:value "TAT", :count 1}}]] {"A" {:value "AT", :count 1}}] + [[[[] {"T" {:value "TUT", :count 1}}]] {"U" {}}]] + {"T" {}}]] + {:root {}}]] + (vec-trie->map-trie vect)) + + ) + +(comment + (let [v1 '("T" "A" "T" "TAT") + v2 '("T" "U" "T" "TUT") + v3 '("T" "A" "AT") + t1 (trie v1) + t2 (trie v2) + t3 (trie v1 v2 v3) + vect (as-vec t3)] + vect) + + ) + +(defn parent? + [node] + (and (vector? (first node)) + (map? (second node)))) + +(defn child-seq [loc] + (if (and (zip/left loc) + (zip/down (zip/left loc))) + ((fn inner [child] + (if child + (lazy-seq + (cons child + (inner (zip/right child)))) + nil)) + (->> loc zip/left zip/down)) + '())) + +(defn zip-visitor + ([visitor zipper] + (loop [zipper zipper] + (if(zip/end? zipper) + (zip/root zipper) + (recur (zip/next (visitor zipper))))))) + +(comment + (let [m {:root + {:children + {"T" + {:children + {"A" {:children + {"T" {:value "TAT", :count 1} + "S" {:value "SAT" :count 1}} + :value "AT" + :count 1}, + "U" {:children {"T" {:value "TUT", :count 1}}}}}}}}] + (let [z (zip/vector-zip (map-trie->seq-trie m)) + pred (fn [loc] + (map? (zip/node loc)))] + (zip-visitor + (fn [loc] + (if (pred loc) + (zip/edit + loc + (fn [node] + (let [[k v] (first (seq node)) + children-counts (->> (child-seq loc) + (map zip/node) + (map (comp :count second first seq second)))] + (if (not-empty children-counts) + (update-in node [k :count] (partial apply (fnil + 0)) children-counts) + node)))) + loc)) + z))) + + ) + +(defprotocol ITrie + (as-map [this] "Map that underlies trie.") + (as-vec [this] "Depth-first post-order vector") + (transform [this f] "Depth-first post-order apply each function to each node.")) + +;; Seq offers a depth-first post-order traversal +;; with children ordered by key. +(deftype Trie [trie] + ITrie + (as-map [_] trie) + (as-vec [_] (map-trie->seq-trie trie)) + (transform [self f] + (->> self + as-vec + zip/vector-zip + (zip-visitor f) + (vec-trie->map-trie) + (Trie.))) + + clojure.lang.IPersistentCollection + (seq [self] + (->> self + as-vec + zip/vector-zip + (iterate zip/next) + (take-while (complement zip/end?)) + (map zip/node) + (filter map?) + (filter (comp :value second first seq)))) + (cons [_ o] + (let [path (cons :root (interleave (repeat :children) (butlast o))) + id (last o) + node (get-in trie path)] + (Trie. + (update-in + trie + path + (fn [prev] + (if (nil? prev) + {:value id + :count 1} + (-> prev + (assoc :value id) ; Assert value same? + (update :count (fnil inc 0))))))))) + (empty [_] (Trie. {})) + (equiv [_ o] + (and (isa? (class o) Trie) + (= (as-map o) trie)))) + +(defn trie + ([] (->Trie {})) + ([& entries] + (reduce + (fn [t entry] + (conj t entry)) + (trie) + entries))) + +(comment + (let [v1 '("T" "A" "T" "TAT") + v2 '("T" "U" "T" "TUT") + v3 '("T" "A" "AT") + t1 (trie v1) + t2 (trie v2) + t3 (trie v1 v2 v3)] + (seq t3)) + + (let [v1 '("T" "A" "T" "TAT") + v2 '("T" "U" "T" "TUT") + v3 '("T" "A" "AT") + t1 (trie v1) + t2 (trie v2) + t3 (trie v1 v2 v3) + pred (fn [loc] + (map? (zip/node loc)))] + (transform + t3 + (fn [loc] + (if (pred loc) + (zip/edit + loc + (fn [node] + (let [[k v] (first (seq node)) + children-counts (->> (child-seq loc) + (map zip/node) + (map (comp :count second first seq second)))] + (if (not-empty children-counts) + (update-in node [k :count] (partial apply (fnil + 0)) children-counts) + node)))) + loc)))) + + ) + +(defn vec->trie + [v] + (let [zipper (zip/vector-zip v)] + (->> zipper + (iterate zip/next) + (take-while (complement zip/end?)) + (filter (comp map? zip/node)) + #_(map #(concat (zip/path %) [(->> % zip/node keys first) + (->> % zip/node vals first :value)])) + (map zip/path)))) + +(comment + (let [v1 '("T" "A" "T" "TAT") + v2 '("T" "U" "T" "TUT") + v3 '("T" "A" "AT") + t1 (trie v1) + t2 (trie v2) + t3 (trie v1 v2 v3) + vect (as-vec t3)] + (vec->trie vect)) + + ) + + +(defn write-node [baos node]) +(defn write-index [baos children]) + +(defn pack-index-to-children [children] + (let [baos (ByteArrayOutputStream.)] + (run! + (fn [[index-key byte-address]] + (.write baos index-key) + (.write baos byte-address)) + children) + (.toByteArray baos))) + +(defn node->byte-array [index-key node-value children] + (let [baos (ByteArrayOutputStream.) + child-index (pack-index-to-children children)] + (.write baos node-value) + (.write baos (count child-index)) + (.writeBytes baos child-index) + (.toByteArray baos))) + + +(defn tpt [trie] + (let [node? (fn [x] + (and (seq? x) + (not-empty x) + (not (seq? (first x))))) + transform (fn [x] + (if (node? x) + (let [[index-key node-value children] x] + (list index-key node-value (count children) children)) + x))] + (walk/postwalk transform trie))) + + +(comment + (let [trie '(b 3 ())] + (tpt trie)) + + (let [trie '(nil 20 ((a 17 ())))] + (tpt trie)) + + (let [trie '(nil 20 ((a 17 ((a 10 ()) + (b 7 ()))) + (b 3 ())))] + (tpt trie)) + + ) diff --git a/test/com/owoga/prhyme/generation/simple_good_turing_test.clj b/test/com/owoga/prhyme/generation/simple_good_turing_test.clj index 140f8ec..21817c7 100644 --- a/test/com/owoga/prhyme/generation/simple_good_turing_test.clj +++ b/test/com/owoga/prhyme/generation/simple_good_turing_test.clj @@ -12,6 +12,11 @@ (with-open [reader (io/reader (io/resource "dark-corpus-test.txt"))] (->> (line-seq reader) doall))) +(def test-sentence (first test-corpus)) + +(def test-tokens + (sgt/pad-tokens (sgt/tokenize-line test-sentence) 1)) + (def train-trie (sgt/lines->trie train-corpus 3)) @@ -31,3 +36,49 @@ (deftest simple-good-turing (testing "accuracy")) +(partition 2 1 test-tokens) +;; => (("" "three") +;; ("three" "years") +;; ("years" "passed") +;; ("passed" "since") +;; ("since" "it") +;; ("it" "began") +;; ("began" "")) + +(defn ngram-perplexity [model vocab n] + (fn [tokens] + (->> tokens + (partition n 1) + (map #(model vocab %)) + (map #(/ (Math/log %) + (Math/log 2)))))) + +(def unigram-perplexity (ngram-perplexity sgt-model vocab 1)) +(def bigram-perplexity (ngram-perplexity sgt-model vocab 2)) +(unigram-perplexity ["" "you're" "a" "dweller" ""]) +;; => (-2.3988984034800693 +;; -9.075314877722885 +;; -4.843244892573044 +;; -11.69197410313705 +;; -2.3988984034800693) +;; => -30.408330680393117 +(bigram-perplexity ["" "you're" "a" "dweller" ""]) +;; => (-8.228033550648288 -9.95473739539021 -10.916015756609784 -13.731947650675696) +;; => -42.830734353323976 +(sgt-model vocab '("it" "began")) +(->> test-corpus + (map sgt/tokenize-line) + (map #(sgt/pad-tokens % 1)) + (map #(partition 2 1 %)) + (map + (fn [tokens] + (map #(sgt-model vocab %) tokens))) + (take 10) + (map + (fn [tokens] + (map #(/ (Math/log %) + (Math/log 2)) + tokens)))) + +(/ (Math/log 10) + (Math/log 2))