Module fontai.prediction.custom_mappers
This module contains mapper functions to be applied to Tensorflow examples right after being deserialised to be used at training time; at the moment available functions filter examples based on the score's accuracy and values, and do so for scored font records.
Expand source code
"""
This module contains mapper functions to be applied to Tensorflow examples right after being deserialised to be used at training time; at the moment available functions filter examples based on the score's accuracy and values, and do so for scored font records.
"""
import tensorflow as tf
import typing as t
__all__ = ["drop_misclassified_in_font",
"keep_high_scores_in_font",
"map_to_binary_pixels"]
def drop_misclassified_in_font():
"""Returns a mapper function for Tensorflow datasets that drops misclassified images in a font batch; examples must have the schema as in ScoredLabeledChars._tfr_schema
Returns:
t.Callable: Mapping function for Tensorflow datasets
"""
def f(kwargs: t.Dict):
"""
Args:
kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries.
Returns:
t.Dict: dictionary with filtered features and scores
"""
#if label is empty, do nothings
# if tf.equal(tf.size(kwargs["label"]), 0):
# return kwargs
predicted_label_idx = tf.argmax(kwargs["score"], axis=-1)
predicted_labels = tf.gather(kwargs["charset_tensor"], predicted_label_idx, axis=-1)
index = tf.reshape(predicted_labels == kwargs["label"], (-1,)) #flatten
if tf.equal(tf.reduce_sum(tf.cast(index, dtype=tf.int32)), 0):
kwargs["label"] = tf.zeros((0,),dtype=tf.string) #if no accurate scores are left, pass empty label for downstream deletion
else:
kwargs["label"] = kwargs["label"][index]
kwargs["features"] = kwargs["features"][index]
kwargs["score"] = kwargs["score"][index]
return kwargs
return f
def keep_high_scores_in_font(threshold: float):
"""Returns a function for Tensorflow datasets that drops images with low classification score in a font batch
Args:
threshold (float): Score threshold
Returns:
t.Callable: Mapping function fot Tensorflow datasets
"""
if threshold > 1 or threshold <= 0:
raise ValueError("Threshold value must be in (0,1]")
def f(kwargs: t.Dict):
"""
Args:
kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries.
Returns:
t.Dict: dictionary with filtered features and scores
"""
#if label is empty, d nothings
#if label is empty, do nothings
if tf.equal(tf.size(kwargs["label"]), 0):
return kwargs
index = tf.reshape(tf.reduce_max(kwargs["score"], axis=-1) >= threshold, (-1,))
if tf.equal(tf.reduce_sum(tf.cast(index, dtype=tf.int32)), 0):
kwargs["label"] = tf.zeros((0,),dtype=tf.string) #if no accurate scores are left, pass empty label for downstream deletion
else:
kwargs["label"] = kwargs["label"][index]
kwargs["features"] = kwargs["features"][index]
kwargs["score"] = kwargs["score"][index]
return kwargs
return f
def map_to_binary_pixels():
"""Returns a mapping function to normalise pixels in [0,1] to either 0 or 1
"""
def f(kwargs):
"""
Args:
kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries.
Returns:
t.Dict: dictionary with filtered features and scores
"""
kwargs["features"] = tf.math.round(kwargs["features"])
return kwargs
return f
# def font_chars_to_channels():
# """Reshapes and flips the feature tensor so that font characters become image channels.
# """
# def f(kwargs):
# """
# Args:
# kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries.
# Returns:
# t.Dict: dictionary with mapped features and scores
# """
# features_shape = tf.shape(kwargs["features"])
# n_chars = features_shape[0]
# height = features_shape[1]
# width = features_shape[2]
# kwargs["features"] = tf.transpose(kwargs["features"])
# #print((n_chars, height, width))
# #kwargs["features"] = tf.reshape(kwargs["features"], (n_chars, height, width))
# return kwargs
# return f
Functions
def drop_misclassified_in_font()
-
Returns a mapper function for Tensorflow datasets that drops misclassified images in a font batch; examples must have the schema as in ScoredLabeledChars._tfr_schema
Returns
t.Callable
- Mapping function for Tensorflow datasets
Expand source code
def drop_misclassified_in_font(): """Returns a mapper function for Tensorflow datasets that drops misclassified images in a font batch; examples must have the schema as in ScoredLabeledChars._tfr_schema Returns: t.Callable: Mapping function for Tensorflow datasets """ def f(kwargs: t.Dict): """ Args: kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries. Returns: t.Dict: dictionary with filtered features and scores """ #if label is empty, do nothings # if tf.equal(tf.size(kwargs["label"]), 0): # return kwargs predicted_label_idx = tf.argmax(kwargs["score"], axis=-1) predicted_labels = tf.gather(kwargs["charset_tensor"], predicted_label_idx, axis=-1) index = tf.reshape(predicted_labels == kwargs["label"], (-1,)) #flatten if tf.equal(tf.reduce_sum(tf.cast(index, dtype=tf.int32)), 0): kwargs["label"] = tf.zeros((0,),dtype=tf.string) #if no accurate scores are left, pass empty label for downstream deletion else: kwargs["label"] = kwargs["label"][index] kwargs["features"] = kwargs["features"][index] kwargs["score"] = kwargs["score"][index] return kwargs return f
def keep_high_scores_in_font(threshold: float)
-
Returns a function for Tensorflow datasets that drops images with low classification score in a font batch
Args
threshold
:float
- Score threshold
Returns
t.Callable
- Mapping function fot Tensorflow datasets
Expand source code
def keep_high_scores_in_font(threshold: float): """Returns a function for Tensorflow datasets that drops images with low classification score in a font batch Args: threshold (float): Score threshold Returns: t.Callable: Mapping function fot Tensorflow datasets """ if threshold > 1 or threshold <= 0: raise ValueError("Threshold value must be in (0,1]") def f(kwargs: t.Dict): """ Args: kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries. Returns: t.Dict: dictionary with filtered features and scores """ #if label is empty, d nothings #if label is empty, do nothings if tf.equal(tf.size(kwargs["label"]), 0): return kwargs index = tf.reshape(tf.reduce_max(kwargs["score"], axis=-1) >= threshold, (-1,)) if tf.equal(tf.reduce_sum(tf.cast(index, dtype=tf.int32)), 0): kwargs["label"] = tf.zeros((0,),dtype=tf.string) #if no accurate scores are left, pass empty label for downstream deletion else: kwargs["label"] = kwargs["label"][index] kwargs["features"] = kwargs["features"][index] kwargs["score"] = kwargs["score"][index] return kwargs return f
def map_to_binary_pixels()
-
Returns a mapping function to normalise pixels in [0,1] to either 0 or 1
Expand source code
def map_to_binary_pixels(): """Returns a mapping function to normalise pixels in [0,1] to either 0 or 1 """ def f(kwargs): """ Args: kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries. Returns: t.Dict: dictionary with filtered features and scores """ kwargs["features"] = tf.math.round(kwargs["features"]) return kwargs return f