手書き数字のデータを扱う!Pythonでmnistを使う方法【初心者向け】
初心者向けにPythonでmnistを使う方法について解説しています。これは機械学習の入門として使われるデータセットのひとつで、手書き数字の画像データを集めたものです。導入の方法と基本の使い方についてサンプルプログラムを見ながら学びましょう。
テックアカデミーマガジンは受講者数No.1のプログラミングスクール「テックアカデミー」が運営。初心者向けにプロが解説した記事を公開中。現役エンジニアの方はこちらをご覧ください。 ※ アンケートモニター提供元:GMOリサーチ株式会社 調査期間:2021年8月12日~8月16日 調査対象:2020年8月以降にプログラミングスクールを受講した18~80歳の男女1,000名 調査手法:インターネット調査
Pythonでmnistを使う方法について解説します。
そもそもPythonについてよく分からないという方は、Pythonとは何なのか解説した記事を読むとさらに理解が深まります。
なお本記事は、TechAcademyのオンラインブートキャンプPython講座の内容をもとに紹介しています。
今回は、Pythonに関する内容だね!
どういう内容でしょうか?
mnistの使い方について詳しく説明していくね!
お願いします!
mnistとは
mnistとは、手書き数字の画像のデータのセットです。機械学習やディープラーニングを学ぶ際のデータセットとして良く用いられます。画像は全部で7万枚あり、トレーニング用データ6万枚とテスト用データ1万枚で構成されています。
データは、画像データとラベルで構成されています。ラベルとは画像データが表す数字です。
1つ1つの画像はグレースケールで、大きさが縦28ピクセル・横28ピクセルです。各ピクセルには0〜255の値が格納されています。
ちなみにmnistとは Mixed National Institute of Standards and Technology database の略です。
mnistの使い方
mnistを使うには、以下の方法があります。
THE MNIST DATABASE of handwritten digits からダウンロードする
こちらが本家です。Yann LeCun さんのサイトからダウンロードできます。
http://yann.lecun.com/exdb/mnist/
scikit-learn を使い mldata.org からダウンロードする
mldata.orgは機械学習用データを集めたサイトです。以下のように記述することで、 mnist をダウンロードできます。初回ダウンロードには時間がかかりますが、次回以降はダウンロード済のデータを読み込んで利用できます。
ただし、 mldata.org は、しばしばサーバがダウンしており、ダウンロードできない場合があります。なお、scikit-learnには、 load_digits というメソッドで手書き数字のデータセットを取得できます。これは mnist を加工して作成した、 縦8ピクセル・横8ピクセル、1800枚の小さなデータセットです。 mnist とは大きさも枚数も異なりますので注意してください。
from sklearn.datasets import fetch_mldata mnist = fetch_mldata('MNIST original', data_home=".")
各種機械学習のライブラリを使う
最もおすすめの方法です。 TensorFlow や Keras などの機械学習のライブラリには、あらかじめ mnist をダウンロードするメソッドが用意されています。
実際に書いてみよう
今回のサンプルプログラムでは、機械学習ライブラリの Keras を使い、 mnist のダウンロードと表示を行います。なお事前に必要なライブラリのインストールが必要です。
pip install keras pip install matplotlib
サンプルプログラムは以下となります。
# 必要なライブラリのインポート from keras import backend as K from keras.datasets import mnist import matplotlib.pyplot as plt # mnist データをダウンロード (train_images, train_labels), (test_images, test_labels) = mnist.load_data() # 画像データとラベルの要素数を表示 print("画像データの要素数", train_images.shape) print("ラベルデータの要素数", train_labels.shape) # ラベルと画像データを表示 for i in range(0,10): print("ラベル", train_labels[i]) plt.imshow(train_images[i].reshape(28, 28), cmap='Greys') plt.show()
実行結果は以下のようになります。
この記事を監修してくれた方
太田和樹(おおたかずき) 普段は主に、Web系アプリケーション開発のプロジェクトマネージャーとプログラミング講師を行っている。守備範囲はフロントエンド、モバイル、サーバサイド、データサイエンティストと幅広い。その幅広い知見を生かして、複数の領域を組み合わせた新しい提案をするのが得意。 開発実績:画像認識技術を活用した駐車場混雑状況把握(実証実験)、音声認識を活用したヘルプデスク支援システム、Pepperを遠隔操作するアプリの開発、大規模基幹系システムの開発・導入マネジメント 地方在住。仕事のほとんどをリモートオフィスで行う。通勤で消耗する代わりに趣味のDIYや家庭菜園、家族との時間を楽しんでいる。 |
内容分かりやすくて良かったです!
ゆかりちゃんも分からないことがあったら質問してね!
分かりました。ありがとうございます!
TechAcademyでは、初心者でもPythonを使った人工知能(AI)や機械学習の基礎を習得できるオンラインブートキャンプPython講座を開催しています。
挫折しない学習方法を知れる説明動画や、現役エンジニアとのビデオ通話とチャットサポート、学習用カリキュラムを体験できる無料体験も実施しているので、ぜひ参加してみてください。