From f0e03d76a25aa2acbd041909cec9a89b4e7ff93c Mon Sep 17 00:00:00 2001 From: Eric Ihli Date: Wed, 30 Dec 2020 14:44:09 -0800 Subject: [PATCH] Progress towards Katz back-off --- dev/examples/scratch.clj | 52 ++++++- .../prhyme/generation/simple_good_turing.clj | 141 +++++++++++++++++- 2 files changed, 189 insertions(+), 4 deletions(-) diff --git a/dev/examples/scratch.clj b/dev/examples/scratch.clj index 4fdead2..bb1897d 100644 --- a/dev/examples/scratch.clj +++ b/dev/examples/scratch.clj @@ -222,13 +222,61 @@ (-> (add-to-trie-1 acc 1 tokens) (add-to-trie-1 2 tokens) (add-to-trie-1 3 tokens))) - {})))) + {}) + ((fn [trie] + (assoc + trie + :count + (->> trie + (map second) + (map :count) + (apply +)))))))) (->> (get-in trie ["you're" "my"]) (remove (fn [[k _]] (= :count k)))) (def r*s (sgt/trie->r*s trie)) - + r*s + (get-in trie ["you're" "my"]) + + (get-in r*s [2 :r*s]) + + + (get-in trie ["my" "us"]) + + (get-in {:a 1} '()) + (sgt/katz-alpha + trie + r*s + ["you're" "my" "lady"] + (sgt/katz-beta trie r*s ["you're" "my" "lady"])) + + (sgt/alpha trie r*s ["eat" "my"] 2) + (get-in trie ["you're" "my" "lady"]) + (sgt/katz-estimator trie r*s 0 ["you're" "my" "head"]) +;; => 0.1067916992217116 + (sgt/katz-estimator trie r*s 0 ["you're" "my" "lady"]) +;; => 0.016222893164898698 + (sgt/katz-estimator trie r*s 0 ["you're" "my" "fooball"]) +;; => 9.223367982725652E-6 + (float (/ 1 27)) + (get-in trie ["eat" "my"]) + (sgt/sum-of-betas trie r*s ["you're" "my"]) + (sgt/katz-beta trie r*s ["you're" "my" "lady"]) + (get-in trie ["eat" "my" "heart"]) + (get-in trie ["my" "heart"]) + (sgt/katz-smoothing trie r*s ["eat" "my" "heart"] 5) + (sgt/prob-observed-ngram trie r*s ["eat"]) + + (->> ["pathe" "way"] (get trie) (map :count)) + (sgt/mle trie ["you're"]) + + (let [words ["eat" "my"] + r (get-in trie (concat words [:count]) 0) + flattened (sgt/filter-trie-to-ngrams trie 2)] + (count flattened)) + + (get-in r*s [2]) (def probs (->> (range 1 4) (map #(vector % (filter-trie-to-ngrams trie %))) diff --git a/src/com/owoga/prhyme/generation/simple_good_turing.clj b/src/com/owoga/prhyme/generation/simple_good_turing.clj index a6a81fc..c0a863f 100644 --- a/src/com/owoga/prhyme/generation/simple_good_turing.clj +++ b/src/com/owoga/prhyme/generation/simple_good_turing.clj @@ -1,7 +1,10 @@ -(ns com.owoga.prhyme.generation.simple-good-turing) +(ns com.owoga.prhyme.generation.simple-good-turing + (:require [clojure.set])) ;; Pythons NLTK is a great resource for this. ;; https://github.com/nltk/nltk/blob/2.0.4/nltk/probability.py +;; +;; Useful to check out commit 3c8a25379 and look at nltk/model/ngram.py (defn least-squares-log-log-linear-regression @@ -279,7 +282,7 @@ :nrs nrs :zrs zrs :lm lm - :r*s (r-stars rs zrs lm)}])) + :r*s (into (sorted-map) (map vector rs (r-stars rs zrs lm)))}])) ngram-rs-nrs-map)))) ;; zrs (average-consecutives rs nrs) @@ -309,6 +312,140 @@ :else (* 0.4 (stupid-backoff trie probs (rest words)))))) +(defn mle + [trie words] + (let [r (get-in trie (concat words [:count]) 0) + q (get-in trie (concat (butlast words) [:count]))] + (/ r q))) + +(declare katz-beta-alpha) + +(defn katz-estimator + [trie r*s k words] + (let [r (get-in trie (concat words [:count]) 0)] + (if (> r 0) + (let [n (count words) + r* (get-in r*s [n :r*s r]) + r-1 (get-in trie (concat (butlast words) [:count]) 1) + d (/ r* r)] + (* d (/ r r-1))) + (let [alpha (/ (katz-beta-alpha trie r*s k words) + (katz-beta-alpha trie r*s k (rest words)))] + (* alpha + (katz-estimator + trie + r*s + k + (rest words))))))) + +(defn katz-beta-alpha + [trie r*s k words] + (let [ngrams (->> (get-in trie (butlast words)) + (remove #(= :count (first %))) + (filter (fn [[_ v]] (> (:count v) k))) + (map first) + (map #(concat (butlast words) [%])) + (map #(katz-estimator trie r*s k %)) + (apply +))] + (- 1 ngrams))) + +(defn katz-alpha + [trie r*s words b] + (let [denom (->> (get-in trie (rest (butlast words))) + ((fn [a] + (println a) + a)) + (remove #(= :count (first %))) + (map first) + (map #(concat (rest (butlast words)) %)) + (map #(katz-estimator trie r*s %)) + (apply +) + (- 1 ))] + (println denom) + (/ b denom))) + +(defn beta + "Estimate of the sum of conditional probabilities of all words wₘ which never + followed wₘ₋₁." + [trie r*s words] + (let [n (count words) + r (get-in trie (concat words [:count]) 0) + r* (get-in r*s [n :r*s r]) + d (/ r* r)] + (if (zero? r) + 1 + (* d (/ r (get-in trie (concat (butlast words) [:count]))))))) + +(defn sum-of-betas + [trie r*s words] + (let [ngrams (->> (get-in trie words) + (remove #(= (first %) :count)) + (map first) + (map #(concat words [%])))] + (- 1 (->> ngrams + (map (partial beta trie r*s)) + (apply +))))) + +(defn prob-observed-ngram + [trie r*s words] + (let [observed (->> (get-in trie words) + (remove #(= :count (first %)))) + dictionary (->> trie + (remove #(= :count (first %))) + (map first)) + unobserved (clojure.set/difference + (into #{} dictionary) + (into #{} (map first observed))) + ;; The likelihood of an observed ngram + ;; is 1 - N₁/N, so the sum of observed counts + ;; needs to be normalized by that. + sum-of-observed-counts (->> observed + (map second) + (map :count) + (apply +)) + ;; The likelihood of an unobserved ngram + ;; is N₁/N, so the sum of unobserved counts + ;; needs to be normalized with that. + sum-of-unobserved-counts (->> unobserved + (map #(get trie %)) + (map :count) + (apply +))] + [sum-of-observed-counts sum-of-unobserved-counts (take 5 unobserved)])) + +(declare katz-smoothing) + +(defn alpha + [trie r*s words k] + (let [n (count words) + beta-mass (sum-of-betas trie r*s words) + denom (->> (get-in trie (butlast words)) + (remove #(= (first %) :count)) + (filter (fn [[kw v]] + (<= (:count v) k))) + (map (fn [[kw v]] + (let [new-words (concat (rest (butlast words)) [kw]) + #_#__ (do (println new-words)) + smoothed (katz-smoothing trie r*s new-words k)] + (println smoothed kw) + smoothed))) + (apply +))] + (println beta-mass denom) + (/ beta-mass denom))) + +(defn katz-smoothing + [trie r*s words k] + (let [n (count words) + r (get-in trie (concat words [:count]) 0) + r-1 (get-in trie (concat (butlast words) [:count])) + r* (get-in r*s [n :r*s r]) + d (/ r* r)] + (Thread/sleep 100) + (println r r-1 d k words (* d (/ r r-1))) + (if (> r k) + (* d (/ r r-1)) + (* (alpha trie r*s (butlast words) k) + (katz-smoothing trie r*s (rest words) k))))) + (defn katz-backoff [trie probs r*s words] (let [k 0