Module fontai.prediction.input_processing

This module contains input preprocessing logic that happens right before data is ingested by the model to be trained.

Expand source code
"""
This module contains input preprocessing logic that happens right before data is ingested by the model to be trained.
"""
from __future__ import absolute_import
from collections.abc import Iterable
import os
import logging
import zipfile
import io
import typing as t
import types
from abc import ABC, abstractmethod
from pathlib import Path

import numpy as np
import imageio
import tensorflow as tf
from  tensorflow.python.data.ops.dataset_ops import MapDataset

from tensorflow.data import TFRecordDataset

import fontai.prediction.models as custom_models

logger = logging.getLogger(__name__)

class RecordPreprocessor(object):
  """
  Fetches and processes a list of Tensorflow record files to be consumed by an ML model
  
  Attributes:
      charset (char): string with every character that needs to be extracted
      charset_tensor (tf.Tensor): charset in tensor form
      custom_filters (t.List[types.Callable]): List of custom_filters to apply to training data
      num_classes (int): number of classes in charset
  """

  def __init__(
    self, 
    input_record_class: type,
    charset_tensor: tf.Tensor,
    custom_filters: t.List[t.Callable] = [],
    custom_mappers: t.List[t.Callable] = []):
    """
    Args:
        input_record_class (type): Subclass of `fontai.io.records.TfrWritable` that corresponds to the schema of the input records
        charset_tensor (tf.Tensor): Tensor with one entry per character in the charset under consideration
        custom_filters (t.List[t.Callable], optional): Filtering functions for sets of image tensors and one-hot-encoded labels
        custom_mappers (t.List[t.Callable], optional): Mapping  functions for sets of image tensors and one-hot-encoded labels
    
    """

    self.input_record_class = input_record_class

    self.custom_filters = custom_filters

    self.custom_mappers = custom_mappers

    self.charset_tensor = tf.convert_to_tensor(charset_tensor)


  def fetch(self, dataset: TFRecordDataset, batch_size = 32, training_format=True, buffered_batches = 512, cyclic=True):
    """
    Fetches a list of input Tensorflow record files and prepares them for training or scoring
    
    Args:
        dataset (TFRecordDataset): input data
        batch_size (int): training batch size
        training_format (bool, optional): If True, returns features and a one hot encoded label; otherwise, returns a dict of parsed bytestreams with labels as bytes
        buffered_batches (int, optional): Size of in-memory buffer from which batches are taken
        cyclic (bool, optional): Whether to cycle over the data indefinitely
    
    Returns:
        TFRecordDataset: Dataset ready for model consumption
    """

    # bytes -> dict -> tuple of objs
    dataset = dataset\
      .map(self.input_record_class.from_tf_example)\
      .map(self.input_record_class.parse_bytes_dict)
        
    # if for training, take only features and formatted labels, and batch together
    if training_format:

      # apply custom filters to formatted tuples
      for example_filter in self.custom_filters:
          dataset = dataset.filter(example_filter)
          
      # apply custom map to formatted tuples
      for example_mapper in self.custom_mappers:
          dataset = dataset.map(example_mapper)

      dataset = dataset\
        .map(self.input_record_class.get_training_parser(charset_tensor = self.charset_tensor))\
        .filter(self.label_is_nonempty) #enmpty labels signal something went wrong while parsing

      dataset = self.scramble(dataset, batch_size, buffered_batches, cyclic)

      if batch_size is not None:
        dataset = dataset.batch(batch_size) 

      dataset = self.add_batch_shape_signature(dataset)

    else:
      dataset = self.input_record_class.filter_charset_for_scoring(dataset, self.charset_tensor)
      # split record dictionary for batching and filter out empty examples
      dataset = dataset.map(self.split_parsed_dict)\
      .filter(self.label_is_nonempty)

      unbatchable = self.input_record_class._nonbatched_scoring

      if unbatchable:
        logger.warning(f"records of class {self.input_record_class.__name__} aren't batchable at scoring time; setting batch size to None.")

      if batch_size is not None and not unbatchable:
        dataset = dataset.batch(batch_size) 
    
    return dataset

  def scramble(self, dataset, batch_size, buffered_batches = 512, cyclic=True):
    """
    Scrambles a data set randomly and makes it unbounded in order to process an arbitrary number of batches
    
    Args:
        dataset (TFRecordDataset): Input dataset
        batch_size (int): training batch size
        buffered_batches (int, optional): Number of batches to fetch in memory buffer
    
    Returns:
        TFRecordDataset
    """

    buffer_size = buffered_batches*batch_size if batch_size is not None else 2048
    dataset = dataset.shuffle(buffer_size=buffer_size)

    if cyclic:
      dataset = dataset.repeat()

    return dataset

  def label_is_nonempty(self, features, label, *args):
    """
    Filters out training examples without rows or incorrectly formatted labels
    
    Args:
        features (tf.Tensor)
        labels (tf.Tensor)
        args: other arguments
    
    Returns:
        Tensor
    """
    return tf.math.logical_not(tf.equal(tf.size(label), 0))

  def add_batch_shape_signature(self, data: TFRecordDataset) -> TFRecordDataset:
    """Intermediate method required to make training data shapes known at graph compile time. Returns the passed data wrapped in a callable object with explicit output shape signatures
    
    Args:
        data (TFRecordDataset): Input training data
    
    Returns:
        TFRecordDataset
    
    Raises:
        ValueError
    """
    def callable_data():
      return data

    features, labels = next(iter(data))
    # drop batch size form shape tuples
    ftr_shape = features.shape[1::]
    lbl_shape = labels.shape[1::]

    # if len(ftr_shape) != 3 or len(lbl_shape) != 1:
    #   raise ValueError(f"Input shapes don't match expected: got shapes {features.shape} and {labels.shape}")

    training_data = tf.data.Dataset.from_generator(
      callable_data, 
      output_types = (
        features.dtype, 
        labels.dtype
      ),
      output_shapes=(
        tf.TensorShape((None,) + ftr_shape),
        tf.TensorShape((None,) + lbl_shape)
      )
    )

    return training_data

  def split_parsed_dict(self, parsed_dict: t.Dict):
    """Split a parsed record dictionary into features, labels and fontname
    
    Args:
        parsed_dict (t.Dict): parsed record dictionary
    
    Returns:
        t.Tuple[tf.Tensor, tf.Tensor, tf.Tensor]
    """
    return parsed_dict["features"], parsed_dict["label"], parsed_dict["fontname"]

