Module fontai.prediction.custom_filters
This module contains filtering functions for Tensorflow dataset operations that are applied to model inputs right after being deserialised to be used at training time
Expand source code
"""
This module contains filtering functions for Tensorflow dataset operations that are applied to model inputs right after being deserialised to be used at training time
"""
import tensorflow as tf
import typing as t
__all__ = [
"filter_misclassified_chars",
"filter_chars_by_score",
"filter_fonts_by_size",
"filter_irregular_fonts",
"filter_by_name"
]
def filter_misclassified_chars():
"""Returns a filtering function for Tensorflow datasets that filter out misclassified examples; examples must have the schema as in ScoredLabeledChars._tfr_schema
Returns:
t.Callable: Filtering function for Tensorflow datasets
"""
def f(kwargs):
"""
Args:
kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries.
Returns:
boolean
"""
classification_index = tf.argmax(kwargs["score"], axis=-1)
return kwargs["label"] == kwargs["charset_tensor"][classification_index]
return f
def filter_chars_by_score(threshold: float):
"""Returns a Filtering function for Tensorflow datasets that filter out scores lower than a given threshold.
Returns:
t.Callable: Filtering function for Tensorflow datasets
"""
if threshold > 1 or threshold <= 0:
raise ValueError("Threshold value must be in (0,1]")
def f(kwargs):
"""
Args:
kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries.
Returns:
boolean
"""
return tf.reduce_max(kwargs["score"],axis=-1) >= threshold
return f
def filter_fonts_by_size(n_chars: int):
"""Returns a Filtering function for Tensorflow datasets that filter out fonts with too few remaining characters.
Returns:
t.Callable: Filtering function for Tensorflow datasets
"""
if n_chars < 0:
raise ValueError("n_chars must be non-negative")
def f(kwargs):
"""
Args:
kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries.
Returns:
boolean
"""
return tf.size(kwargs["label"]) >= n_chars
#return tf.logical_and(tf.size(kwargs["features"]) > 0, kwargs["features"].shape[0] >= n)
return f
def filter_irregular_fonts(min_score: int):
"""Returns a Filtering function for Tensorflow datasets that filter out fonts with misclassified characters or low-confidence scores.
Args:
min_score (int): minimum score threshold
Returns:
t.Callable: Filtering function for Tensorflow datasets
"""
def f(kwargs):
"""
Args:
kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries.
Returns:
boolean
"""
predicted_label_idx = tf.argmax(kwargs["score"], axis=-1)
predicted_labels = tf.gather(kwargs["charset_tensor"], predicted_label_idx, axis=-1)
predicted_scores = tf.reduce_max(kwargs["score"], axis=-1)
return tf.math.logical_and(tf.math.reduce_all(predicted_scores >= min_score), tf.math.reduce_all(predicted_labels == kwargs["label"]))
return f
def filter_by_name(substring: str):
"""Returns a Filtering function for Tensorflow datasets that filter out fonts whose name does not contain the provided substring, e.g. italic, 3d, etc.
Args:
substring (str): substring that will be searched for in the lowercased font names
Returns:
t.Callable: Filtering function for Tensorflow datasets
"""
lower_substring = substring.lower()
def f(kwargs):
"""
Args:
kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries.
Returns:
boolean
"""
return tf.strings.regex_full_match(tf.strings.lower(kwargs["fontname"]), f".*{lower_substring}.*")
return f
Functions
def filter_by_name(substring: str)
-
Returns a Filtering function for Tensorflow datasets that filter out fonts whose name does not contain the provided substring, e.g. italic, 3d, etc.
Args
substring
:str
- substring that will be searched for in the lowercased font names
Returns
t.Callable
- Filtering function for Tensorflow datasets
Expand source code
def filter_by_name(substring: str): """Returns a Filtering function for Tensorflow datasets that filter out fonts whose name does not contain the provided substring, e.g. italic, 3d, etc. Args: substring (str): substring that will be searched for in the lowercased font names Returns: t.Callable: Filtering function for Tensorflow datasets """ lower_substring = substring.lower() def f(kwargs): """ Args: kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries. Returns: boolean """ return tf.strings.regex_full_match(tf.strings.lower(kwargs["fontname"]), f".*{lower_substring}.*") return f
def filter_chars_by_score(threshold: float)
-
Returns a Filtering function for Tensorflow datasets that filter out scores lower than a given threshold.
Returns
t.Callable
- Filtering function for Tensorflow datasets
Expand source code
def filter_chars_by_score(threshold: float): """Returns a Filtering function for Tensorflow datasets that filter out scores lower than a given threshold. Returns: t.Callable: Filtering function for Tensorflow datasets """ if threshold > 1 or threshold <= 0: raise ValueError("Threshold value must be in (0,1]") def f(kwargs): """ Args: kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries. Returns: boolean """ return tf.reduce_max(kwargs["score"],axis=-1) >= threshold return f
def filter_fonts_by_size(n_chars: int)
-
Returns a Filtering function for Tensorflow datasets that filter out fonts with too few remaining characters.
Returns
t.Callable
- Filtering function for Tensorflow datasets
Expand source code
def filter_fonts_by_size(n_chars: int): """Returns a Filtering function for Tensorflow datasets that filter out fonts with too few remaining characters. Returns: t.Callable: Filtering function for Tensorflow datasets """ if n_chars < 0: raise ValueError("n_chars must be non-negative") def f(kwargs): """ Args: kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries. Returns: boolean """ return tf.size(kwargs["label"]) >= n_chars #return tf.logical_and(tf.size(kwargs["features"]) > 0, kwargs["features"].shape[0] >= n) return f
def filter_irregular_fonts(min_score: int)
-
Returns a Filtering function for Tensorflow datasets that filter out fonts with misclassified characters or low-confidence scores.
Args
min_score
:int
- minimum score threshold
Returns
t.Callable
- Filtering function for Tensorflow datasets
Expand source code
def filter_irregular_fonts(min_score: int): """Returns a Filtering function for Tensorflow datasets that filter out fonts with misclassified characters or low-confidence scores. Args: min_score (int): minimum score threshold Returns: t.Callable: Filtering function for Tensorflow datasets """ def f(kwargs): """ Args: kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries. Returns: boolean """ predicted_label_idx = tf.argmax(kwargs["score"], axis=-1) predicted_labels = tf.gather(kwargs["charset_tensor"], predicted_label_idx, axis=-1) predicted_scores = tf.reduce_max(kwargs["score"], axis=-1) return tf.math.logical_and(tf.math.reduce_all(predicted_scores >= min_score), tf.math.reduce_all(predicted_labels == kwargs["label"])) return f
def filter_misclassified_chars()
-
Returns a filtering function for Tensorflow datasets that filter out misclassified examples; examples must have the schema as in ScoredLabeledChars._tfr_schema
Returns
t.Callable
- Filtering function for Tensorflow datasets
Expand source code
def filter_misclassified_chars(): """Returns a filtering function for Tensorflow datasets that filter out misclassified examples; examples must have the schema as in ScoredLabeledChars._tfr_schema Returns: t.Callable: Filtering function for Tensorflow datasets """ def f(kwargs): """ Args: kwargs (t.Dict): a dictionary with every object parsed from a serialised Tensorflow example, including "features" and "label" entries. Returns: boolean """ classification_index = tf.argmax(kwargs["score"], axis=-1) return kwargs["label"] == kwargs["charset_tensor"][classification_index] return f