import math
from collections import Counter

class naivebayes:

    def __init__(self):
        self.cntwordinclass = {}
        self.cntclass = {}
        self.vocabulary = set()
        self.classes = set()
        self.classtotal = 0
        self.totalwordcount = 0
        self.cntdocinclass = {}
        self.wordcount = {}

    def topn_print(self, n = 10, reverse = False):
        "Prints top n most/least indicative words in all classes."
        for c in self.classes:
            indwords = self.get_indicativeness(c)
            if reverse:
                indwords.reverse()
            print c, indwords[0:n]
    
    def get_indicativeness(self, cl, wordprior = 0.5):
        """Returns an array of all words in a class, sorted by how indicative
        of the class the words are. Indicativeness is calculated by
        information gain."""
        classvocabulary = [w[1] for w in self.cntwordinclass.keys() if w[0] == cl]
        x = {v:self._getprob(cl,v,wordprior) for v in classvocabulary}
        for v in classvocabulary:
            x[v] -= max(self._getprob(c,v,wordprior) for c in self.classes if c != cl)
        return sorted(x, key=x.get, reverse=True)

    def _getprob(self, c, w, wordprior = 0.5, complement = False):
        """Calculates the log of p(w|c), using the specified in wordprior.
           Setting complement to True gives us the complement p(w|not c)."""
        if complement:
            classcount = self.wordcount[w] - self.cntwordinclass.get((c,w),0) + wordprior
            totalcount = self.totalwordcount - self.cntclass[c] + wordprior * len(self.vocabulary)
        else:
            classcount = self.cntwordinclass.get((c, w), 0) + wordprior
            totalcount = self.cntclass.get(c, 0) + wordprior * len(self.vocabulary)
        return math.log(classcount/float(totalcount))

    def train(self, docs):
        """Trains a NB classifier, i.e. collects counts.
           docs are assumed to be a list containing lists where the first entry
           in each list is the class name, followed by lists of words, e.g. 
           docs = [['class1',doc1,doc2],['class2',doc3,doc4],['class3',doc5]]
           Here doc1-doc5 would be lists of words."""
        self.classes = {c[0] for c in docs}
        for classdoc in docs:
            classsig = classdoc[0]
            self.classes.add(classsig)
            for doc in classdoc[1:]:
                self.cntdocinclass[classsig] = self.cntdocinclass.get(classsig, 0) + 1
                self.classtotal += 1
                for word in doc:
                    self.vocabulary.add(word)
                    self.totalwordcount = self.totalwordcount + 1
                    self.cntwordinclass[(classsig,word)] = self.cntwordinclass.get((classsig,word), 0) + 1
                    self.cntclass[classsig] = self.cntclass.get(classsig, 0) + 1
                    self.wordcount[word] = self.wordcount.get(word, 0) + 1
                                        
    def classify(self, message, complement = False, bernoulli = False,
                                    classprior = 0.5, wordprior = 0.5):
        """Classifies a document using the trained model. Returns dict
        containing {class1:probability1, ... , classN:probabilityN}."""
        result = {}
        for c in self.classes:
            prior = math.log((self.cntdocinclass[c] + classprior) /
                    float(self.classtotal + len(self.classes) * classprior))
            if bernoulli:
                ptot = sum(self._getprob(c,w,wordprior,complement)
                           for w in message if w in self.vocabulary)
            else:
                ptot = sum(self._getprob(c,w,wordprior,complement) * occurrences
                           for w, occurrences in Counter(message).iteritems()
                           if w in self.vocabulary)
            if complement:
                result[c] = prior - ptot
            else:
                result[c] = prior + ptot
        return result
