Nearest Neighbor Computation with CaffeΒΆ
The following is a concrete example of performing a nearest neighbor computation using a set of ten butterfly images. This example has been tested using Caffe version rc2, ) and may work with the master version of Caffe from GitHub.
To generate the required model files image_mean_filepath
and network_model_filepath
,
run the following scripts:
caffe_src/ilsvrc12/get_ilsvrc_aux.sh
caffe_src/scripts/download_model_binary.py ./models/bvlc_reference_caffenet/
Once this is done, the nearest neighbor index for the butterfly images can be built with the following code:
from smqtk.algorithms.nn_index.flann import FlannNearestNeighborsIndex
# Import some butterfly data
urls = ["http://www.comp.leeds.ac.uk/scs6jwks/dataset/leedsbutterfly/examples/{:03d}.jpg".format(i) for i in range(1,11)]
from smqtk.representation.data_element.url_element import DataUrlElement
el = [DataUrlElement(d) for d in urls]
# Create a model. This assumes that you have properly set up a proper Caffe environment for SMQTK
from smqtk.algorithms.descriptor_generator import get_descriptor_generator_impls
cd = get_descriptor_generator_impls()['CaffeDescriptorGenerator'](
network_prototxt_filepath="caffe/models/bvlc_reference_caffenet/deploy.prototxt",
network_model_filepath="caffe/models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel",
image_mean_filepath="caffe/data/ilsvrc12/imagenet_mean.binaryproto",
return_layer="fc7",
batch_size=1,
use_gpu=False,
gpu_device_id=0,
network_is_bgr=True,
data_layer="data",
load_truncated_images=True)
# Set up a factory for our vector (here in-memory storage)
from smqtk.representation.descriptor_element_factory import DescriptorElementFactory
from smqtk.representation.descriptor_element.local_elements import DescriptorMemoryElement
factory = DescriptorElementFactory(DescriptorMemoryElement, {})
# Compute features on the first image
descriptor_iter = cd.generate_elements(el, descr_factory=factory)
index = FlannNearestNeighborsIndex(distance_method="euclidean",
random_seed=42, index_filepath="nn.index",
parameters_filepath="nn.params",
descriptor_cache_filepath="nn.cache")
index.build_index(descriptor_iter)