handwritten

Autor: victor
Data: 18/08/2011

This is an example page showing some performance and accuracy results of the Optimum-Path Forest classifier, we are using the MNIST Handwritten digits dataset for the test with no changes. There are 60000 images for training and 10000 images for testing.

In this test we gradually increase the number of training images in order to compare OPF's accuracy and execution time with a RBF-SVM in a situation where we have few data about the problem.

We can see from the graphics below that OPF has a lower processing time [specially in the predicting part] than a RBF-SVM for this problem. Also, its f1-score is slightly better.

  1 from struct import *
  2 from itertools import *
  3 import time
  4 import operator
  5 
  6 import numpy
  7 import libopf_py
  8 
  9 from scikits.learn import datasets, svm, metrics
 10 
 11 #http://yann.lecun.com/exdb/mnist/
 12 
 13 #read idx
 14 def read_idx(f):
 15   idx = open(f, 'rb')
 16 
 17   (_,_,t,ndim) = unpack_from(">BBBB", idx.read(4))
 18 
 19   sizeof_t = {
 20                 0x08: (1, numpy.uint8  , 'B'),
 21                 0x09: (1, numpy.int8   , 'b'),
 22                 0x0B: (2, numpy.int16  , 'h'),
 23                 0x0C: (4, numpy.int32  , 'i'),
 24                 0x0D: (4, numpy.float32, 'f'),
 25                 0x0E: (8, numpy.float64, 'd')
 26               }
 27 
 28   dims = unpack_from('>'+'i'*ndim, idx.read(4*ndim))
 29   nelem = reduce(operator.mul, dims)
 30 
 31   data = numpy.zeros(nelem, dtype=sizeof_t[t][1])
 32   data[:] = unpack_from('>'+('%s' % sizeof_t[t][2])*nelem, idx.read(sizeof_t[t][0]*nelem))
 33 
 34   if len(dims) > 1:
 35     data_r  = data.reshape (dims[0], reduce(operator.mul, dims[1:]))
 36   else:
 37     data_r = data
 38 
 39   idx.close()
 40 
 41   return data_r
 42 
 43 
 44 ###
 45 
 46 
 47 train_label = read_idx(find_attachment_file('classification_datasets/handwritten/train-labels-idx1-ubyte')).astype(numpy.int32)
 48 train_image = read_idx(find_attachment_file('classification_datasets/handwritten/train-images-idx3-ubyte')).astype(numpy.float32) / 255.0
 49 
 50 test_label  = read_idx(find_attachment_file('classification_datasets/handwritten/t10k-labels-idx1-ubyte')).astype(numpy.int32)
 51 test_image  = read_idx(find_attachment_file('classification_datasets/handwritten/t10k-images-idx3-ubyte')).astype(numpy.float32) / 255.0
 52 
 53 
 54 ###
 55 
 56 opf_training_time = []
 57 opf_predicting_time = []
 58 opf_f1 = []
 59 
 60 svm_training_time = []
 61 svm_predicting_time = []
 62 svm_f1 = []
 63 
 64 def run(SIZE):
 65 
 66   rand = numpy.random.permutation(train_label.shape[0]) [:SIZE]
 67 
 68   # OPF
 69   def opf():
 70     image, label = train_image[rand], train_label[rand]
 71 
 72     O = libopf_py.OPF()
 73 
 74     t = time.time()
 75     O.fit(image, label)
 76     opf_training_time.append(time.time()-t)
 77 
 78     t = time.time()
 79     label = O.predict(test_image)
 80     opf_predicting_time.append(time.time()-t)
 81 
 82     opf_f1.append(metrics.f1_score(test_label, label))
 83 
 84   opf()
 85 
 86   # SVM
 87   def _svm():
 88     image, label =  train_image[rand].astype(numpy.float64), train_label[rand].astype(numpy.int64)
 89     clf = svm.SVC()
 90 
 91     t = time.time()
 92     clf.fit(image, label)
 93     svm_training_time.append(time.time()-t)
 94 
 95     t_image = test_image.astype(numpy.float64)
 96 
 97     t = time.time()
 98     label = clf.predict(t_image)
 99     svm_predicting_time.append(time.time()-t)
100 
101     svm_f1.append(metrics.f1_score(test_label, label))
102 
103   _svm()
104 
105 # run
106 S = [10, 50, 100, 1000, 2000, 4000]
107 for s in S:
108   run(s)
109 
110 # plot data
111 
112 mmplot ([[S, opf_training_time,   'OPF training time'],
113          [S, svm_training_time,   'SVM training time']])
114 
115 mmplot ([[S, opf_predicting_time, 'OPF predicting time'],
116          [S, svm_predicting_time, 'SVM predicting time']])
117 
118 mmplot ([[S, opf_f1,              'OPF f1 score'],
119          [S, svm_f1,              'SVM f1 score']])