import abc
import hashlib
import six
from smqtk.algorithms import SmqtkAlgorithm, ImageReader
from smqtk.utils import ContentTypeValidator
from smqtk.utils.configuration import (
make_default_config,
from_config_dict,
to_config_dict,
)
from ._defaults import DFLT_CLASSIFIER_FACTORY, DFLT_DETECTION_FACTORY
[docs]@six.add_metaclass(abc.ABCMeta)
class ObjectDetector (SmqtkAlgorithm, ContentTypeValidator):
"""
Abstract interface to an object detection algorithm.
An object detection algorithm is one that can take in data and output zero
or more detection elements, where each detection represents a spatial
region in the data.
This high level interface only requires detection element returns (spatial
bounding-boxes with associated classification elements).
"""
__slots__ = ()
@staticmethod
def _gen_detection_uuid(data_uuid, bbox, labels):
"""
Local standard for producing the UUID of a DetectionElement based on
parent data, component bounding box and classification labels.
:param str data_uuid:
UUID of parent data element (checksum hash string) this detection
is derived from.
:param smqtk.representation.AxisAlignedBoundingBox bbox:
Detection bounding box instance.
:param labels:
Sequence of string classification labels.
:return: Detection UUID string that is the SHA1 checksum of the
component data.
"""
# noinspection PyStringFormat
# - uses variadic expansion which fools current linter.
hashable = data_uuid + \
'{}{}{}{}'.format(*(bbox.min_vertex.tolist() +
bbox.max_vertex.tolist())) + \
''.join(sorted(map(str, labels)))
return hashlib.sha1(six.b(hashable)).hexdigest()
[docs] def detect_objects(self, data_element,
de_factory=DFLT_DETECTION_FACTORY,
ce_factory=DFLT_CLASSIFIER_FACTORY):
"""
Detect objects in the given data.
UUIDs of detections are based on the hash produced from the combination
of:
- Detection bounding-box bounding coordinates
- Classification label set predicted for a bounding box.
:param smqtk.representation.DataElement data_element:
Source data from which to detect objects within.
:param smqtk.representation.DetectionElementFactory de_factory:
Factory for generating DetectionElement instances. The default
factory yields MemoryClassificationElement instances.
:param smqtk.representation.ClassificationElementFactory ce_factory:
Factory for generating ClassificationElement instances for
detections. The default factory yields MemoryClassificationElement
instances.
:raises ValueError: Given data element content was not of a valid
content type that this class reports as valid for object detection.
:return: Iterator over result DetectionElement instances as generated
by the given DetectionElementFactory, containing classification
elements as generated by the given ClassificationElementFactory.
:rtype: collections.Iterable[smqtk.representation.DetectionElement]
"""
self.raise_valid_element(data_element)
# We know that the UUID of a DataElement should be a checksum of sorts,
# so we can generally assume a string-cast is unique preserving.
de_uuid = str(data_element.uuid())
type_str = 'object detection classification'
for bbox, c_map in self._detect_objects(data_element):
# Determine UUID of detection from bbox and classification labels
det_uuid = self._gen_detection_uuid(de_uuid, bbox, c_map.keys())
ce = ce_factory.new_classification(type_str, det_uuid)
ce.set_classification(c_map)
de = de_factory.new_detection(det_uuid).set_detection(bbox, ce)
yield de
@abc.abstractmethod
def _detect_objects(self, data):
"""
Internal method that defines the generation of paired bounding boxes
and classification maps for detected objects in the given data.
:param smqtk.representation.DataElement data:
Source data (DataElement) from which to detect objects within.
:return: Iterable over paired ``AxisAlignedBoundingBox`` and
classification map for detected objects. The returned
"classification map" should follow the format described by
``smqtk.representation.ClassificationElement``: dictionary where
keys are classification labels and values are classification
probabilities.
:rtype:
collections.Iterator[(smqtk.representation.AxisAlignedBoundingBox,
dict[collections.Hashable, float])]
"""
@six.add_metaclass(abc.ABCMeta)
class ImageMatrixObjectDetector (ObjectDetector):
"""
Class of object detectors that operate over the pixel matrix of an image.
This sub abstract class standardizes the use of an
:class:`smqtk.algorithms.ImageReader` algorithm to read an image file's
pixels as well as determine which image formats are valid input elements.
There is a special exception of :class:`.MatrixDataElement` types as they
directly provide a matrix.
We define an alternate abstract method for implementing classes to define:
``_detect_objects_matrix``. This method is given a numpy ndarray instance
for the implementing class to utilize. The return requirements are the same
as the ``_detect_objects`` method.
"""
@classmethod
def get_default_config(cls):
"""
Generate and return a default configuration dictionary for this class.
This will be primarily used for generating what the configuration
dictionary would look like for this class without instantiating it.
By default, we observe what this class's constructor takes as arguments,
turning those argument names into configuration dictionary keys. If any
of those arguments have defaults, we will add those values into the
configuration dictionary appropriately. The dictionary returned should
only contain JSON compliant value types.
It is not be guaranteed that the configuration dictionary returned
from this method is valid for construction of an instance of this class.
:return: Default configuration dictionary for the class.
:rtype: dict
"""
default = super(ImageMatrixObjectDetector, cls).get_default_config()
default['image_reader'] = make_default_config(ImageReader.get_impls())
return default
@classmethod
def from_config(cls, config_dict, merge_default=True):
"""
Instantiate a new instance of this class given the configuration
JSON-compliant dictionary encapsulating initialization arguments.
This method should not be called via super unless an instance of the
class is desired.
:param config_dict: JSON compliant dictionary encapsulating
a configuration.
:type config_dict: dict
:param merge_default: Merge the given configuration on top of the
default provided by ``get_default_config``.
:type merge_default: bool
:return: Constructed instance from the provided config.
:rtype: ImageMatrixObjectDetector
"""
# Shallow copy
config_dict = dict(config_dict) # type: dict
config_dict['image_reader'] = from_config_dict(
config_dict.get('image_reader', {}), ImageReader.get_impls()
)
return super(ImageMatrixObjectDetector, cls).from_config(
config_dict, merge_default=merge_default
)
def __init__(self, image_reader):
"""
An image matrix object detector must have a method of converting a
DataElement into an image pixel matrix so this interface broadly
requires an :class:`smqtk.algorithms.ImageReader` instance.
:param smqtk.algorithms.ImageReader image_reader:
ImageReader algorithm instance for reading image matrices from
DataElements.
"""
super(ImageMatrixObjectDetector, self).__init__()
self._image_reader = image_reader
@abc.abstractmethod
def get_config(self):
"""
Return a JSON-compliant dictionary that could be passed to this class's
``from_config`` method to produce an instance with identical
configuration.
In the most cases, this involves naming the keys of the dictionary
based on the initialization argument names as if it were to be passed
to the constructor via dictionary expansion. In some cases, where it
doesn't make sense to store some object constructor parameters are
expected to be supplied at as configuration values (i.e. must be
supplied at runtime), this method's returned dictionary may leave those
parameters out. In such cases, the object's ``from_config``
class-method would also take additional positional arguments to fill in
for the parameters that this returned configuration lacks.
:return: JSON type compliant configuration dictionary.
:rtype: dict
"""
return {
'image_reader': to_config_dict(self._image_reader),
}
def valid_content_types(self):
"""
:return: A set valid MIME types that are "valid" within the implementing
class' context.
:rtype: set[str]
"""
return self._image_reader.valid_content_types()
def is_valid_element(self, data_element):
"""
Check if the given DataElement instance reports a content type that
matches one of the MIME types reported by ``valid_content_types``.
This override uses our stored :class:`ImageReader` algorithm instance to
define what :class:`DataElement` instances are valid.
:param smqtk.representation.DataElement data_element:
Data element instance to check.
:return: True if the given element has a valid content type as reported
by ``valid_content_types``, and False if not.
:rtype: bool
"""
return self._image_reader.is_valid_element(data_element)
def _detect_objects(self, data):
"""
Internal method that defines the generation of paired bounding boxes
and classification maps for detected objects in the given data.
This ``ImageMatrixObjectDetector`` implementation ensures that the data
element is converted to a :class:`numpy.ndarray` before passing the result
matrix along to the :func:`_detect_objects_matrix` method for the
implementing class to define.
:param smqtk.representation.DataElement data:
Source data (DataElement) from which to detect objects within.
:return: Iterable over paired ``AxisAlignedBoundingBox`` and
classification map for detected objects. The returned
"classification map" should follow the format described by
``smqtk.representation.ClassificationElement``: dictionary where
keys are classification labels and values are classification
probabilities.
:rtype:
collections.Iterator[(smqtk.representation.AxisAlignedBoundingBox,
dict[collections.Hashable, float])]
"""
return self._detect_objects_matrix(
self._image_reader.load_as_matrix(data)
)
@abc.abstractmethod
def _detect_objects_matrix(self, mat):
"""
Internal method to be implemented that defines the generation of paired
bounding boxes and classification maps for detected objects in the given
image matrix data.
:param numpy.ndarray mat:
Image pixel matrix to detect objects within.
:return: Iterable over paired ``AxisAlignedBoundingBox`` and
classification map for detected objects. The returned
"classification map" should follow the format described by
``smqtk.representation.ClassificationElement``: dictionary where
keys are classification labels and values are classification
probabilities.
:rtype:
collections.Iterator[(smqtk.representation.AxisAlignedBoundingBox,
dict[collections.Hashable, float])]
"""