{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 20 Newsgroups\n",
    "\n",
    "Using [this tutorial!](https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html) | 10-22-19"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['alt.atheism',\n",
      " 'comp.graphics',\n",
      " 'comp.os.ms-windows.misc',\n",
      " 'comp.sys.ibm.pc.hardware',\n",
      " 'comp.sys.mac.hardware',\n",
      " 'comp.windows.x',\n",
      " 'misc.forsale',\n",
      " 'rec.autos',\n",
      " 'rec.motorcycles',\n",
      " 'rec.sport.baseball',\n",
      " 'rec.sport.hockey',\n",
      " 'sci.crypt',\n",
      " 'sci.electronics',\n",
      " 'sci.med',\n",
      " 'sci.space',\n",
      " 'soc.religion.christian',\n",
      " 'talk.politics.guns',\n",
      " 'talk.politics.mideast',\n",
      " 'talk.politics.misc',\n",
      " 'talk.religion.misc']\n"
     ]
    }
   ],
   "source": [
    "from sklearn.datasets import fetch_20newsgroups\n",
    "from sklearn.datasets import fetch_20newsgroups_vectorized\n",
    "newsgroups_train = fetch_20newsgroups(subset='train')\n",
    "# newsgroups_train = fetch_20newsgroups(subset='train')\n",
    "from pprint import pprint\n",
    "pprint(list(newsgroups_train.target_names))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The data is in `filenames` and `target` attributes (target is integer index of category)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# categories = ['alt.atheism', 'sci.space']\n",
    "# atheism = ['alt.atheism']\n",
    "# ng_atheism = fetch_20newsgroups(subset='train', categories=atheism)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2034, 34118)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "categories = ['alt.atheism', 'talk.religion.misc',\n",
    "              'comp.graphics', 'sci.space']\n",
    "newsgroups_train = fetch_20newsgroups(subset='train',\n",
    "                                      categories=categories)\n",
    "# newsgroups_train = fetch_20newsgroups(subset='train',\n",
    "#                                       categories=categories)\n",
    "vectorizer = TfidfVectorizer()\n",
    "vectors = vectorizer.fit_transform(newsgroups_train.data)\n",
    "vectors.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vectorizer = TfidfVectorizer()\n",
    "# vectors = vectorizer.fit_transform(ng_atheism.data)\n",
    "# vectors.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8821359240272957"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.naive_bayes import MultinomialNB\n",
    "from sklearn import metrics\n",
    "newsgroups_test = fetch_20newsgroups(subset='test',\n",
    "                                     categories=categories)\n",
    "vectors_test = vectorizer.transform(newsgroups_test.data)\n",
    "clf = MultinomialNB(alpha=.01)\n",
    "clf.fit(vectors, newsgroups_train.target)\n",
    "pred = clf.predict(vectors_test)\n",
    "metrics.f1_score(newsgroups_test.target, pred, average='macro')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_top10(classifier, vectorizer, categories):\n",
    "    feature_names = np.asarray(vectorizer.get_feature_names())\n",
    "    for i, category in enumerate(categories):\n",
    "        top10 = np.argsort(classifier.coef_[i])[-10:]\n",
    "        print(\"%s: %s\" % (category, \" \".join(feature_names[top10])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[17058 20259 20258 ... 22657 30993 30643]\n",
      "[17058 19716 19715 ... 22657 30993 30643]\n",
      "[34117 28442 28443 ... 30993 22657 30643]\n",
      "[17058 17915 17913 ... 30993 22657 30643]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "def show_top10(classifier, vectorizer, categories):\n",
    "    feature_names = np.asarray(vectorizer.get_feature_names())\n",
    "    for i, category in enumerate(categories):\n",
    "        print(np.argsort(classifier.coef_[i]))\n",
    "#         top10 = np.argsort(classifier.coef_[i])[-10:]\n",
    "#         print(\"%s: %s\" % (category, \" \".join(feature_names[top10])))\n",
    "\n",
    "show_top10(clf, vectorizer, newsgroups_train.target_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
