{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNB Feature Weighting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['alt.atheism',\n",
       " 'comp.graphics',\n",
       " 'comp.os.ms-windows.misc',\n",
       " 'comp.sys.ibm.pc.hardware',\n",
       " 'comp.sys.mac.hardware',\n",
       " 'comp.windows.x',\n",
       " 'misc.forsale',\n",
       " 'rec.autos',\n",
       " 'rec.motorcycles',\n",
       " 'rec.sport.baseball',\n",
       " 'rec.sport.hockey',\n",
       " 'sci.crypt',\n",
       " 'sci.electronics',\n",
       " 'sci.med',\n",
       " 'sci.space',\n",
       " 'soc.religion.christian',\n",
       " 'talk.politics.guns',\n",
       " 'talk.politics.mideast',\n",
       " 'talk.politics.misc',\n",
       " 'talk.religion.misc']"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.datasets import fetch_20newsgroups\n",
    "newsgroups_train = fetch_20newsgroups(subset='train')\n",
    "newsgroups_train.target_names\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "arrays must all be same length",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-19-333e9965d118>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpandas\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mng_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataFrame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnewsgroups_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/usr/local/lib/python3.7/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, data, index, columns, dtype, copy)\u001b[0m\n\u001b[1;32m    409\u001b[0m             )\n\u001b[1;32m    410\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 411\u001b[0;31m             \u001b[0mmgr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minit_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolumns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    412\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMaskedArray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    413\u001b[0m             \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmrecords\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmrecords\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.7/site-packages/pandas/core/internals/construction.py\u001b[0m in \u001b[0;36minit_dict\u001b[0;34m(data, index, columns, dtype)\u001b[0m\n\u001b[1;32m    255\u001b[0m             \u001b[0marr\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_datetime64tz_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0marr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marrays\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    256\u001b[0m         ]\n\u001b[0;32m--> 257\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0marrays_to_mgr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marrays\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolumns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    258\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    259\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.7/site-packages/pandas/core/internals/construction.py\u001b[0m in \u001b[0;36marrays_to_mgr\u001b[0;34m(arrays, arr_names, index, columns, dtype)\u001b[0m\n\u001b[1;32m     75\u001b[0m     \u001b[0;31m# figure out the index, if necessary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 77\u001b[0;31m         \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mextract_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     78\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     79\u001b[0m         \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mensure_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.7/site-packages/pandas/core/internals/construction.py\u001b[0m in \u001b[0;36mextract_index\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m    366\u001b[0m             \u001b[0mlengths\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mraw_lengths\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    367\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlengths\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 368\u001b[0;31m                 \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"arrays must all be same length\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    369\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    370\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mhave_dicts\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: arrays must all be same length"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "ng_df = pd.DataFrame(newsgroups_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([7.51788182e-04, 4.22433359e-04, 4.80428448e-03, ...,\n",
       "       7.87587619e-05, 3.29354823e-04, 2.79235610e-04])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# the feature file is an edited copy of Weka's MultinomialNB output\n",
    "import pandas as pd\n",
    "weka_output = pd.read_csv('MNB-feature-weight.tsv', delimiter='\\t')\n",
    "features = weka_output['token'].values\n",
    "neg_cond_prob = weka_output['neg'].values\n",
    "pos_cond_prob = weka_output['pos'].values\n",
    "neg_cond_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "&\n",
      "0.0007517881818897805\n",
      "0.0006325763932817808\n"
     ]
    }
   ],
   "source": [
    "print(features[0])\n",
    "print(neg_cond_prob[0])\n",
    "print(pos_cond_prob[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.17265361805312018\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "log_ratios = []\n",
    "for i in range(0, len(features)):\n",
    "    log_ratio = math.log(pos_cond_prob[i]) - math.log(neg_cond_prob[i])\n",
    "    log_ratios.append(log_ratio)\n",
    "print(log_ratios[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(1.3471721356912916, 'fantastic'), (1.3527011771704371, 'cameron'), (1.4004166502101043, 'hanks'), (1.4367842943809803, 'tarantino'), (1.4946038652698048, 'toy'), (1.5919548621089836, 'titanic'), (1.6368054282743358, 'ripley'), (1.7681414303354224, 'era'), (1.829177320921791, 'wonderfully'), (1.8422494024891432, 'scorsese'), (1.842249402489144, 'jedi'), (1.960032438145527, 'derek'), (2.3530750262551345, 'outstanding'), (2.441870525038266, 'truman'), (2.5210078455969906, 'damon'), (2.875264408671441, 'bulworth'), (3.228543763609035, 'lebowski'), (3.5852187075477664, 'flynt'), (3.7675402643417213, 'mulan'), (4.055222336793502, 'shrek')]\n"
     ]
    }
   ],
   "source": [
    "feature_ranks = sorted(zip(log_ratios, features))\n",
    "\n",
    "# print the words with highest pos/neg conditional prob ratio / most positive words\n",
    "top_pos_features = feature_ranks[-20:]\n",
    "print(top_pos_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(-4.18119819047189, '&nbsp'), (-3.3091135511022296, 'seagal'), (-2.706350432010554, 'schumacher'), (-2.263144995919541, 'wrestling'), (-2.191401091060701, 'godzilla'), (-1.982431592783855, 'spawn'), (-1.8954202157942248, 'wasted'), (-1.8419315308432394, 'lame'), (-1.8013912661459477, 'poorly'), (-1.772522079471071, 'worst'), (-1.7721875753702774, 'waste'), (-1.7388511551026857, 'ridiculous'), (-1.7130986590002708, 'awful'), (-1.7130986590002708, 'eve'), (-1.6570091923492267, 'stupid'), (-1.6495852532779436, 'snake'), (-1.635137117530559, 'unfunny'), (-1.5975857718784265, 'uninteresting'), (-1.5852652874903859, 'dull'), (-1.5536709220721674, 'arnold')]\n"
     ]
    }
   ],
   "source": [
    "# print the words with lowest pos/neg conditional prob ratio / most negative words\n",
    "top_neg_features = feature_ranks[:20]\n",
    "print(top_neg_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(0.0024851215450355678, '*'), (0.0024980312673474397, 'funny'), (0.002517395850815249, 'comedy'), (0.002556125017750869, 'star'), (0.002562579878906806, 'takes'), (0.002569034740062741, 'year'), (0.002594854184686488, 'played'), (0.002601309045842423, 'cast'), (0.0026271284904661693, 'fact'), (0.002711041685493345, 'find'), (0.0027691354358967747, 'family'), (0.002788500019364582, 'big'), (0.002827229186300204, 'young'), (0.0028465937697680125, 'audience'), (0.0028465937697680125, 'john'), (0.00288532293670363, 'real'), (0.00288532293670363, 'things'), (0.0030337847432901728, 'action'), (0.003130607660629221, 'years'), (0.003137062521785156, 'role'), (0.003156427105252967, 'made'), (0.003317798634151378, 'work'), (0.003337163217619189, 'director'), (0.0033629826622429354, 'end'), (0.0035501736357650936, 'performance'), (0.0036082673861685243, 'back'), (0.0036211771084803954, 'makes'), (0.0036792708588838244, '--'), (0.0036921805811956977, 'don'), (0.0037502743315991267, 'plot'), (0.003879371554717858, 'doesn'), (0.004060107667084079, 'movies'), (0.004085927111707826, 'scenes'), (0.004156930584423131, 'world'), (0.004182750029046874, 'love'), (0.0046345903099624325, 'scene'), (0.004770142394237098, 'man'), (0.00480887156117272, 'great'), (0.004996062534694879, 'make'), (0.005086430590877992, 'people'), (0.005260711842088278, '-'), (0.005667368094912279, 'films'), (0.006351583377441551, 'characters'), (0.0064032222666890425, 'life'), (0.006893791714540223, 'character'), (0.007565097274757618, 'time'), (0.007694194497876354, 'good'), (0.00794593408295787, 'story'), (0.01609196886174978, 'movie'), (0.033436180787751256, 'film')]\n",
      "\n",
      "[(0.002606199030551239, 'guy'), (0.00262051880544438, 'actors'), (0.002649158355230655, 'acting'), (0.0026634781301237962, 'point'), (0.002670638017570365, 'plays'), (0.0026777979050169338, 'long'), (0.00271359734224978, 'role'), (0.0027422368920360586, 'minutes'), (0.002778036329268904, 'played'), (0.0027923561041620425, 'fact'), (0.0028353154288414586, 'great'), (0.0028711148660743053, 'things'), (0.002899754415860584, 'real'), (0.002949873627986569, 'comedy'), (0.003093071376917956, 'makes'), (0.0031789900262767863, 'funny'), (0.0032076295760630627, 'thing'), (0.0032147894635096343, 'love'), (0.0033651470998875893, 'audience'), (0.003408106424567007, 'back'), (0.003422426199460145, 'script'), (0.003515504736265545, 'isn'), (0.003522664623712114, 'life'), (0.0035369843986052547, 'work'), (0.003780420571788611, 'end'), (0.0038520194462543047, 'big'), (0.0038520194462543047, 'made'), (0.003988057307739123, 'movies'), (0.004066816069651386, '--'), (0.0042601330307087595, 'director'), (0.004288772580495035, '-'), (0.0043102522428347416, 'man'), (0.004338891792621018, 'action'), (0.0045178889787852545, 'scenes'), (0.004610967515590657, 'films'), (0.004718365827289194, 'people'), (0.004718365827289194, 'scene'), (0.004804284476648028, '*'), (0.004811444364094596, 'doesn'), (0.005104999749403939, 'don'), (0.0058639478187402895, 'make'), (0.006250581740855032, 'characters'), (0.0063007009529810155, 'plot'), (0.0065369772387178105, 'story'), (0.0066372156629697756, 'character'), (0.007295925308054158, 'bad'), (0.007961794840585104, 'time'), (0.008069193152283648, 'good'), (0.02270400309307137, 'movie'), (0.0304510013102594, 'film')]\n"
     ]
    }
   ],
   "source": [
    "# if the model is to classify more than two categories, \n",
    "# you can calculate the log ratio between the conditional probabilies of any two categories \n",
    "\n",
    "# if you simply print out the words with highest conditional probs in each category\n",
    "# you may or may not get informative features \n",
    "# because some popular words in this category may also be popular in other categories.\n",
    "\n",
    "# The following code prints out the words with \n",
    "# highest positive conditional probs and highest negative conditinal probs\n",
    "# and both lists include common words like \"are\", \"this\", etc.\n",
    "\n",
    "pos_features = sorted(zip(pos_cond_prob, features))\n",
    "print(pos_features[-50:])\n",
    "print()\n",
    "neg_features = sorted(zip(neg_cond_prob, features))\n",
    "print(neg_features[-50:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
