Add python example of Katz backoff
parent
f66210d7dd
commit
de194913d9
@ -0,0 +1,219 @@
|
|||||||
|
import math
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import collections
|
||||||
|
from common import SimpleTrie
|
||||||
|
import common.tokenize
|
||||||
|
import lm.probability
|
||||||
|
|
||||||
|
|
||||||
|
class BackoffModel:
|
||||||
|
"""Estimate Katz's backoff model from trained on data stored in
|
||||||
|
`lm.probability.ngramCounter` object.
|
||||||
|
1) Estimate conditional probability for each ngram (MLE or discounted)
|
||||||
|
2) Calculate back-off weights, in order to probs add to one.
|
||||||
|
|
||||||
|
Based on SRI LM implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
precision = 6 # to compensate limited precision in ARPA file
|
||||||
|
|
||||||
|
def __init__(self, counter, unk=True, order=2, discounts=None):
|
||||||
|
"""
|
||||||
|
counter `lm.probability.ngramCounter` with counts data
|
||||||
|
unk create '<unk>' token for probabilty freed by discounting
|
||||||
|
order
|
||||||
|
discounts dictionary of discount callables for each ngram length
|
||||||
|
callable gets count of occurencies and should returned
|
||||||
|
discounted number
|
||||||
|
"""
|
||||||
|
self.trie = SimpleTrie()
|
||||||
|
self.BOWs = {}
|
||||||
|
self.order = order
|
||||||
|
|
||||||
|
if not discounts:
|
||||||
|
discounts = {n: lm.probability.NoDiscount() for n in range(1, order + 1)}
|
||||||
|
self._initProbabilites(counter, unk, discounts)
|
||||||
|
self._calculateBOWs()
|
||||||
|
|
||||||
|
def _initProbabilites(self, counter, unk, discounts):
|
||||||
|
N = len(counter)
|
||||||
|
self.probs = {}
|
||||||
|
self.ngrams = set() # non discounted ngrams
|
||||||
|
|
||||||
|
for ngram, count in counter.items():
|
||||||
|
n = len(ngram)
|
||||||
|
if not n:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if n == 1:
|
||||||
|
prevCount = N
|
||||||
|
elif n > 1:
|
||||||
|
prevCount = counter[ngram[:-1]]
|
||||||
|
|
||||||
|
# hack assumed from SRI LM
|
||||||
|
# however here applied always with GT discount
|
||||||
|
if isinstance(
|
||||||
|
discounts[n], lm.probability.GoodTuringDiscount
|
||||||
|
): # and disdiscounts[n].validCoefficients():
|
||||||
|
prevCount += 1
|
||||||
|
|
||||||
|
# discount only the queried ngram count
|
||||||
|
# prevCount is unchanged as we discount conditional frequency (probability) only
|
||||||
|
prob = discounts[n](count) / prevCount
|
||||||
|
|
||||||
|
if prob > 0 and ngram not in [(common.tokenize.TOKEN_BEG_SENTENCE,)]:
|
||||||
|
logProb = round(math.log10(prob), self.precision)
|
||||||
|
self.probs[ngram] = logProb
|
||||||
|
self.ngrams.add(ngram)
|
||||||
|
else:
|
||||||
|
self.probs[ngram] = lm.probability.logNegInf
|
||||||
|
# don't add to ngrams, that will eliminate discounted ngrams
|
||||||
|
|
||||||
|
# add ngram to trie for search
|
||||||
|
node = self.trie
|
||||||
|
for token in ngram:
|
||||||
|
node = node[token]
|
||||||
|
|
||||||
|
self.ngrams.add((common.tokenize.TOKEN_BEG_SENTENCE,))
|
||||||
|
self.probs[(common.tokenize.TOKEN_BEG_SENTENCE,)] = lm.probability.logNegInf
|
||||||
|
if unk:
|
||||||
|
self.ngrams.add((common.tokenize.TOKEN_UNK,))
|
||||||
|
self.probs[(common.tokenize.TOKEN_UNK,)] = lm.probability.logNegInf
|
||||||
|
_ = self.trie[common.tokenize.TOKEN_UNK]
|
||||||
|
|
||||||
|
def _calculateBOWs(self):
|
||||||
|
# see Ngram::computeBOW in SRI LM
|
||||||
|
# calculate from low order to high (FIFO trie walk)
|
||||||
|
queue = collections.deque()
|
||||||
|
queue.append(())
|
||||||
|
|
||||||
|
while len(queue):
|
||||||
|
context = queue.popleft()
|
||||||
|
result = self._BOWforContext(context)
|
||||||
|
if result:
|
||||||
|
numerator, denominator = result
|
||||||
|
if len(context) == 0:
|
||||||
|
self._distributeProbability(numerator, context)
|
||||||
|
elif numerator == 0 and denominator == 0:
|
||||||
|
self.BOWs[context] = 0 # log 1
|
||||||
|
else:
|
||||||
|
self.BOWs[context] = (
|
||||||
|
math.log10(numerator) - math.log10(denominator)
|
||||||
|
if numerator > 0
|
||||||
|
else lm.probability.logNegInf
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.BOWs[context] = lm.probability.logNegInf
|
||||||
|
|
||||||
|
if len(context) < self.order:
|
||||||
|
node = self.trie
|
||||||
|
for token in context:
|
||||||
|
node = node[token]
|
||||||
|
for k in node:
|
||||||
|
queue.append(context + (k,))
|
||||||
|
|
||||||
|
def _BOWforContext(self, context):
|
||||||
|
denominator = 1
|
||||||
|
|
||||||
|
node = self.trie
|
||||||
|
for token in context:
|
||||||
|
node = node[token]
|
||||||
|
|
||||||
|
numerator = 1
|
||||||
|
denominator = 1
|
||||||
|
for w in node:
|
||||||
|
ngram = context + (w,)
|
||||||
|
numerator -= 10 ** self.probs[ngram]
|
||||||
|
if len(ngram) > 1:
|
||||||
|
denominator -= 10 ** self.probs[ngram[1:]] # lower order estimate
|
||||||
|
|
||||||
|
# rounding errors
|
||||||
|
# if numerator < 0 and numerator > -lm.probability.epsilon:
|
||||||
|
# numerator = 0
|
||||||
|
# if denominator < 0 and denominator > -lm.probability.epsilon:
|
||||||
|
# denominator = 0
|
||||||
|
if abs(numerator) < lm.probability.epsilon:
|
||||||
|
numerator = 0
|
||||||
|
if abs(denominator) < lm.probability.epsilon:
|
||||||
|
denominator = 0
|
||||||
|
|
||||||
|
if denominator == 0 and numerator > lm.probability.epsilon:
|
||||||
|
if numerator == 1: # shouldn't be
|
||||||
|
return None
|
||||||
|
scale = -math.log10(
|
||||||
|
1 - numerator
|
||||||
|
) # log factor to scale sum context probabilities to one
|
||||||
|
for w in node:
|
||||||
|
ngram = context + (w,)
|
||||||
|
self.probs[ngram] += scale
|
||||||
|
numerator = 0
|
||||||
|
elif numerator < 0:
|
||||||
|
# erroneous state
|
||||||
|
return None
|
||||||
|
elif denominator <= 0:
|
||||||
|
if numerator > lm.probability.epsilon:
|
||||||
|
# erroneous state
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
numerator = 0
|
||||||
|
denominator = 0
|
||||||
|
|
||||||
|
return numerator, denominator
|
||||||
|
|
||||||
|
def _distributeProbability(self, probability, context):
|
||||||
|
if probability == 0:
|
||||||
|
return
|
||||||
|
node = self.trie
|
||||||
|
for token in context:
|
||||||
|
node = node[token]
|
||||||
|
if common.tokenize.TOKEN_UNK in node:
|
||||||
|
print(
|
||||||
|
"Assigning {:.3g} probability to {} in context {}.".format(
|
||||||
|
probability, common.tokenize.TOKEN_UNK, context
|
||||||
|
),
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
ngram = context + (common.tokenize.TOKEN_UNK,)
|
||||||
|
self.probs[ngram] = math.log10(probability)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"Distributing {:.3g} probability over {} tokens in context {}.".format(
|
||||||
|
probability, len(node), context
|
||||||
|
),
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
amount = probability / len(node)
|
||||||
|
for k in node:
|
||||||
|
ngram = context + (k,)
|
||||||
|
self.probs[ngram] = math.log10(10 ** self.probs[ngram] + amount)
|
||||||
|
|
||||||
|
def dumpToARPA(self, file):
|
||||||
|
"""Write estimated model to ARPA text file."""
|
||||||
|
print("\\data\\", file=file)
|
||||||
|
for i in range(self.order):
|
||||||
|
print(
|
||||||
|
"ngram {}={}".format(
|
||||||
|
i + 1, len([1 for ngram in self.ngrams if len(ngram) == i + 1])
|
||||||
|
),
|
||||||
|
file=file,
|
||||||
|
)
|
||||||
|
print(file=file)
|
||||||
|
|
||||||
|
for n in range(1, self.order + 1):
|
||||||
|
print("\\{}-grams:".format(n), file=file)
|
||||||
|
for k in sorted(self.ngrams):
|
||||||
|
if len(k) != n:
|
||||||
|
continue
|
||||||
|
if k in self.BOWs and self.BOWs[k] != 0:
|
||||||
|
print(
|
||||||
|
"{:.7g}\t{}\t{:.7g}".format(
|
||||||
|
self.probs[k], " ".join(k), self.BOWs[k]
|
||||||
|
),
|
||||||
|
file=file,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("{:.7g}\t{}".format(self.probs[k], " ".join(k)), file=file)
|
||||||
|
print(file=file)
|
||||||
|
|
||||||
|
print("\\end\\", file=file)
|
Loading…
Reference in New Issue