Source code for smqtk.algorithms.classifier._interface_classifier

import abc

from smqtk.algorithms import SmqtkAlgorithm
from smqtk.utils import (
    bin_utils,
    parallel,
)

from ._defaults import DFLT_CLASSIFIER_FACTORY


[docs]class Classifier (SmqtkAlgorithm): """ Interface for algorithms that classify input descriptors into discrete labels and/or label confidences. """
[docs] def classify(self, d, factory=DFLT_CLASSIFIER_FACTORY, overwrite=False): """ Classify the input descriptor against one or more discrete labels, outputting a ClassificationElement containing the classification result. We return confidence values for each label the configured model contains. Implementations may act in a discrete manner whereby only one label is marked with a ``1`` value (others being ``0``), or in a continuous manner whereby each label is given a confidence-like value in the [0, 1] range. The returned ``ClassificationElement`` will have the same UUID as the input ``DescriptorElement``. :param d: Input descriptor to classify :type d: smqtk.representation.DescriptorElement :param factory: Classification element factory. The default factory yields MemoryClassificationElement instances. :type factory: smqtk.representation.ClassificationElementFactory :param overwrite: Recompute classification of the input descriptor and set the results to the ClassificationElement produced by the factory. :type overwrite: bool :raises ValueError: The given descriptor element did not have a vector to operate on. :raises RuntimeError: Could not perform classification for some reason (see message in raised exception). :return: Classification result element :rtype: smqtk.representation.ClassificationElement """ if not d.has_vector(): raise ValueError("Given DescriptorElement does not contain a " "vector to classify.") c_elem = factory.new_classification(self.name, d.uuid()) if overwrite or not c_elem.has_classifications(): c = self._classify(d) c_elem.set_classification(c) return c_elem
[docs] def classify_async(self, d_iter, factory=DFLT_CLASSIFIER_FACTORY, overwrite=False, procs=None, use_multiprocessing=False, ri=None): """ Asynchronously classify the DescriptorElements in the given iterable. :param d_iter: Iterable of DescriptorElements :type d_iter: collections.Iterable[smqtk.representation.DescriptorElement] :param factory: Classifier element factory to use for element generation. The default factory yields MemoryClassificationElement instances. :type factory: smqtk.representation.ClassificationElementFactory :param overwrite: Recompute classification of the input descriptor and set the results to the ClassificationElement produced by the factory. :type overwrite: bool :param procs: Explicit number of cores/thread/processes to use. :type procs: None | int :param use_multiprocessing: Use multiprocessing instead of threading. :type use_multiprocessing: bool :param ri: Progress reporting interval in seconds. Set to a value > 0 to enable. Disabled by default. :type ri: float | None :return: Mapping of input DescriptorElement instances to the computed ClassificationElement. ClassificationElement UUID's are congruent with the UUID of the DescriptorElement :rtype: dict[smqtk.representation.DescriptorElement, smqtk.representation.ClassificationElement] """ self._log.debug("Async classifying descriptors") ri = ri and ri > 0 and ri def work(d_elem): return d_elem, self.classify(d_elem, factory, overwrite) classifications = parallel.parallel_map( work, d_iter, cores=procs, ordered=False, use_multiprocessing=use_multiprocessing, ) r_state = [0] * 7 if ri: r_progress = bin_utils.report_progress else: def r_progress(*_): return d2c_map = {} for d, c in classifications: d2c_map[d] = c r_progress(self._log.debug, r_state, ri) return d2c_map
# # Abstract methods #
[docs] @abc.abstractmethod def get_labels(self): """ Get the sequence of class labels that this classifier can classify descriptors into. This includes the negative label. :return: Sequence of possible classifier labels. :rtype: collections.Sequence[collections.Hashable] :raises RuntimeError: No model loaded. """
@abc.abstractmethod def _classify(self, d): """ Internal method that constructs the label-to-confidence map (dict) for a given DescriptorElement. The passed descriptor element is guaranteed to have a vector to extract. It is not extracted yet due to the philosophy of waiting until the vector is immediately needed. This moment is thus determined by the implementing algorithm. :param d: DescriptorElement containing the vector to classify. :type d: smqtk.representation.DescriptorElement :raises RuntimeError: Could not perform classification for some reason (see message in raised exception). :return: Dictionary mapping trained labels to classification confidence values :rtype: dict[collections.Hashable, float] """