Classes

class RecordPreprocessor (input_record_class: type, charset_tensor: tensorflow.python.framework.ops.Tensor, custom_filters: List[Callable] = [], custom_mappers: List[Callable] = [])

Fetches and processes a list of Tensorflow record files to be consumed by an ML model

Attributes

charset : char
string with every character that needs to be extracted
charset_tensor : tf.Tensor
charset in tensor form
custom_filters : t.List[types.Callable]
List of custom_filters to apply to training data
num_classes : int
number of classes in charset

Args

input_record_class : type
Subclass of TfrWritable that corresponds to the schema of the input records
charset_tensor : tf.Tensor
Tensor with one entry per character in the charset under consideration
custom_filters : t.List[t.Callable], optional
Filtering functions for sets of image tensors and one-hot-encoded labels
custom_mappers : t.List[t.Callable], optional
Mapping functions for sets of image tensors and one-hot-encoded labels
Expand source code
class RecordPreprocessor(object):
  """
  Fetches and processes a list of Tensorflow record files to be consumed by an ML model
  
  Attributes:
      charset (char): string with every character that needs to be extracted
      charset_tensor (tf.Tensor): charset in tensor form
      custom_filters (t.List[types.Callable]): List of custom_filters to apply to training data
      num_classes (int): number of classes in charset
  """

  def __init__(
    self, 
    input_record_class: type,
    charset_tensor: tf.Tensor,
    custom_filters: t.List[t.Callable] = [],
    custom_mappers: t.List[t.Callable] = []):
    """
    Args:
        input_record_class (type): Subclass of `fontai.io.records.TfrWritable` that corresponds to the schema of the input records
        charset_tensor (tf.Tensor): Tensor with one entry per character in the charset under consideration
        custom_filters (t.List[t.Callable], optional): Filtering functions for sets of image tensors and one-hot-encoded labels
        custom_mappers (t.List[t.Callable], optional): Mapping  functions for sets of image tensors and one-hot-encoded labels
    
    """

    self.input_record_class = input_record_class

    self.custom_filters = custom_filters

    self.custom_mappers = custom_mappers

    self.charset_tensor = tf.convert_to_tensor(charset_tensor)


  def fetch(self, dataset: TFRecordDataset, batch_size = 32, training_format=True, buffered_batches = 512, cyclic=True):
    """
    Fetches a list of input Tensorflow record files and prepares them for training or scoring
    
    Args:
        dataset (TFRecordDataset): input data
        batch_size (int): training batch size
        training_format (bool, optional): If True, returns features and a one hot encoded label; otherwise, returns a dict of parsed bytestreams with labels as bytes
        buffered_batches (int, optional): Size of in-memory buffer from which batches are taken
        cyclic (bool, optional): Whether to cycle over the data indefinitely
    
    Returns:
        TFRecordDataset: Dataset ready for model consumption
    """

    # bytes -> dict -> tuple of objs
    dataset = dataset\
      .map(self.input_record_class.from_tf_example)\
      .map(self.input_record_class.parse_bytes_dict)
        
    # if for training, take only features and formatted labels, and batch together
    if training_format:

      # apply custom filters to formatted tuples
      for example_filter in self.custom_filters:
          dataset = dataset.filter(example_filter)
          
      # apply custom map to formatted tuples
      for example_mapper in self.custom_mappers:
          dataset = dataset.map(example_mapper)

      dataset = dataset\
        .map(self.input_record_class.get_training_parser(charset_tensor = self.charset_tensor))\
        .filter(self.label_is_nonempty) #enmpty labels signal something went wrong while parsing

      dataset = self.scramble(dataset, batch_size, buffered_batches, cyclic)

      if batch_size is not None:
        dataset = dataset.batch(batch_size) 

      dataset = self.add_batch_shape_signature(dataset)

    else:
      dataset = self.input_record_class.filter_charset_for_scoring(dataset, self.charset_tensor)
      # split record dictionary for batching and filter out empty examples
      dataset = dataset.map(self.split_parsed_dict)\
      .filter(self.label_is_nonempty)

      unbatchable = self.input_record_class._nonbatched_scoring

      if unbatchable:
        logger.warning(f"records of class {self.input_record_class.__name__} aren't batchable at scoring time; setting batch size to None.")

      if batch_size is not None and not unbatchable:
        dataset = dataset.batch(batch_size) 
    
    return dataset

  def scramble(self, dataset, batch_size, buffered_batches = 512, cyclic=True):
    """
    Scrambles a data set randomly and makes it unbounded in order to process an arbitrary number of batches
    
    Args:
        dataset (TFRecordDataset): Input dataset
        batch_size (int): training batch size
        buffered_batches (int, optional): Number of batches to fetch in memory buffer
    
    Returns:
        TFRecordDataset
    """

    buffer_size = buffered_batches*batch_size if batch_size is not None else 2048
    dataset = dataset.shuffle(buffer_size=buffer_size)

    if cyclic:
      dataset = dataset.repeat()

    return dataset

  def label_is_nonempty(self, features, label, *args):
    """
    Filters out training examples without rows or incorrectly formatted labels
    
    Args:
        features (tf.Tensor)
        labels (tf.Tensor)
        args: other arguments
    
    Returns:
        Tensor
    """
    return tf.math.logical_not(tf.equal(tf.size(label), 0))

  def add_batch_shape_signature(self, data: TFRecordDataset) -> TFRecordDataset:
    """Intermediate method required to make training data shapes known at graph compile time. Returns the passed data wrapped in a callable object with explicit output shape signatures
    
    Args:
        data (TFRecordDataset): Input training data
    
    Returns:
        TFRecordDataset
    
    Raises:
        ValueError
    """
    def callable_data():
      return data

    features, labels = next(iter(data))
    # drop batch size form shape tuples
    ftr_shape = features.shape[1::]
    lbl_shape = labels.shape[1::]

    # if len(ftr_shape) != 3 or len(lbl_shape) != 1:
    #   raise ValueError(f"Input shapes don't match expected: got shapes {features.shape} and {labels.shape}")

    training_data = tf.data.Dataset.from_generator(
      callable_data, 
      output_types = (
        features.dtype, 
        labels.dtype
      ),
      output_shapes=(
        tf.TensorShape((None,) + ftr_shape),
        tf.TensorShape((None,) + lbl_shape)
      )
    )

    return training_data

  def split_parsed_dict(self, parsed_dict: t.Dict):
    """Split a parsed record dictionary into features, labels and fontname
    
    Args:
        parsed_dict (t.Dict): parsed record dictionary
    
    Returns:
        t.Tuple[tf.Tensor, tf.Tensor, tf.Tensor]
    """
    return parsed_dict["features"], parsed_dict["label"], parsed_dict["fontname"]

Methods

def add_batch_shape_signature(self, data: tensorflow.python.data.ops.readers.TFRecordDatasetV2) ‑> tensorflow.python.data.ops.readers.TFRecordDatasetV2

Intermediate method required to make training data shapes known at graph compile time. Returns the passed data wrapped in a callable object with explicit output shape signatures

Args

data : TFRecordDataset
Input training data

Returns

TFRecordDataset

Raises

ValueError

Expand source code
def add_batch_shape_signature(self, data: TFRecordDataset) -> TFRecordDataset:
  """Intermediate method required to make training data shapes known at graph compile time. Returns the passed data wrapped in a callable object with explicit output shape signatures
  
  Args:
      data (TFRecordDataset): Input training data
  
  Returns:
      TFRecordDataset
  
  Raises:
      ValueError
  """
  def callable_data():
    return data

  features, labels = next(iter(data))
  # drop batch size form shape tuples
  ftr_shape = features.shape[1::]
  lbl_shape = labels.shape[1::]

  # if len(ftr_shape) != 3 or len(lbl_shape) != 1:
  #   raise ValueError(f"Input shapes don't match expected: got shapes {features.shape} and {labels.shape}")

  training_data = tf.data.Dataset.from_generator(
    callable_data, 
    output_types = (
      features.dtype, 
      labels.dtype
    ),
    output_shapes=(
      tf.TensorShape((None,) + ftr_shape),
      tf.TensorShape((None,) + lbl_shape)
    )
  )

  return training_data
def fetch(self, dataset: tensorflow.python.data.ops.readers.TFRecordDatasetV2, batch_size=32, training_format=True, buffered_batches=512, cyclic=True)

Fetches a list of input Tensorflow record files and prepares them for training or scoring

Args

dataset : TFRecordDataset
input data
batch_size : int
training batch size
training_format : bool, optional
If True, returns features and a one hot encoded label; otherwise, returns a dict of parsed bytestreams with labels as bytes
buffered_batches : int, optional
Size of in-memory buffer from which batches are taken
cyclic : bool, optional
Whether to cycle over the data indefinitely

Returns

TFRecordDataset
Dataset ready for model consumption
Expand source code
def fetch(self, dataset: TFRecordDataset, batch_size = 32, training_format=True, buffered_batches = 512, cyclic=True):
  """
  Fetches a list of input Tensorflow record files and prepares them for training or scoring
  
  Args:
      dataset (TFRecordDataset): input data
      batch_size (int): training batch size
      training_format (bool, optional): If True, returns features and a one hot encoded label; otherwise, returns a dict of parsed bytestreams with labels as bytes
      buffered_batches (int, optional): Size of in-memory buffer from which batches are taken
      cyclic (bool, optional): Whether to cycle over the data indefinitely
  
  Returns:
      TFRecordDataset: Dataset ready for model consumption
  """

  # bytes -> dict -> tuple of objs
  dataset = dataset\
    .map(self.input_record_class.from_tf_example)\
    .map(self.input_record_class.parse_bytes_dict)
      
  # if for training, take only features and formatted labels, and batch together
  if training_format:

    # apply custom filters to formatted tuples
    for example_filter in self.custom_filters:
        dataset = dataset.filter(example_filter)
        
    # apply custom map to formatted tuples
    for example_mapper in self.custom_mappers:
        dataset = dataset.map(example_mapper)

    dataset = dataset\
      .map(self.input_record_class.get_training_parser(charset_tensor = self.charset_tensor))\
      .filter(self.label_is_nonempty) #enmpty labels signal something went wrong while parsing

    dataset = self.scramble(dataset, batch_size, buffered_batches, cyclic)

    if batch_size is not None:
      dataset = dataset.batch(batch_size) 

    dataset = self.add_batch_shape_signature(dataset)

  else:
    dataset = self.input_record_class.filter_charset_for_scoring(dataset, self.charset_tensor)
    # split record dictionary for batching and filter out empty examples
    dataset = dataset.map(self.split_parsed_dict)\
    .filter(self.label_is_nonempty)

    unbatchable = self.input_record_class._nonbatched_scoring

    if unbatchable:
      logger.warning(f"records of class {self.input_record_class.__name__} aren't batchable at scoring time; setting batch size to None.")

    if batch_size is not None and not unbatchable:
      dataset = dataset.batch(batch_size) 
  
  return dataset
def label_is_nonempty(self, features, label, *args)

Filters out training examples without rows or incorrectly formatted labels

Args

features (tf.Tensor)
labels (tf.Tensor)
args
other arguments

Returns

Tensor

Expand source code
def label_is_nonempty(self, features, label, *args):
  """
  Filters out training examples without rows or incorrectly formatted labels
  
  Args:
      features (tf.Tensor)
      labels (tf.Tensor)
      args: other arguments
  
  Returns:
      Tensor
  """
  return tf.math.logical_not(tf.equal(tf.size(label), 0))
def scramble(self, dataset, batch_size, buffered_batches=512, cyclic=True)

Scrambles a data set randomly and makes it unbounded in order to process an arbitrary number of batches

Args

dataset : TFRecordDataset
Input dataset
batch_size : int
training batch size
buffered_batches : int, optional
Number of batches to fetch in memory buffer

Returns

TFRecordDataset

Expand source code
def scramble(self, dataset, batch_size, buffered_batches = 512, cyclic=True):
  """
  Scrambles a data set randomly and makes it unbounded in order to process an arbitrary number of batches
  
  Args:
      dataset (TFRecordDataset): Input dataset
      batch_size (int): training batch size
      buffered_batches (int, optional): Number of batches to fetch in memory buffer
  
  Returns:
      TFRecordDataset
  """

  buffer_size = buffered_batches*batch_size if batch_size is not None else 2048
  dataset = dataset.shuffle(buffer_size=buffer_size)

  if cyclic:
    dataset = dataset.repeat()

  return dataset
def split_parsed_dict(self, parsed_dict: Dict[~KT, ~VT])

Split a parsed record dictionary into features, labels and fontname

Args

parsed_dict : t.Dict
parsed record dictionary

Returns

t.Tuple[tf.Tensor, tf.Tensor, tf.Tensor]

Expand source code
def split_parsed_dict(self, parsed_dict: t.Dict):
  """Split a parsed record dictionary into features, labels and fontname
  
  Args:
      parsed_dict (t.Dict): parsed record dictionary
  
  Returns:
      t.Tuple[tf.Tensor, tf.Tensor, tf.Tensor]
  """
  return parsed_dict["features"], parsed_dict["label"], parsed_dict["fontname"]