Module fontai.prediction.models

Expand source code
import sys
import re
import io
import tensorflow as tf
import json
import logging
import copy
import matplotlib.pyplot as plt
import numpy as np

from fontai.io.storage import BytestreamPath

logger = logging.getLogger(__name__)

class CharStyleSAAE(tf.keras.Model):

  """This class fits a supervised adversarial autoencoder and its inspired in the architecture from "Adversarial Autoencoders" by Ian Goodfellow et al. The only difference is that label (i.e. character) information is injected between the encoder and the style embedding, in the hope that labels not only help the decoding but also the encoding process, e.g. curvyness shouldn't be as important in the input if its a C than if its an H.
  
  Attributes:
      accuracy_metric (tf.keras.metrics.Accuracy): Accuracy metric
      cross_entropy (tf.keras.losses.BinaryCrossentropy): Cross entropy loss
      decoder (tf.keras.Model): Decoder model that maps style and characters to images
      prior_discriminator (tf.keras.Model): Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
      full_encoder (tf.keras.Model): Encoder that takes high-level image features and labels to produce embedded representations
      image_encoder (tf.keras.Model): Encoder for image features
      input_dim (t.Tuple[int]): Input dimension
      mse_loss (TYPE): Description
      mse_metric (tf.keras.losses.MSE): MSE loss
      prior_batch_size (int): Batch size from prior distribution at training time
      rec_loss_weight (float): Weight of reconstruction loss at training time. Should be between 0 and 1.
  """

  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False) #assumes prior_discriminator outputs probabilities(i.e. in [0, 1])

  mse_metric = tf.keras.metrics.MeanSquaredError(name="Reconstruction error")
  prior_accuracy_metric = tf.keras.metrics.Accuracy(name="prior adversarial accuracy")
  cross_entropy_metric = tf.keras.metrics.BinaryCrossentropy(name="Prior adversarial cross entropy", from_logits=False)

  def __init__(
    self,
    full_encoder: tf.keras.Model,
    image_encoder: tf.keras.Model,
    decoder: tf.keras.Model,
    prior_discriminator: tf.keras.Model,
    reconstruction_loss_weight:float=0.5,
    prior_batch_size:int=32):
    """Summary
    
    Args:
        decoder (tf.keras.Model): Decoder model that maps style and characters to images
        prior_discriminator (tf.keras.Model): Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
        full_encoder (tf.keras.Model): Encoder that takes high-level image features and labels to produce embedded representations
        image_encoder (tf.keras.Model): Encoder for image features
        reconstruction_loss_weight (float, optional): Weight of reconstruction loss at training time. Should be between 0 and 1.
        n_classes (int): number of labeled classes
        prior_batch_size (int): Batch size from prior distribution at training time
    """
    super(CharStyleSAAE, self).__init__()

    #encoder.build(input_shape=input_dim)
    #decoder.build(input_shape=(None,n_classes+code_dim))
    #prior_discriminator.build(input_shape=(None,code_dim))

    self.full_encoder = full_encoder
    self.image_encoder = image_encoder
    self.decoder = decoder
    self.prior_discriminator = prior_discriminator
    self.rec_loss_weight = min(max(reconstruction_loss_weight,0),1)
    self.prior_batch_size = prior_batch_size

    self.prior_sampler = tf.random.normal
    
    # list of embedded models as instance attributes 
    self.model_list = ["full_encoder", "image_encoder", "decoder", "prior_discriminator"]

  def compile(self,
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    **kwargs):

    self.full_encoder.compile(optimizer = copy.deepcopy(optimizer))
    self.image_encoder.compile(optimizer = copy.deepcopy(optimizer))
    self.decoder.compile(optimizer = copy.deepcopy(optimizer))
    self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer))

    super().compile(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      loss_weights=loss_weights,
      weighted_metrics=weighted_metrics,
      run_eagerly=run_eagerly,
      **kwargs)

  def prior_discriminator_loss(self,real,fake):
    real_loss = self.cross_entropy(tf.ones_like(real), real)
    fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss)/(2*self.prior_batch_size)

  def __call__(self, x, training=True, mask=None):
    return self.image_encoder(x, training=training)

  def train_step(self, inputs):
    x, labels = inputs

    #self.prior_batch_size = x.shape[0]
    #logger.info("prior_batch_size is deprecated; setting it equal to batch size.")

    with tf.GradientTape(persistent=True) as tape:

      # apply autoencoder
      image_precode = self.image_encoder(x, training=True)
      full_precode = tf.concat([image_precode, labels], axis=-1)
      code = self.full_encoder(full_precode, training=True)
      extended_code = tf.concat([code,labels],axis=-1)
      decoded = self.decoder(extended_code,training=True)  

      # apply prior_discriminator model
      #print((self.prior_batch_size,code.shape[1]))
      prior_samples = self.prior_sampler(shape=[self.prior_batch_size,code.shape[1]])
      real = self.prior_discriminator(prior_samples,training=True)
      fake = self.prior_discriminator(code,training=True)

      # compute losses for the models
      reconstruction_loss = tf.keras.losses.MSE(x,decoded)
      classification_loss = self.prior_discriminator_loss(real,fake)
      mixed_loss = -(1-self.rec_loss_weight)*classification_loss + self.rec_loss_weight*reconstruction_loss

    # Compute gradients
    discr_gradients = tape.gradient(classification_loss,self.prior_discriminator.trainable_variables)
    decoder_gradients = tape.gradient(reconstruction_loss, self.decoder.trainable_variables)
    image_encoder_gradients = tape.gradient(mixed_loss, self.image_encoder.trainable_variables)
    full_encoder_gradients = tape.gradient(mixed_loss, self.full_encoder.trainable_variables)


    #apply gradients
    self.prior_discriminator.optimizer.apply_gradients(zip(discr_gradients,self.prior_discriminator.trainable_variables))
    self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
    self.image_encoder.optimizer.apply_gradients(zip(image_encoder_gradients,self.image_encoder.trainable_variables))
    self.full_encoder.optimizer.apply_gradients(zip(full_encoder_gradients,self.full_encoder.trainable_variables))

    # compute metrics
    self.mse_metric.update_state(x,decoded)

    discr_true = tf.concat([tf.ones_like(real),tf.zeros_like(fake)],axis=0)
    discr_predicted = tf.concat([real,fake],axis=0)
    self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

    #self.cross_entropy_metric.update_state(discr_true, discr_predicted)

    return {"reconstruction MSE": self.mse_metric.result(), "discriminator accuracy": self.prior_accuracy_metric.result()}

  @property
  def metrics(self):
    """Performance metrics to report at training time
    
    Returns: A list of metric objects

    """
    return [self.mse_metric, self.prior_accuracy_metric, self.cross_entropy_metric]

  def predict(self, *args, **kwargs):
    return self.image_encoder.predict(*args,**kwargs)


  def save(self,output_dir: str):
    """Save the model to an output folder
    
    Args:
        output_dir (str): Target output folder
    """
    self.full_encoder.save(str(BytestreamPath(output_dir) / "full_encoder"))
    self.image_encoder.save(str(BytestreamPath(output_dir) / "image_encoder"))
    self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
    self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))

    d = {
      "reconstruction_loss_weight":self.rec_loss_weight,
      "prior_batch_size": self.prior_batch_size
    }

    (BytestreamPath(output_dir) / "aae-params.json").write_bytes(json.dumps(d).encode())

    # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
    #   json.dump(d,f)

  @classmethod
  def load(cls, input_dir: str):
    """Loads a saved instance of this class
    
    Args:
        input_dir (str): Target input folder
    
    Returns:
        SAAE: Loaded model
    """
    full_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "full_encoder"))
    image_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_encoder"))
    decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
    prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))

    d_string = (BytestreamPath(input_dir) / "aae-params.json").read_bytes().decode("utf-8")
    d = json.loads(d_string)

    # with open(str(BytestreamPath(input_dir) / "aae-params.json"),"r") as f:
    #   d = json.loads(f.read())

    return cls(
      image_encoder = image_encoder, 
      full_encoder = full_encoder, 
      decoder = decoder, 
      prior_discriminator = prior_discriminator, 
      **d)


















class PureCharStyleSAAE(tf.keras.Model):

  """This model is trained as a regular SAAE but an additional discriminator model is added to ensure the embedding does not retain information about the character class; i.e. it only retains style information
  """
  mean_metric = tf.keras.metrics.Mean(name="Mean code variance")
  char_accuracy_metric = tf.keras.metrics.Accuracy(name="style-char accuracy")
  prior_accuracy_metric = tf.keras.metrics.Accuracy(name="prior adversarial accuracy")
  mse_metric = tf.keras.metrics.MeanSquaredError(name="Reconstruction error")
  cross_entropy_metric = tf.keras.metrics.BinaryCrossentropy(name="Prior adversarial cross entropy", from_logits=False)


  style_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False) #assumes prior_discriminator outputs 

    ## this __init__ has to be copy pasted from the model above, because Tensorflow hates good coding practices aparently
  def __init__(
    self,
    full_encoder: tf.keras.Model,
    image_encoder: tf.keras.Model,
    decoder: tf.keras.Model,
    char_discriminator: tf.keras.Model,
    prior_discriminator: tf.keras.Model,
    reconstruction_loss_weight:float=0.5,
    prior_batch_size:int=32):
    """Summary
    
    Args:
        decoder (tf.keras.Model): Decoder model that maps style and characters to images
        prior_discriminator (tf.keras.Model): Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
        full_encoder (tf.keras.Model): Encoder that takes high-level image features and labels to produce embedded representations
        char_discriminator (tf.keras.Model): Discriminator to remove any character information from embeddings
        image_encoder (tf.keras.Model): Encoder for image features
        reconstruction_loss_weight (float, optional): Weight of reconstruction loss at training time. Should be between 0 and 1.
        n_classes (int): number of labeled classes
        prior_batch_size (int): Batch size from prior distribution at training time
    """
    super(PureCharStyleSAAE, self).__init__()

    self.full_encoder = full_encoder
    self.image_encoder = image_encoder
    self.decoder = decoder
    self.prior_discriminator = prior_discriminator
    self.rec_loss_weight = min(max(reconstruction_loss_weight,0),1)
    self.prior_batch_size = prior_batch_size
    self.char_discriminator = char_discriminator

    self.prior_sampler = tf.random.normal
    # list of embedded models as instance attributes 
    self.model_list = ["full_encoder", "image_encoder", "decoder", "prior_discriminator", "char_discriminator"]

  def prior_discriminator_loss(self,real,fake):
    real_loss = self.cross_entropy(tf.ones_like(real), real)
    fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss)/(2*self.prior_batch_size)

  def compile(self,
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    **kwargs):

    self.full_encoder.compile(optimizer = copy.deepcopy(optimizer))
    self.image_encoder.compile(optimizer = copy.deepcopy(optimizer))
    self.decoder.compile(optimizer = copy.deepcopy(optimizer))
    self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer))
    self.char_discriminator.compile(optimizer=copy.deepcopy(optimizer))

    super().compile(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      loss_weights=loss_weights,
      weighted_metrics=weighted_metrics,
      run_eagerly=run_eagerly,
      **kwargs)

  def __call__(self, x, training=True, mask=None):
    return self.image_encoder(x, training=training)

  def train_step(self, inputs):
    #prior_sampler = tf.random.normal
    x, labels = inputs

    #self.prior_batch_size = x.shape[0]
    #logger.info("prior_batch_size is deprecated; setting it equal to batch size.")

    with tf.GradientTape(persistent=True) as tape:

      # apply autoencoder
      image_precode = self.image_encoder(x, training=True)
      full_precode = tf.concat([image_precode, labels], axis=-1)
      code = self.full_encoder(full_precode, training=True)
      extended_code = tf.concat([code,labels],axis=-1)
      decoded = self.decoder(extended_code,training=True)  

      # apply prior_discriminator model
      prior_samples = self.prior_sampler(shape=(self.prior_batch_size,code.shape[1]))
      real = self.prior_discriminator(prior_samples,training=True)
      fake = self.prior_discriminator(code,training=True)

      # apply char_discriminator model
      char_guess = self.char_discriminator(code,training=True)

      # compute losses for the models
      char_loss = self.style_loss(labels, char_guess)/self.prior_batch_size
      reconstruction_loss = tf.keras.losses.MSE(x,decoded)
      classification_loss = self.prior_discriminator_loss(real,fake)

      mixed_loss = -(1-self.rec_loss_weight)*(classification_loss + char_loss) + self.rec_loss_weight*(reconstruction_loss)

    # Compute gradients
    discr_gradients = tape.gradient(classification_loss,self.prior_discriminator.trainable_variables)
    decoder_gradients = tape.gradient(reconstruction_loss, self.decoder.trainable_variables)
    image_encoder_gradients = tape.gradient(mixed_loss, self.image_encoder.trainable_variables)
    full_encoder_gradients = tape.gradient(mixed_loss, self.full_encoder.trainable_variables)
    style_discr_gradients = tape.gradient(char_loss, self.char_discriminator.trainable_variables)

    #apply gradients
    self.prior_discriminator.optimizer.apply_gradients(zip(discr_gradients,self.prior_discriminator.trainable_variables))
    self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
    self.image_encoder.optimizer.apply_gradients(zip(image_encoder_gradients,self.image_encoder.trainable_variables))
    self.full_encoder.optimizer.apply_gradients(zip(full_encoder_gradients,self.full_encoder.trainable_variables))
    self.char_discriminator.optimizer.apply_gradients(
      zip(style_discr_gradients, self.char_discriminator.trainable_variables))

    # compute metrics
    self.mse_metric.update_state(x,decoded)

    discr_true = tf.concat([tf.ones_like(real),tf.zeros_like(fake)],axis=0)
    discr_predicted = tf.concat([real,fake],axis=0)
    self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

    #self.cross_entropy_metric.update_state(discr_true, discr_predicted)


    self.char_accuracy_metric.update_state(tf.argmax(labels, axis=-1), tf.argmax(char_guess, axis=-1))

    return {
    "reconstruction MSE": self.mse_metric.result(), 
    "discriminator accuracy": self.prior_accuracy_metric.result(), 
    "sstyle discriminator accuracy": self.char_accuracy_metric.result()}



  def save(self,output_dir: str):
    """Save the model to an output folder
    
    Args:
        output_dir (str): Target output folder
    """

    self.char_discriminator.save(str(BytestreamPath(output_dir) / "char_discriminator"))
    self.full_encoder.save(str(BytestreamPath(output_dir) / "full_encoder"))
    self.image_encoder.save(str(BytestreamPath(output_dir) / "image_encoder"))
    self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
    self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))

    d = {
      "reconstruction_loss_weight":self.rec_loss_weight,
      "prior_batch_size": self.prior_batch_size
    }

    # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
    #   json.dump(d,f)
    (BytestreamPath(output_dir) / "aae-params.json").write_bytes(json.dumps(d).encode())

  @classmethod
  def load(cls, input_dir: str):
    """Loads a saved instance of this class
    
    Args:
        input_dir (str): Target input folder
    
    Returns:
        SAAE: Loaded model
    """
    char_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "char_discriminator"))
    full_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "full_encoder"))
    image_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_encoder"))
    decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
    prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))

    # with open(str(BytestreamPath(input_dir) / "aae-params.json"),"r") as f:
    #   d = json.loads(f.read())

    d_string = (BytestreamPath(input_dir) / "aae-params.json").read_bytes().decode("utf-8")
    d = json.loads(d_string)

    return cls(
      image_encoder = image_encoder, 
      full_encoder = full_encoder, 
      decoder = decoder, 
      prior_discriminator = prior_discriminator, 
      char_discriminator = char_discriminator,
      **d)







class PureFontStyleSAAE(tf.keras.Model):

  """This model is trained on all characters from one or more font files at a time; the aim is to encode the font's style as opposed to single characters' styles, which can happen when training with scrambled characters from different fonts and results in sometimes having different-looking image styles for a given style in latent space. This model works with characters from a single typeface at a time, and use the style from a given character to decode a different randomly chosen character in the same font, using the encoded style and target label information. 
  """
  prior_accuracy_metric = tf.keras.metrics.Accuracy(name="prior adversarial accuracy")
  mse_metric = tf.keras.metrics.MeanSquaredError(name="Reconstruction error")
  cross_entropy_metric = tf.keras.metrics.BinaryCrossentropy(name="Prior adversarial cross entropy", from_logits=False)

  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False) #assumes prior_discriminator outputs 

    ## this __init__ has to be copy pasted from the model above, because Tensorflow hates good coding practices aparently
  def __init__(
    self,
    full_encoder: tf.keras.Model,
    image_encoder: tf.keras.Model,
    decoder: tf.keras.Model,
    prior_discriminator: tf.keras.Model,
    reconstruction_loss_weight:float=0.5,
    prior_batch_size:int=32,
    code_regularisation_weight=0):
    """Summary
    
    Args:
        decoder (tf.keras.Model): Decoder model that maps style and characters to images
        prior_discriminator (tf.keras.Model): Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
        full_encoder (tf.keras.Model): Encoder that takes high-level image features and labels to produce embedded representations
        image_encoder (tf.keras.Model): Encoder for image features
        reconstruction_loss_weight (float, optional): Weight of reconstruction loss at training time. Should be between 0 and 1.
        n_classes (int): number of labeled classes
        prior_batch_size (int): Batch size from prior distribution at training time
    """
    super(PureFontStyleSAAE, self).__init__()

    self.full_encoder = full_encoder
    self.image_encoder = image_encoder
    self.decoder = decoder
    self.prior_discriminator = prior_discriminator
    self.rec_loss_weight = min(max(reconstruction_loss_weight,0),1)
    self.prior_batch_size = prior_batch_size

    self.prior_sampler = tf.random.normal
    self.code_regularisation_weight = code_regularisation_weight
    # list of embedded models as instance attributes 
    self.model_list = ["full_encoder", "image_encoder", "decoder", "prior_discriminator"]


  def prior_discriminator_loss(self,real,fake):
    real_loss = self.cross_entropy(tf.ones_like(real), real)
    fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss)/(2*self.prior_batch_size)

  def compile(self,
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    **kwargs):

    self.full_encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.image_encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.decoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)  

    super().compile(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      loss_weights=loss_weights,
      weighted_metrics=weighted_metrics,
      run_eagerly=run_eagerly,
      **kwargs)

  def __call__(self, x, training=True, mask=None):
    return self.image_encoder(x, training=training)

  def scramble_font_batches(self, x: tf.Tensor, labels: tf.Tensor):
    """Creates a scrambled copy of a minibatch in which individual fonts are randomly shuffled. Returns the original minibatch in addition to the shuffled version.
    
    Args:
        x (tf.Tensor): Feature tensor
        labels (tf.Tensor): Label tensor
    
    Returns:
        t.Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: return scrambled and original feature-label pairs, in that order.
    """
    #
    x_shape = tf.shape(x)
    n_fonts = x_shape[0]
    #
    style_x = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
    style_y = tf.TensorArray(tf.float32, size=0, dynamic_size=True)

    outcome_x = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
    outcome_y = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
    #
    for k in range(n_fonts):
      x_k, labels_k = x[k], labels[k]
      x_k_shape = tf.shape(x_k)
      shuffling_idx = tf.range(start=0, limit=tf.shape(x_k)[0], dtype=tf.int32)
      scrambled = tf.random.shuffle(shuffling_idx)
      #
      style_x = style_x.write(k, tf.gather(x_k, scrambled, axis=0))
      style_y = style_y.write(k, tf.gather(labels_k, scrambled, axis=0))

      outcome_x = outcome_x.write(k, x_k)
      outcome_y = outcome_y.write(k, labels_k)
      #
    #
    #
    return style_x.concat(), style_y.concat(), outcome_x.concat(), outcome_y.concat()

  def train_step(self, inputs):
    #prior_sampler = tf.random.normal
    x, labels = inputs

    #self.prior_batch_size = x.shape[0]
    style_x, style_y, outcome_x, outcome_y = self.scramble_font_batches(x ,labels)

    n_examples = tf.shape(style_x)[0]
    with tf.GradientTape(persistent=True) as tape:

      # apply autoencoder
      image_precode = self.image_encoder(style_x, training=True)
      full_precode = tf.concat([image_precode, style_y], axis=-1)
      code = self.full_encoder(full_precode, training=True)
      extended_code = tf.concat([code,outcome_y],axis=-1)
      decoded = self.decoder(extended_code,training=True)  

      # apply prior_discriminator model
      prior_samples = self.prior_sampler(shape=(n_examples,code.shape[1]))
      real = self.prior_discriminator(prior_samples,training=True)
      fake = self.prior_discriminator(code,training=True)


      # Moment regularization for embedded representation to keep it closer to standard normal
      reg = self.code_regularisation_weight*(tf.reduce_mean(code)**2 + (tf.reduce_mean(code**2) - 1.0)**2)/2

      # compute losses for the models
      reconstruction_loss = tf.keras.losses.MSE(outcome_x,decoded) 
      classification_loss = self.prior_discriminator_loss(real,fake)

      mixed_loss = -(1-self.rec_loss_weight)*(classification_loss) + self.rec_loss_weight*(reconstruction_loss + reg)


    # Compute gradients
    discr_gradients = tape.gradient(classification_loss,self.prior_discriminator.trainable_variables)
    decoder_gradients = tape.gradient(reconstruction_loss, self.decoder.trainable_variables)
    image_encoder_gradients = tape.gradient(mixed_loss, self.image_encoder.trainable_variables)
    full_encoder_gradients = tape.gradient(mixed_loss, self.full_encoder.trainable_variables)

    #apply gradients
    self.prior_discriminator.optimizer.apply_gradients(zip(discr_gradients,self.prior_discriminator.trainable_variables))
    self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
    self.image_encoder.optimizer.apply_gradients(zip(image_encoder_gradients,self.image_encoder.trainable_variables))
    self.full_encoder.optimizer.apply_gradients(zip(full_encoder_gradients,self.full_encoder.trainable_variables))

    # compute metrics
    self.mse_metric.update_state(outcome_x,decoded)

    discr_true = tf.concat([tf.ones_like(real),tf.zeros_like(fake)],axis=0)
    discr_predicted = tf.concat([real,fake],axis=0)
    self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

    return {
    "reconstruction MSE": self.mse_metric.result(), 
    "discriminator accuracy": self.prior_accuracy_metric.result()}



  def save(self,output_dir: str):
    """Save the model to an output folder
    
    Args:
        output_dir (str): Target output folder
    """

    self.full_encoder.save(str(BytestreamPath(output_dir) / "full_encoder"))
    self.image_encoder.save(str(BytestreamPath(output_dir) / "image_encoder"))
    self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
    self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))

    d = {
      "reconstruction_loss_weight":self.rec_loss_weight,
      "prior_batch_size": self.prior_batch_size,
      "code_regularisation_weight": self.code_regularisation_weight
    }

    # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
    #   json.dump(d,f)

    (BytestreamPath(output_dir) / "aae-params.json").write_bytes(json.dumps(d).encode())

  @classmethod
  def load(cls, input_dir: str):
    """Loads a saved instance of this class
    
    Args:
        input_dir (str): Target input folder
    
    Returns:
        SAAE: Loaded model
    """
    full_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "full_encoder"))
    image_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_encoder"))
    decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
    prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))

    # with open(str(BytestreamPath(input_dir) / "aae-params.json"),"r") as f:
    #   d = json.loads(f.read())
    d_string = (BytestreamPath(input_dir) / "aae-params.json").read_bytes().decode("utf-8")
    d = json.loads(d_string)

    return cls(
      image_encoder = image_encoder, 
      full_encoder = full_encoder, 
      decoder = decoder, 
      prior_discriminator = prior_discriminator, 
      **d)
















class PureFontStyleSA2AE(tf.keras.Model):

  """This model is works like PureFontStyleSA2AE but instead of minimising the MSE reconstruction error, an additional image discriminator classify between real and reconstructed images, so the encoder and decoder now maximise misclassification error.

  """
  prior_accuracy_metric = tf.keras.metrics.Accuracy(name="adversarial prior accuracy")
  image_accuracy_metric = tf.keras.metrics.Accuracy(name="adversarial image accuracy")
  cross_entropy_metric = tf.keras.metrics.BinaryCrossentropy(name="Prior adversarial cross entropy", from_logits=False)

  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False) #assumes prior_discriminator outputs 

    ## this __init__ has to be copy pasted from the model above, because Tensorflow hates good coding practices aparently
  def __init__(
    self,
    full_encoder: tf.keras.Model,
    image_encoder: tf.keras.Model,
    decoder: tf.keras.Model,
    prior_discriminator: tf.keras.Model,
    image_discriminator: tf.keras.Model,
    code_regularisation_weight=0):
    """Summary
    
    Args:
        decoder (tf.keras.Model): Decoder model that maps style and characters to images
        prior_discriminator (tf.keras.Model): Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
        full_encoder (tf.keras.Model): Encoder that takes high-level image features and labels to produce embedded representations
        image_encoder (tf.keras.Model): Encoder for image features
        image_discriminator (tf.keras.Model): image discriminator
    """
    super(PureFontStyleSA2AE, self).__init__()

    self.full_encoder = full_encoder
    self.image_encoder = image_encoder
    self.decoder = decoder
    self.prior_discriminator = prior_discriminator
    self.image_discriminator = image_discriminator

    self.prior_sampler = tf.random.normal
    # list of embedded models as instance attributes 
    self.model_list = ["full_encoder", "image_encoder", "decoder", "prior_discriminator", "image_discriminator"]


  def discriminator_loss(self,real,fake):
    real_loss = self.cross_entropy(tf.ones_like(real), real)
    fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss)

  def compile(self,
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    **kwargs):

    self.full_encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.image_encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.decoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.image_discriminator.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)  

    super().compile(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      loss_weights=loss_weights,
      weighted_metrics=weighted_metrics,
      run_eagerly=run_eagerly,
      **kwargs)

  def __call__(self, x, training=True, mask=None):
    return self.image_encoder(x, training=training)

  def scramble_font_batches(self, x: tf.Tensor, labels: tf.Tensor):
    """Creates a scrambled copy of a minibatch in which individual fonts are randomly shuffled. Returns the original minibatch in addition to the shuffled version.
    
    Args:
        x (tf.Tensor): Feature tensor
        labels (tf.Tensor): Label tensor
    
    Returns:
        t.Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: return scrambled and original feature-label pairs, in that order.
    """
    #
    x_shape = tf.shape(x)
    n_fonts = x_shape[0]
    #
    style_x = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
    style_y = tf.TensorArray(tf.float32, size=0, dynamic_size=True)

    outcome_x = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
    outcome_y = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
    #
    for k in range(n_fonts):
      x_k, labels_k = x[k], labels[k]
      x_k_shape = tf.shape(x_k)
      shuffling_idx = tf.range(start=0, limit=tf.shape(x_k)[0], dtype=tf.int32)
      scrambled = tf.random.shuffle(shuffling_idx)
      #
      style_x = style_x.write(k, tf.gather(x_k, scrambled, axis=0))
      style_y = style_y.write(k, tf.gather(labels_k, scrambled, axis=0))

      outcome_x = outcome_x.write(k, x_k)
      outcome_y = outcome_y.write(k, labels_k)
      #
    #
    #
    return style_x.concat(), style_y.concat(), outcome_x.concat(), outcome_y.concat()

  def train_step(self, inputs):
    #prior_sampler = tf.random.normal
    x, labels = inputs

    #self.prior_batch_size = x.shape[0]
    style_x, style_y, outcome_x, outcome_y = self.scramble_font_batches(x ,labels)

    n_examples = tf.shape(style_x)[0]
    with tf.GradientTape(persistent=True) as tape:

      # apply autoencoder
      image_precode = self.image_encoder(style_x, training=True)
      full_precode = tf.concat([image_precode, style_y], axis=-1)
      code = self.full_encoder(full_precode, training=True)
      extended_code = tf.concat([code,outcome_y],axis=-1)
      decoded = self.decoder(extended_code,training=True)  

      # apply prior_discriminator model
      prior_samples = self.prior_sampler(shape=(n_examples,code.shape[1]))
      real_prior = self.prior_discriminator(prior_samples,training=True)
      fake_prior = self.prior_discriminator(code,training=True)
      prior_classification_loss = self.discriminator_loss(real_prior,fake_prior)

      #apply image_discriminator model
      real_image = self.image_discriminator(outcome_x)
      fake_image = self.image_discriminator(decoded)
      image_classification_loss = self.discriminator_loss(real_image, fake_image)

      # compute losses for the models
      #reconstruction_loss = tf.keras.losses.MSE(outcome_x,decoded) 
      
      mixed_loss = -(prior_classification_loss + image_classification_loss)

      #mixed_loss = -(1-self.rec_loss_weight)*(prior_classification_loss) + self.rec_loss_weight*(reconstruction_loss)


    # Compute gradients
    prior_discriminator_gradients = tape.gradient(prior_classification_loss,self.prior_discriminator.trainable_variables)
    decoder_gradients = tape.gradient(image_classification_loss, self.decoder.trainable_variables)
    image_encoder_gradients = tape.gradient(mixed_loss, self.image_encoder.trainable_variables)
    full_encoder_gradients = tape.gradient(mixed_loss, self.full_encoder.trainable_variables)
    image_discriminator_gradients = tape.gradient(image_classification_loss, self.image_discriminator.trainable_variables)

    #apply gradients
    self.prior_discriminator.optimizer.apply_gradients(zip(prior_discriminator_gradients,self.prior_discriminator.trainable_variables))
    self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
    self.image_encoder.optimizer.apply_gradients(zip(image_encoder_gradients,self.image_encoder.trainable_variables))
    self.full_encoder.optimizer.apply_gradients(zip(full_encoder_gradients,self.full_encoder.trainable_variables))
    self.image_discriminator.optimizer.apply_gradients(zip(image_discriminator_gradients, self.image_discriminator.trainable_variables))

    # compute metrics
    #self.mse_metric.update_state(outcome_x,decoded)

    discr_true = tf.concat([tf.ones_like(real_prior),tf.zeros_like(fake_prior)],axis=0)
    discr_predicted = tf.concat([real_prior,fake_prior],axis=0)
    self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

    discr_true = tf.concat([tf.ones_like(real_image),tf.zeros_like(fake_image)],axis=0)
    discr_predicted = tf.concat([real_image,fake_image],axis=0)
    self.image_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

    return {
    "image discriminator accuracy": self.image_accuracy_metric.result(), 
    "prior discriminator accuracy": self.prior_accuracy_metric.result()}



  def save(self,output_dir: str):
    """Save the model to an output folder
    
    Args:
        output_dir (str): Target output folder
    """

    self.full_encoder.save(str(BytestreamPath(output_dir) / "full_encoder"))
    self.image_encoder.save(str(BytestreamPath(output_dir) / "image_encoder"))
    self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
    self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))
    self.image_discriminator.save(str(BytestreamPath(output_dir) / "image_discriminator"))


    # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
    #   json.dump(d,f)

  @classmethod
  def load(cls, input_dir: str):
    """Loads a saved instance of this class
    
    Args:
        input_dir (str): Target input folder
    
    Returns:
        SAAE: Loaded model
    """
    full_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "full_encoder"))
    image_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_encoder"))
    decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
    prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))
    image_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_discriminator"))

    return cls(
      image_encoder = image_encoder, 
      full_encoder = full_encoder, 
      decoder = decoder, 
      prior_discriminator = prior_discriminator, 
      image_discriminator = image_discriminator)







class TensorFontStyleSAAE(tf.keras.Model):

  """This model treats characters in a font as channels in an image. The encoder takes the whole font as an image with (n_char) channels.
  """
  prior_accuracy_metric = tf.keras.metrics.Accuracy(name="prior adversarial accuracy")
  mse_metric = tf.keras.metrics.MeanSquaredError(name="Reconstruction error")
  cross_entropy_metric = tf.keras.metrics.BinaryCrossentropy(name="Prior adversarial cross entropy", from_logits=False)

  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False) #assumes prior_discriminator outputs 

    ## this __init__ has to be copy pasted from the model above, because Tensorflow hates good coding practices aparently
  def __init__(
    self,
    encoder: tf.keras.Model,
    decoder: tf.keras.Model,
    prior_discriminator: tf.keras.Model,
    reconstruction_loss_weight:float=0.5):
    """Summary
    
    Args:
        decoder (tf.keras.Model): Decoder model that maps style and characters to images
        prior_discriminator (tf.keras.Model): Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
        full_encoder (tf.keras.Model): Encoder that takes high-level image features and labels to produce embedded representations
        encoder (tf.keras.Model): Encoder for image features
        reconstruction_loss_weight (float, optional): Weight of reconstruction loss at training time. Should be between 0 and 1.
        n_classes (int): number of labeled classes
        prior_batch_size (int): Batch size from prior distribution at training time
    """
    super(TensorFontStyleSAAE, self).__init__()

    self.encoder = encoder
    self.decoder = decoder
    self.prior_discriminator = prior_discriminator
    self.rec_loss_weight = min(max(reconstruction_loss_weight,0),1)

    self.prior_sampler = tf.random.normal
    # list of embedded models as instance attributes 
    self.model_list = ["encoder", "decoder", "prior_discriminator"]


  def prior_discriminator_loss(self,real,fake, n_fonts):
    real_loss = self.cross_entropy(tf.ones_like(real), real)
    fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss)/tf.cast(n_fonts, tf.float32)

  def compile(self,
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    **kwargs):

    self.encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.decoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)  

    super().compile(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      loss_weights=loss_weights,
      weighted_metrics=weighted_metrics,
      run_eagerly=run_eagerly,
      **kwargs)

  def __call__(self, x, training=True, mask=None):
    return self.encoder(x, training=training)

  def train_step(self, inputs):
    #prior_sampler = tf.random.normal
    x, _ = inputs

    x_shape = tf.shape(x)
    n_fonts = x_shape[0]
    n_chars = x_shape[1]
    height = x_shape[2]
    width = x_shape[3]

    x = tf.reshape(x, (n_fonts, n_chars, height, width))
    x = tf.transpose(x, [0,2,3,1])

    with tf.GradientTape(persistent=True) as tape:

      # apply autoencoder
      embedding = self.encoder(x, training=True)
      decoded = self.decoder(embedding,training=True)  

      # apply prior_discriminator model
      prior_samples = self.prior_sampler(shape=(n_fonts,embedding.shape[1]))
      real = self.prior_discriminator(prior_samples,training=True)
      fake = self.prior_discriminator(embedding,training=True)

      # compute losses for the models
      reconstruction_loss = tf.keras.losses.MSE(x,decoded) 
      classification_loss = self.prior_discriminator_loss(real,fake, n_fonts)

      mixed_loss = -(1-self.rec_loss_weight)*classification_loss + self.rec_loss_weight*reconstruction_loss

    # Compute gradients
    discr_gradients = tape.gradient(classification_loss,self.prior_discriminator.trainable_variables)
    decoder_gradients = tape.gradient(reconstruction_loss, self.decoder.trainable_variables)
    encoder_gradients = tape.gradient(mixed_loss, self.encoder.trainable_variables)

    #apply gradients
    self.prior_discriminator.optimizer.apply_gradients(zip(discr_gradients,self.prior_discriminator.trainable_variables))
    self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
    self.encoder.optimizer.apply_gradients(zip(encoder_gradients,self.encoder.trainable_variables))

    # compute metrics
    self.mse_metric.update_state(x,decoded)

    discr_true = tf.concat([tf.ones_like(real),tf.zeros_like(fake)],axis=0)
    discr_predicted = tf.concat([real,fake],axis=0)
    self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

    return {
    "reconstruction MSE": self.mse_metric.result(), 
    "discriminator accuracy": self.prior_accuracy_metric.result()}



  def save(self,output_dir: str):
    """Save the model to an output folder
    
    Args:
        output_dir (str): Target output folder
    """

    self.encoder.save(str(BytestreamPath(output_dir) / "encoder"))
    self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
    self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))

    d = {
      "reconstruction_loss_weight":self.rec_loss_weight
      }

    # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
    #   json.dump(d,f)

    (BytestreamPath(output_dir) / "aae-params.json").write_bytes(json.dumps(d).encode())

  @classmethod
  def load(cls, input_dir: str):
    """Loads a saved instance of this class
    
    Args:
        input_dir (str): Target input folder
    
    Returns:
        SAAE: Loaded model
    """
    encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "encoder"))
    decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
    prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))

    # with open(str(BytestreamPath(input_dir) / "aae-params.json"),"r") as f:
    #   d = json.loads(f.read())
    d_string = (BytestreamPath(input_dir) / "aae-params.json").read_bytes().decode("utf-8")
    d = json.loads(d_string)

    return cls(
      encoder = encoder, 
      decoder = decoder, 
      prior_discriminator = prior_discriminator, 
      **d)

Classes

class CharStyleSAAE (full_encoder: tensorflow.python.keras.engine.training.Model, image_encoder: tensorflow.python.keras.engine.training.Model, decoder: tensorflow.python.keras.engine.training.Model, prior_discriminator: tensorflow.python.keras.engine.training.Model, reconstruction_loss_weight: float = 0.5, prior_batch_size: int = 32)

This class fits a supervised adversarial autoencoder and its inspired in the architecture from "Adversarial Autoencoders" by Ian Goodfellow et al. The only difference is that label (i.e. character) information is injected between the encoder and the style embedding, in the hope that labels not only help the decoding but also the encoding process, e.g. curvyness shouldn't be as important in the input if its a C than if its an H.

Attributes

accuracy_metric : tf.keras.metrics.Accuracy
Accuracy metric
cross_entropy : tf.keras.losses.BinaryCrossentropy
Cross entropy loss
decoder : tf.keras.Model
Decoder model that maps style and characters to images
prior_discriminator : tf.keras.Model
Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
full_encoder : tf.keras.Model
Encoder that takes high-level image features and labels to produce embedded representations
image_encoder : tf.keras.Model
Encoder for image features
input_dim : t.Tuple[int]
Input dimension
mse_loss : TYPE
Description
mse_metric : tf.keras.losses.MSE
MSE loss
prior_batch_size : int
Batch size from prior distribution at training time
rec_loss_weight : float
Weight of reconstruction loss at training time. Should be between 0 and 1.

Summary

Args

decoder : tf.keras.Model
Decoder model that maps style and characters to images
prior_discriminator : tf.keras.Model
Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
full_encoder : tf.keras.Model
Encoder that takes high-level image features and labels to produce embedded representations
image_encoder : tf.keras.Model
Encoder for image features
reconstruction_loss_weight : float, optional
Weight of reconstruction loss at training time. Should be between 0 and 1.
n_classes : int
number of labeled classes
prior_batch_size : int
Batch size from prior distribution at training time
Expand source code
class CharStyleSAAE(tf.keras.Model):

  """This class fits a supervised adversarial autoencoder and its inspired in the architecture from "Adversarial Autoencoders" by Ian Goodfellow et al. The only difference is that label (i.e. character) information is injected between the encoder and the style embedding, in the hope that labels not only help the decoding but also the encoding process, e.g. curvyness shouldn't be as important in the input if its a C than if its an H.
  
  Attributes:
      accuracy_metric (tf.keras.metrics.Accuracy): Accuracy metric
      cross_entropy (tf.keras.losses.BinaryCrossentropy): Cross entropy loss
      decoder (tf.keras.Model): Decoder model that maps style and characters to images
      prior_discriminator (tf.keras.Model): Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
      full_encoder (tf.keras.Model): Encoder that takes high-level image features and labels to produce embedded representations
      image_encoder (tf.keras.Model): Encoder for image features
      input_dim (t.Tuple[int]): Input dimension
      mse_loss (TYPE): Description
      mse_metric (tf.keras.losses.MSE): MSE loss
      prior_batch_size (int): Batch size from prior distribution at training time
      rec_loss_weight (float): Weight of reconstruction loss at training time. Should be between 0 and 1.
  """

  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False) #assumes prior_discriminator outputs probabilities(i.e. in [0, 1])

  mse_metric = tf.keras.metrics.MeanSquaredError(name="Reconstruction error")
  prior_accuracy_metric = tf.keras.metrics.Accuracy(name="prior adversarial accuracy")
  cross_entropy_metric = tf.keras.metrics.BinaryCrossentropy(name="Prior adversarial cross entropy", from_logits=False)

  def __init__(
    self,
    full_encoder: tf.keras.Model,
    image_encoder: tf.keras.Model,
    decoder: tf.keras.Model,
    prior_discriminator: tf.keras.Model,
    reconstruction_loss_weight:float=0.5,
    prior_batch_size:int=32):
    """Summary
    
    Args:
        decoder (tf.keras.Model): Decoder model that maps style and characters to images
        prior_discriminator (tf.keras.Model): Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
        full_encoder (tf.keras.Model): Encoder that takes high-level image features and labels to produce embedded representations
        image_encoder (tf.keras.Model): Encoder for image features
        reconstruction_loss_weight (float, optional): Weight of reconstruction loss at training time. Should be between 0 and 1.
        n_classes (int): number of labeled classes
        prior_batch_size (int): Batch size from prior distribution at training time
    """
    super(CharStyleSAAE, self).__init__()

    #encoder.build(input_shape=input_dim)
    #decoder.build(input_shape=(None,n_classes+code_dim))
    #prior_discriminator.build(input_shape=(None,code_dim))

    self.full_encoder = full_encoder
    self.image_encoder = image_encoder
    self.decoder = decoder
    self.prior_discriminator = prior_discriminator
    self.rec_loss_weight = min(max(reconstruction_loss_weight,0),1)
    self.prior_batch_size = prior_batch_size

    self.prior_sampler = tf.random.normal
    
    # list of embedded models as instance attributes 
    self.model_list = ["full_encoder", "image_encoder", "decoder", "prior_discriminator"]

  def compile(self,
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    **kwargs):

    self.full_encoder.compile(optimizer = copy.deepcopy(optimizer))
    self.image_encoder.compile(optimizer = copy.deepcopy(optimizer))
    self.decoder.compile(optimizer = copy.deepcopy(optimizer))
    self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer))

    super().compile(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      loss_weights=loss_weights,
      weighted_metrics=weighted_metrics,
      run_eagerly=run_eagerly,
      **kwargs)

  def prior_discriminator_loss(self,real,fake):
    real_loss = self.cross_entropy(tf.ones_like(real), real)
    fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss)/(2*self.prior_batch_size)

  def __call__(self, x, training=True, mask=None):
    return self.image_encoder(x, training=training)

  def train_step(self, inputs):
    x, labels = inputs

    #self.prior_batch_size = x.shape[0]
    #logger.info("prior_batch_size is deprecated; setting it equal to batch size.")

    with tf.GradientTape(persistent=True) as tape:

      # apply autoencoder
      image_precode = self.image_encoder(x, training=True)
      full_precode = tf.concat([image_precode, labels], axis=-1)
      code = self.full_encoder(full_precode, training=True)
      extended_code = tf.concat([code,labels],axis=-1)
      decoded = self.decoder(extended_code,training=True)  

      # apply prior_discriminator model
      #print((self.prior_batch_size,code.shape[1]))
      prior_samples = self.prior_sampler(shape=[self.prior_batch_size,code.shape[1]])
      real = self.prior_discriminator(prior_samples,training=True)
      fake = self.prior_discriminator(code,training=True)

      # compute losses for the models
      reconstruction_loss = tf.keras.losses.MSE(x,decoded)
      classification_loss = self.prior_discriminator_loss(real,fake)
      mixed_loss = -(1-self.rec_loss_weight)*classification_loss + self.rec_loss_weight*reconstruction_loss

    # Compute gradients
    discr_gradients = tape.gradient(classification_loss,self.prior_discriminator.trainable_variables)
    decoder_gradients = tape.gradient(reconstruction_loss, self.decoder.trainable_variables)
    image_encoder_gradients = tape.gradient(mixed_loss, self.image_encoder.trainable_variables)
    full_encoder_gradients = tape.gradient(mixed_loss, self.full_encoder.trainable_variables)


    #apply gradients
    self.prior_discriminator.optimizer.apply_gradients(zip(discr_gradients,self.prior_discriminator.trainable_variables))
    self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
    self.image_encoder.optimizer.apply_gradients(zip(image_encoder_gradients,self.image_encoder.trainable_variables))
    self.full_encoder.optimizer.apply_gradients(zip(full_encoder_gradients,self.full_encoder.trainable_variables))

    # compute metrics
    self.mse_metric.update_state(x,decoded)

    discr_true = tf.concat([tf.ones_like(real),tf.zeros_like(fake)],axis=0)
    discr_predicted = tf.concat([real,fake],axis=0)
    self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

    #self.cross_entropy_metric.update_state(discr_true, discr_predicted)

    return {"reconstruction MSE": self.mse_metric.result(), "discriminator accuracy": self.prior_accuracy_metric.result()}

  @property
  def metrics(self):
    """Performance metrics to report at training time
    
    Returns: A list of metric objects

    """
    return [self.mse_metric, self.prior_accuracy_metric, self.cross_entropy_metric]

  def predict(self, *args, **kwargs):
    return self.image_encoder.predict(*args,**kwargs)


  def save(self,output_dir: str):
    """Save the model to an output folder
    
    Args:
        output_dir (str): Target output folder
    """
    self.full_encoder.save(str(BytestreamPath(output_dir) / "full_encoder"))
    self.image_encoder.save(str(BytestreamPath(output_dir) / "image_encoder"))
    self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
    self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))

    d = {
      "reconstruction_loss_weight":self.rec_loss_weight,
      "prior_batch_size": self.prior_batch_size
    }

    (BytestreamPath(output_dir) / "aae-params.json").write_bytes(json.dumps(d).encode())

    # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
    #   json.dump(d,f)

  @classmethod
  def load(cls, input_dir: str):
    """Loads a saved instance of this class
    
    Args:
        input_dir (str): Target input folder
    
    Returns:
        SAAE: Loaded model
    """
    full_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "full_encoder"))
    image_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_encoder"))
    decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
    prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))

    d_string = (BytestreamPath(input_dir) / "aae-params.json").read_bytes().decode("utf-8")
    d = json.loads(d_string)

    # with open(str(BytestreamPath(input_dir) / "aae-params.json"),"r") as f:
    #   d = json.loads(f.read())

    return cls(
      image_encoder = image_encoder, 
      full_encoder = full_encoder, 
      decoder = decoder, 
      prior_discriminator = prior_discriminator, 
      **d)

Ancestors

  • tensorflow.python.keras.engine.training.Model
  • tensorflow.python.keras.engine.network.Network
  • tensorflow.python.keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.training.tracking.tracking.AutoTrackable
  • tensorflow.python.training.tracking.base.Trackable
  • tensorflow.python.keras.utils.version_utils.LayerVersionSelector
  • tensorflow.python.keras.utils.version_utils.ModelVersionSelector

Class variables

var cross_entropy
var cross_entropy_metric
var mse_metric
var prior_accuracy_metric

Static methods

def load(input_dir: str)

Loads a saved instance of this class

Args

input_dir : str
Target input folder

Returns

SAAE
Loaded model
Expand source code
@classmethod
def load(cls, input_dir: str):
  """Loads a saved instance of this class
  
  Args:
      input_dir (str): Target input folder
  
  Returns:
      SAAE: Loaded model
  """
  full_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "full_encoder"))
  image_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_encoder"))
  decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
  prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))

  d_string = (BytestreamPath(input_dir) / "aae-params.json").read_bytes().decode("utf-8")
  d = json.loads(d_string)

  # with open(str(BytestreamPath(input_dir) / "aae-params.json"),"r") as f:
  #   d = json.loads(f.read())

  return cls(
    image_encoder = image_encoder, 
    full_encoder = full_encoder, 
    decoder = decoder, 
    prior_discriminator = prior_discriminator, 
    **d)

Instance variables

var metrics

Performance metrics to report at training time

Returns: A list of metric objects

Expand source code
@property
def metrics(self):
  """Performance metrics to report at training time
  
  Returns: A list of metric objects

  """
  return [self.mse_metric, self.prior_accuracy_metric, self.cross_entropy_metric]

Methods

def compile(self, optimizer='rmsprop', loss=None, metrics=None, loss_weights=None, weighted_metrics=None, run_eagerly=None, **kwargs)

Configures the model for training.

Arguments

optimizer: String (name of optimizer) or optimizer instance. See tf.keras.optimizers. loss: String (name of objective function), objective function or tf.keras.losses.Loss instance. See tf.keras.losses. An objective function is any callable with the signature loss = fn(y_true, y_pred), where y_true = ground truth values with shape = [batch_size, d0, .. dN], except sparse loss functions such as sparse categorical crossentropy where shape = [batch_size, d0, .. dN-1]. y_pred = predicted values with shape = [batch_size, d0, .. dN]. It returns a weighted loss float tensor. If a custom Loss instance is used and reduction is set to NONE, return value has the shape [batch_size, d0, .. dN-1] ie. per-sample or per-timestep loss values; otherwise, it is a scalar. If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses. metrics: List of metrics to be evaluated by the model during training and testing. Each of this can be a string (name of a built-in function), function or a tf.keras.metrics.Metric instance. See tf.keras.metrics. Typically you will use metrics=['accuracy']. A function is any callable with the signature result = fn(y_true, y_pred). To specify different metrics for different outputs of a multi-output model, you could also pass a dictionary, such as metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}. You can also pass a list (len = len(outputs)) of lists of metrics such as metrics=[['accuracy'], ['accuracy', 'mse']] or metrics=['accuracy', ['accuracy', 'mse']]. When you pass the strings 'accuracy' or 'acc', we convert this to one of tf.keras.metrics.BinaryAccuracy, tf.keras.metrics.CategoricalAccuracy, tf.keras.metrics.SparseCategoricalAccuracy based on the loss function used and the model output shape. We do a similar conversion for the strings 'crossentropy' and 'ce' as well. loss_weights: Optional list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the loss_weights coefficients. If a list, it is expected to have a 1:1 mapping to the model's outputs. If a dict, it is expected to map output names (strings) to scalar coefficients. sample_weight_mode: If you need to do timestep-wise sample weighting (2D weights), set this to "temporal". None defaults to sample-wise weights (1D). If the model has multiple outputs, you can use a different sample_weight_mode on each output by passing a dictionary or a list of modes. weighted_metrics: List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing. **kwargs: Any additional arguments. For eager execution, pass run_eagerly=True.

Raises

ValueError
In case of invalid arguments for optimizer, loss, metrics or sample_weight_mode.
Expand source code
def compile(self,
  optimizer='rmsprop',
  loss=None,
  metrics=None,
  loss_weights=None,
  weighted_metrics=None,
  run_eagerly=None,
  **kwargs):

  self.full_encoder.compile(optimizer = copy.deepcopy(optimizer))
  self.image_encoder.compile(optimizer = copy.deepcopy(optimizer))
  self.decoder.compile(optimizer = copy.deepcopy(optimizer))
  self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer))

  super().compile(
    optimizer=optimizer,
    loss=loss,
    metrics=metrics,
    loss_weights=loss_weights,
    weighted_metrics=weighted_metrics,
    run_eagerly=run_eagerly,
    **kwargs)
def predict(self, *args, **kwargs)

Generates output predictions for the input samples.

Computation is done in batches. This method is designed for performance in large scale inputs. For small amount of inputs that fit in one batch, directly using __call__ is recommended for faster execution, e.g., model(x), or model(x, training=False) if you have layers such as tf.keras.layers.BatchNormalization that behaves differently during inference.

Arguments

x: Input samples. It could be: - A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs). - A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). - A tf.data dataset. - A generator or keras.utils.Sequence instance. A more detailed description of unpacking behavior for iterator types (Dataset, generator, Sequence) is given in the Unpacking behavior for iterator-like inputs<code> section of </code>Model.fit. batch_size: Integer or None. Number of samples per batch. If unspecified, batch_size will default to 32. Do not specify the batch_size if your data is in the form of dataset, generators, or keras.utils.Sequence instances (since they generate batches). verbose: Verbosity mode, 0 or 1. steps: Total number of steps (batches of samples) before declaring the prediction round finished. Ignored with the default value of None. If x is a tf.data dataset and steps is None, predict will run until the input dataset is exhausted. callbacks: List of keras.callbacks.Callback instances. List of callbacks to apply during prediction. See callbacks. max_queue_size: Integer. Used for generator or keras.utils.Sequence input only. Maximum size for the generator queue. If unspecified, max_queue_size will default to 10. workers: Integer. Used for generator or keras.utils.Sequence input only. Maximum number of processes to spin up when using process-based threading. If unspecified, workers will default to 1. If 0, will execute the generator on the main thread. use_multiprocessing: Boolean. Used for generator or keras.utils.Sequence input only. If True, use process-based threading. If unspecified, use_multiprocessing will default to False. Note that because this implementation relies on multiprocessing, you should not pass non-picklable arguments to the generator as they can't be passed easily to children processes.

See the discussion of Unpacking behavior for iterator-like inputs for Model.fit. Note that Model.predict uses the same interpretation rules as Model.fit and Model.evaluate, so inputs must be unambiguous for all three methods.

Returns

Numpy array(s) of predictions.

Raises

ValueError
In case of mismatch between the provided input data and the model's expectations, or in case a stateful model receives a number of samples that is not a multiple of the batch size.
Expand source code
def predict(self, *args, **kwargs):
  return self.image_encoder.predict(*args,**kwargs)
def prior_discriminator_loss(self, real, fake)
Expand source code
def prior_discriminator_loss(self,real,fake):
  real_loss = self.cross_entropy(tf.ones_like(real), real)
  fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
  return (real_loss + fake_loss)/(2*self.prior_batch_size)
def save(self, output_dir: str)

Save the model to an output folder

Args

output_dir : str
Target output folder
Expand source code
def save(self,output_dir: str):
  """Save the model to an output folder
  
  Args:
      output_dir (str): Target output folder
  """
  self.full_encoder.save(str(BytestreamPath(output_dir) / "full_encoder"))
  self.image_encoder.save(str(BytestreamPath(output_dir) / "image_encoder"))
  self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
  self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))

  d = {
    "reconstruction_loss_weight":self.rec_loss_weight,
    "prior_batch_size": self.prior_batch_size
  }

  (BytestreamPath(output_dir) / "aae-params.json").write_bytes(json.dumps(d).encode())

  # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
  #   json.dump(d,f)
def train_step(self, inputs)

The logic for one training step.

This method can be overridden to support custom training logic. This method is called by Model.make_train_function.

This method should contain the mathemetical logic for one step of training. This typically includes the forward pass, loss calculation, backpropagation, and metric updates.

Configuration details for how this logic is run (e.g. tf.function and tf.distribute.Strategy settings), should be left to Model.make_train_function, which can also be overridden.

Arguments

data: A nested structure of Tensors.

Returns

A dict containing values that will be passed to tf.keras.callbacks.CallbackList.on_train_batch_end. Typically, the values of the Model's metrics are returned. Example: {'loss': 0.2, 'accuracy': 0.7}.

Expand source code
def train_step(self, inputs):
  x, labels = inputs

  #self.prior_batch_size = x.shape[0]
  #logger.info("prior_batch_size is deprecated; setting it equal to batch size.")

  with tf.GradientTape(persistent=True) as tape:

    # apply autoencoder
    image_precode = self.image_encoder(x, training=True)
    full_precode = tf.concat([image_precode, labels], axis=-1)
    code = self.full_encoder(full_precode, training=True)
    extended_code = tf.concat([code,labels],axis=-1)
    decoded = self.decoder(extended_code,training=True)  

    # apply prior_discriminator model
    #print((self.prior_batch_size,code.shape[1]))
    prior_samples = self.prior_sampler(shape=[self.prior_batch_size,code.shape[1]])
    real = self.prior_discriminator(prior_samples,training=True)
    fake = self.prior_discriminator(code,training=True)

    # compute losses for the models
    reconstruction_loss = tf.keras.losses.MSE(x,decoded)
    classification_loss = self.prior_discriminator_loss(real,fake)
    mixed_loss = -(1-self.rec_loss_weight)*classification_loss + self.rec_loss_weight*reconstruction_loss

  # Compute gradients
  discr_gradients = tape.gradient(classification_loss,self.prior_discriminator.trainable_variables)
  decoder_gradients = tape.gradient(reconstruction_loss, self.decoder.trainable_variables)
  image_encoder_gradients = tape.gradient(mixed_loss, self.image_encoder.trainable_variables)
  full_encoder_gradients = tape.gradient(mixed_loss, self.full_encoder.trainable_variables)


  #apply gradients
  self.prior_discriminator.optimizer.apply_gradients(zip(discr_gradients,self.prior_discriminator.trainable_variables))
  self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
  self.image_encoder.optimizer.apply_gradients(zip(image_encoder_gradients,self.image_encoder.trainable_variables))
  self.full_encoder.optimizer.apply_gradients(zip(full_encoder_gradients,self.full_encoder.trainable_variables))

  # compute metrics
  self.mse_metric.update_state(x,decoded)

  discr_true = tf.concat([tf.ones_like(real),tf.zeros_like(fake)],axis=0)
  discr_predicted = tf.concat([real,fake],axis=0)
  self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

  #self.cross_entropy_metric.update_state(discr_true, discr_predicted)

  return {"reconstruction MSE": self.mse_metric.result(), "discriminator accuracy": self.prior_accuracy_metric.result()}
class PureCharStyleSAAE (full_encoder: tensorflow.python.keras.engine.training.Model, image_encoder: tensorflow.python.keras.engine.training.Model, decoder: tensorflow.python.keras.engine.training.Model, char_discriminator: tensorflow.python.keras.engine.training.Model, prior_discriminator: tensorflow.python.keras.engine.training.Model, reconstruction_loss_weight: float = 0.5, prior_batch_size: int = 32)

This model is trained as a regular SAAE but an additional discriminator model is added to ensure the embedding does not retain information about the character class; i.e. it only retains style information

Summary

Args

decoder : tf.keras.Model
Decoder model that maps style and characters to images
prior_discriminator : tf.keras.Model
Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
full_encoder : tf.keras.Model
Encoder that takes high-level image features and labels to produce embedded representations
char_discriminator : tf.keras.Model
Discriminator to remove any character information from embeddings
image_encoder : tf.keras.Model
Encoder for image features
reconstruction_loss_weight : float, optional
Weight of reconstruction loss at training time. Should be between 0 and 1.
n_classes : int
number of labeled classes
prior_batch_size : int
Batch size from prior distribution at training time
Expand source code
class PureCharStyleSAAE(tf.keras.Model):

  """This model is trained as a regular SAAE but an additional discriminator model is added to ensure the embedding does not retain information about the character class; i.e. it only retains style information
  """
  mean_metric = tf.keras.metrics.Mean(name="Mean code variance")
  char_accuracy_metric = tf.keras.metrics.Accuracy(name="style-char accuracy")
  prior_accuracy_metric = tf.keras.metrics.Accuracy(name="prior adversarial accuracy")
  mse_metric = tf.keras.metrics.MeanSquaredError(name="Reconstruction error")
  cross_entropy_metric = tf.keras.metrics.BinaryCrossentropy(name="Prior adversarial cross entropy", from_logits=False)


  style_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False) #assumes prior_discriminator outputs 

    ## this __init__ has to be copy pasted from the model above, because Tensorflow hates good coding practices aparently
  def __init__(
    self,
    full_encoder: tf.keras.Model,
    image_encoder: tf.keras.Model,
    decoder: tf.keras.Model,
    char_discriminator: tf.keras.Model,
    prior_discriminator: tf.keras.Model,
    reconstruction_loss_weight:float=0.5,
    prior_batch_size:int=32):
    """Summary
    
    Args:
        decoder (tf.keras.Model): Decoder model that maps style and characters to images
        prior_discriminator (tf.keras.Model): Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
        full_encoder (tf.keras.Model): Encoder that takes high-level image features and labels to produce embedded representations
        char_discriminator (tf.keras.Model): Discriminator to remove any character information from embeddings
        image_encoder (tf.keras.Model): Encoder for image features
        reconstruction_loss_weight (float, optional): Weight of reconstruction loss at training time. Should be between 0 and 1.
        n_classes (int): number of labeled classes
        prior_batch_size (int): Batch size from prior distribution at training time
    """
    super(PureCharStyleSAAE, self).__init__()

    self.full_encoder = full_encoder
    self.image_encoder = image_encoder
    self.decoder = decoder
    self.prior_discriminator = prior_discriminator
    self.rec_loss_weight = min(max(reconstruction_loss_weight,0),1)
    self.prior_batch_size = prior_batch_size
    self.char_discriminator = char_discriminator

    self.prior_sampler = tf.random.normal
    # list of embedded models as instance attributes 
    self.model_list = ["full_encoder", "image_encoder", "decoder", "prior_discriminator", "char_discriminator"]

  def prior_discriminator_loss(self,real,fake):
    real_loss = self.cross_entropy(tf.ones_like(real), real)
    fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss)/(2*self.prior_batch_size)

  def compile(self,
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    **kwargs):

    self.full_encoder.compile(optimizer = copy.deepcopy(optimizer))
    self.image_encoder.compile(optimizer = copy.deepcopy(optimizer))
    self.decoder.compile(optimizer = copy.deepcopy(optimizer))
    self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer))
    self.char_discriminator.compile(optimizer=copy.deepcopy(optimizer))

    super().compile(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      loss_weights=loss_weights,
      weighted_metrics=weighted_metrics,
      run_eagerly=run_eagerly,
      **kwargs)

  def __call__(self, x, training=True, mask=None):
    return self.image_encoder(x, training=training)

  def train_step(self, inputs):
    #prior_sampler = tf.random.normal
    x, labels = inputs

    #self.prior_batch_size = x.shape[0]
    #logger.info("prior_batch_size is deprecated; setting it equal to batch size.")

    with tf.GradientTape(persistent=True) as tape:

      # apply autoencoder
      image_precode = self.image_encoder(x, training=True)
      full_precode = tf.concat([image_precode, labels], axis=-1)
      code = self.full_encoder(full_precode, training=True)
      extended_code = tf.concat([code,labels],axis=-1)
      decoded = self.decoder(extended_code,training=True)  

      # apply prior_discriminator model
      prior_samples = self.prior_sampler(shape=(self.prior_batch_size,code.shape[1]))
      real = self.prior_discriminator(prior_samples,training=True)
      fake = self.prior_discriminator(code,training=True)

      # apply char_discriminator model
      char_guess = self.char_discriminator(code,training=True)

      # compute losses for the models
      char_loss = self.style_loss(labels, char_guess)/self.prior_batch_size
      reconstruction_loss = tf.keras.losses.MSE(x,decoded)
      classification_loss = self.prior_discriminator_loss(real,fake)

      mixed_loss = -(1-self.rec_loss_weight)*(classification_loss + char_loss) + self.rec_loss_weight*(reconstruction_loss)

    # Compute gradients
    discr_gradients = tape.gradient(classification_loss,self.prior_discriminator.trainable_variables)
    decoder_gradients = tape.gradient(reconstruction_loss, self.decoder.trainable_variables)
    image_encoder_gradients = tape.gradient(mixed_loss, self.image_encoder.trainable_variables)
    full_encoder_gradients = tape.gradient(mixed_loss, self.full_encoder.trainable_variables)
    style_discr_gradients = tape.gradient(char_loss, self.char_discriminator.trainable_variables)

    #apply gradients
    self.prior_discriminator.optimizer.apply_gradients(zip(discr_gradients,self.prior_discriminator.trainable_variables))
    self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
    self.image_encoder.optimizer.apply_gradients(zip(image_encoder_gradients,self.image_encoder.trainable_variables))
    self.full_encoder.optimizer.apply_gradients(zip(full_encoder_gradients,self.full_encoder.trainable_variables))
    self.char_discriminator.optimizer.apply_gradients(
      zip(style_discr_gradients, self.char_discriminator.trainable_variables))

    # compute metrics
    self.mse_metric.update_state(x,decoded)

    discr_true = tf.concat([tf.ones_like(real),tf.zeros_like(fake)],axis=0)
    discr_predicted = tf.concat([real,fake],axis=0)
    self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

    #self.cross_entropy_metric.update_state(discr_true, discr_predicted)


    self.char_accuracy_metric.update_state(tf.argmax(labels, axis=-1), tf.argmax(char_guess, axis=-1))

    return {
    "reconstruction MSE": self.mse_metric.result(), 
    "discriminator accuracy": self.prior_accuracy_metric.result(), 
    "sstyle discriminator accuracy": self.char_accuracy_metric.result()}



  def save(self,output_dir: str):
    """Save the model to an output folder
    
    Args:
        output_dir (str): Target output folder
    """

    self.char_discriminator.save(str(BytestreamPath(output_dir) / "char_discriminator"))
    self.full_encoder.save(str(BytestreamPath(output_dir) / "full_encoder"))
    self.image_encoder.save(str(BytestreamPath(output_dir) / "image_encoder"))
    self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
    self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))

    d = {
      "reconstruction_loss_weight":self.rec_loss_weight,
      "prior_batch_size": self.prior_batch_size
    }

    # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
    #   json.dump(d,f)
    (BytestreamPath(output_dir) / "aae-params.json").write_bytes(json.dumps(d).encode())

  @classmethod
  def load(cls, input_dir: str):
    """Loads a saved instance of this class
    
    Args:
        input_dir (str): Target input folder
    
    Returns:
        SAAE: Loaded model
    """
    char_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "char_discriminator"))
    full_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "full_encoder"))
    image_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_encoder"))
    decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
    prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))

    # with open(str(BytestreamPath(input_dir) / "aae-params.json"),"r") as f:
    #   d = json.loads(f.read())

    d_string = (BytestreamPath(input_dir) / "aae-params.json").read_bytes().decode("utf-8")
    d = json.loads(d_string)

    return cls(
      image_encoder = image_encoder, 
      full_encoder = full_encoder, 
      decoder = decoder, 
      prior_discriminator = prior_discriminator, 
      char_discriminator = char_discriminator,
      **d)

Ancestors

  • tensorflow.python.keras.engine.training.Model
  • tensorflow.python.keras.engine.network.Network
  • tensorflow.python.keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.training.tracking.tracking.AutoTrackable
  • tensorflow.python.training.tracking.base.Trackable
  • tensorflow.python.keras.utils.version_utils.LayerVersionSelector
  • tensorflow.python.keras.utils.version_utils.ModelVersionSelector

Class variables

var char_accuracy_metric
var cross_entropy
var cross_entropy_metric
var mean_metric
var mse_metric
var prior_accuracy_metric
var style_loss

Static methods

def load(input_dir: str)

Loads a saved instance of this class

Args

input_dir : str
Target input folder

Returns

SAAE
Loaded model
Expand source code
@classmethod
def load(cls, input_dir: str):
  """Loads a saved instance of this class
  
  Args:
      input_dir (str): Target input folder
  
  Returns:
      SAAE: Loaded model
  """
  char_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "char_discriminator"))
  full_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "full_encoder"))
  image_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_encoder"))
  decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
  prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))

  # with open(str(BytestreamPath(input_dir) / "aae-params.json"),"r") as f:
  #   d = json.loads(f.read())

  d_string = (BytestreamPath(input_dir) / "aae-params.json").read_bytes().decode("utf-8")
  d = json.loads(d_string)

  return cls(
    image_encoder = image_encoder, 
    full_encoder = full_encoder, 
    decoder = decoder, 
    prior_discriminator = prior_discriminator, 
    char_discriminator = char_discriminator,
    **d)

Methods

def compile(self, optimizer='rmsprop', loss=None, metrics=None, loss_weights=None, weighted_metrics=None, run_eagerly=None, **kwargs)

Configures the model for training.

Arguments

optimizer: String (name of optimizer) or optimizer instance. See tf.keras.optimizers. loss: String (name of objective function), objective function or tf.keras.losses.Loss instance. See tf.keras.losses. An objective function is any callable with the signature loss = fn(y_true, y_pred), where y_true = ground truth values with shape = [batch_size, d0, .. dN], except sparse loss functions such as sparse categorical crossentropy where shape = [batch_size, d0, .. dN-1]. y_pred = predicted values with shape = [batch_size, d0, .. dN]. It returns a weighted loss float tensor. If a custom Loss instance is used and reduction is set to NONE, return value has the shape [batch_size, d0, .. dN-1] ie. per-sample or per-timestep loss values; otherwise, it is a scalar. If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses. metrics: List of metrics to be evaluated by the model during training and testing. Each of this can be a string (name of a built-in function), function or a tf.keras.metrics.Metric instance. See tf.keras.metrics. Typically you will use metrics=['accuracy']. A function is any callable with the signature result = fn(y_true, y_pred). To specify different metrics for different outputs of a multi-output model, you could also pass a dictionary, such as metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}. You can also pass a list (len = len(outputs)) of lists of metrics such as metrics=[['accuracy'], ['accuracy', 'mse']] or metrics=['accuracy', ['accuracy', 'mse']]. When you pass the strings 'accuracy' or 'acc', we convert this to one of tf.keras.metrics.BinaryAccuracy, tf.keras.metrics.CategoricalAccuracy, tf.keras.metrics.SparseCategoricalAccuracy based on the loss function used and the model output shape. We do a similar conversion for the strings 'crossentropy' and 'ce' as well. loss_weights: Optional list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the loss_weights coefficients. If a list, it is expected to have a 1:1 mapping to the model's outputs. If a dict, it is expected to map output names (strings) to scalar coefficients. sample_weight_mode: If you need to do timestep-wise sample weighting (2D weights), set this to "temporal". None defaults to sample-wise weights (1D). If the model has multiple outputs, you can use a different sample_weight_mode on each output by passing a dictionary or a list of modes. weighted_metrics: List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing. **kwargs: Any additional arguments. For eager execution, pass run_eagerly=True.

Raises

ValueError
In case of invalid arguments for optimizer, loss, metrics or sample_weight_mode.
Expand source code
def compile(self,
  optimizer='rmsprop',
  loss=None,
  metrics=None,
  loss_weights=None,
  weighted_metrics=None,
  run_eagerly=None,
  **kwargs):

  self.full_encoder.compile(optimizer = copy.deepcopy(optimizer))
  self.image_encoder.compile(optimizer = copy.deepcopy(optimizer))
  self.decoder.compile(optimizer = copy.deepcopy(optimizer))
  self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer))
  self.char_discriminator.compile(optimizer=copy.deepcopy(optimizer))

  super().compile(
    optimizer=optimizer,
    loss=loss,
    metrics=metrics,
    loss_weights=loss_weights,
    weighted_metrics=weighted_metrics,
    run_eagerly=run_eagerly,
    **kwargs)
def prior_discriminator_loss(self, real, fake)
Expand source code
def prior_discriminator_loss(self,real,fake):
  real_loss = self.cross_entropy(tf.ones_like(real), real)
  fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
  return (real_loss + fake_loss)/(2*self.prior_batch_size)
def save(self, output_dir: str)

Save the model to an output folder

Args

output_dir : str
Target output folder
Expand source code
def save(self,output_dir: str):
  """Save the model to an output folder
  
  Args:
      output_dir (str): Target output folder
  """

  self.char_discriminator.save(str(BytestreamPath(output_dir) / "char_discriminator"))
  self.full_encoder.save(str(BytestreamPath(output_dir) / "full_encoder"))
  self.image_encoder.save(str(BytestreamPath(output_dir) / "image_encoder"))
  self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
  self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))

  d = {
    "reconstruction_loss_weight":self.rec_loss_weight,
    "prior_batch_size": self.prior_batch_size
  }

  # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
  #   json.dump(d,f)
  (BytestreamPath(output_dir) / "aae-params.json").write_bytes(json.dumps(d).encode())
def train_step(self, inputs)

The logic for one training step.

This method can be overridden to support custom training logic. This method is called by Model.make_train_function.

This method should contain the mathemetical logic for one step of training. This typically includes the forward pass, loss calculation, backpropagation, and metric updates.

Configuration details for how this logic is run (e.g. tf.function and tf.distribute.Strategy settings), should be left to Model.make_train_function, which can also be overridden.

Arguments

data: A nested structure of Tensors.

Returns

A dict containing values that will be passed to tf.keras.callbacks.CallbackList.on_train_batch_end. Typically, the values of the Model's metrics are returned. Example: {'loss': 0.2, 'accuracy': 0.7}.

Expand source code
def train_step(self, inputs):
  #prior_sampler = tf.random.normal
  x, labels = inputs

  #self.prior_batch_size = x.shape[0]
  #logger.info("prior_batch_size is deprecated; setting it equal to batch size.")

  with tf.GradientTape(persistent=True) as tape:

    # apply autoencoder
    image_precode = self.image_encoder(x, training=True)
    full_precode = tf.concat([image_precode, labels], axis=-1)
    code = self.full_encoder(full_precode, training=True)
    extended_code = tf.concat([code,labels],axis=-1)
    decoded = self.decoder(extended_code,training=True)  

    # apply prior_discriminator model
    prior_samples = self.prior_sampler(shape=(self.prior_batch_size,code.shape[1]))
    real = self.prior_discriminator(prior_samples,training=True)
    fake = self.prior_discriminator(code,training=True)

    # apply char_discriminator model
    char_guess = self.char_discriminator(code,training=True)

    # compute losses for the models
    char_loss = self.style_loss(labels, char_guess)/self.prior_batch_size
    reconstruction_loss = tf.keras.losses.MSE(x,decoded)
    classification_loss = self.prior_discriminator_loss(real,fake)

    mixed_loss = -(1-self.rec_loss_weight)*(classification_loss + char_loss) + self.rec_loss_weight*(reconstruction_loss)

  # Compute gradients
  discr_gradients = tape.gradient(classification_loss,self.prior_discriminator.trainable_variables)
  decoder_gradients = tape.gradient(reconstruction_loss, self.decoder.trainable_variables)
  image_encoder_gradients = tape.gradient(mixed_loss, self.image_encoder.trainable_variables)
  full_encoder_gradients = tape.gradient(mixed_loss, self.full_encoder.trainable_variables)
  style_discr_gradients = tape.gradient(char_loss, self.char_discriminator.trainable_variables)

  #apply gradients
  self.prior_discriminator.optimizer.apply_gradients(zip(discr_gradients,self.prior_discriminator.trainable_variables))
  self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
  self.image_encoder.optimizer.apply_gradients(zip(image_encoder_gradients,self.image_encoder.trainable_variables))
  self.full_encoder.optimizer.apply_gradients(zip(full_encoder_gradients,self.full_encoder.trainable_variables))
  self.char_discriminator.optimizer.apply_gradients(
    zip(style_discr_gradients, self.char_discriminator.trainable_variables))

  # compute metrics
  self.mse_metric.update_state(x,decoded)

  discr_true = tf.concat([tf.ones_like(real),tf.zeros_like(fake)],axis=0)
  discr_predicted = tf.concat([real,fake],axis=0)
  self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

  #self.cross_entropy_metric.update_state(discr_true, discr_predicted)


  self.char_accuracy_metric.update_state(tf.argmax(labels, axis=-1), tf.argmax(char_guess, axis=-1))

  return {
  "reconstruction MSE": self.mse_metric.result(), 
  "discriminator accuracy": self.prior_accuracy_metric.result(), 
  "sstyle discriminator accuracy": self.char_accuracy_metric.result()}
class PureFontStyleSA2AE (full_encoder: tensorflow.python.keras.engine.training.Model, image_encoder: tensorflow.python.keras.engine.training.Model, decoder: tensorflow.python.keras.engine.training.Model, prior_discriminator: tensorflow.python.keras.engine.training.Model, image_discriminator: tensorflow.python.keras.engine.training.Model, code_regularisation_weight=0)

This model is works like PureFontStyleSA2AE but instead of minimising the MSE reconstruction error, an additional image discriminator classify between real and reconstructed images, so the encoder and decoder now maximise misclassification error.

Summary

Args

decoder : tf.keras.Model
Decoder model that maps style and characters to images
prior_discriminator : tf.keras.Model
Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
full_encoder : tf.keras.Model
Encoder that takes high-level image features and labels to produce embedded representations
image_encoder : tf.keras.Model
Encoder for image features
image_discriminator : tf.keras.Model
image discriminator
Expand source code
class PureFontStyleSA2AE(tf.keras.Model):

  """This model is works like PureFontStyleSA2AE but instead of minimising the MSE reconstruction error, an additional image discriminator classify between real and reconstructed images, so the encoder and decoder now maximise misclassification error.

  """
  prior_accuracy_metric = tf.keras.metrics.Accuracy(name="adversarial prior accuracy")
  image_accuracy_metric = tf.keras.metrics.Accuracy(name="adversarial image accuracy")
  cross_entropy_metric = tf.keras.metrics.BinaryCrossentropy(name="Prior adversarial cross entropy", from_logits=False)

  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False) #assumes prior_discriminator outputs 

    ## this __init__ has to be copy pasted from the model above, because Tensorflow hates good coding practices aparently
  def __init__(
    self,
    full_encoder: tf.keras.Model,
    image_encoder: tf.keras.Model,
    decoder: tf.keras.Model,
    prior_discriminator: tf.keras.Model,
    image_discriminator: tf.keras.Model,
    code_regularisation_weight=0):
    """Summary
    
    Args:
        decoder (tf.keras.Model): Decoder model that maps style and characters to images
        prior_discriminator (tf.keras.Model): Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
        full_encoder (tf.keras.Model): Encoder that takes high-level image features and labels to produce embedded representations
        image_encoder (tf.keras.Model): Encoder for image features
        image_discriminator (tf.keras.Model): image discriminator
    """
    super(PureFontStyleSA2AE, self).__init__()

    self.full_encoder = full_encoder
    self.image_encoder = image_encoder
    self.decoder = decoder
    self.prior_discriminator = prior_discriminator
    self.image_discriminator = image_discriminator

    self.prior_sampler = tf.random.normal
    # list of embedded models as instance attributes 
    self.model_list = ["full_encoder", "image_encoder", "decoder", "prior_discriminator", "image_discriminator"]


  def discriminator_loss(self,real,fake):
    real_loss = self.cross_entropy(tf.ones_like(real), real)
    fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss)

  def compile(self,
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    **kwargs):

    self.full_encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.image_encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.decoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.image_discriminator.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)  

    super().compile(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      loss_weights=loss_weights,
      weighted_metrics=weighted_metrics,
      run_eagerly=run_eagerly,
      **kwargs)

  def __call__(self, x, training=True, mask=None):
    return self.image_encoder(x, training=training)

  def scramble_font_batches(self, x: tf.Tensor, labels: tf.Tensor):
    """Creates a scrambled copy of a minibatch in which individual fonts are randomly shuffled. Returns the original minibatch in addition to the shuffled version.
    
    Args:
        x (tf.Tensor): Feature tensor
        labels (tf.Tensor): Label tensor
    
    Returns:
        t.Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: return scrambled and original feature-label pairs, in that order.
    """
    #
    x_shape = tf.shape(x)
    n_fonts = x_shape[0]
    #
    style_x = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
    style_y = tf.TensorArray(tf.float32, size=0, dynamic_size=True)

    outcome_x = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
    outcome_y = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
    #
    for k in range(n_fonts):
      x_k, labels_k = x[k], labels[k]
      x_k_shape = tf.shape(x_k)
      shuffling_idx = tf.range(start=0, limit=tf.shape(x_k)[0], dtype=tf.int32)
      scrambled = tf.random.shuffle(shuffling_idx)
      #
      style_x = style_x.write(k, tf.gather(x_k, scrambled, axis=0))
      style_y = style_y.write(k, tf.gather(labels_k, scrambled, axis=0))

      outcome_x = outcome_x.write(k, x_k)
      outcome_y = outcome_y.write(k, labels_k)
      #
    #
    #
    return style_x.concat(), style_y.concat(), outcome_x.concat(), outcome_y.concat()

  def train_step(self, inputs):
    #prior_sampler = tf.random.normal
    x, labels = inputs

    #self.prior_batch_size = x.shape[0]
    style_x, style_y, outcome_x, outcome_y = self.scramble_font_batches(x ,labels)

    n_examples = tf.shape(style_x)[0]
    with tf.GradientTape(persistent=True) as tape:

      # apply autoencoder
      image_precode = self.image_encoder(style_x, training=True)
      full_precode = tf.concat([image_precode, style_y], axis=-1)
      code = self.full_encoder(full_precode, training=True)
      extended_code = tf.concat([code,outcome_y],axis=-1)
      decoded = self.decoder(extended_code,training=True)  

      # apply prior_discriminator model
      prior_samples = self.prior_sampler(shape=(n_examples,code.shape[1]))
      real_prior = self.prior_discriminator(prior_samples,training=True)
      fake_prior = self.prior_discriminator(code,training=True)
      prior_classification_loss = self.discriminator_loss(real_prior,fake_prior)

      #apply image_discriminator model
      real_image = self.image_discriminator(outcome_x)
      fake_image = self.image_discriminator(decoded)
      image_classification_loss = self.discriminator_loss(real_image, fake_image)

      # compute losses for the models
      #reconstruction_loss = tf.keras.losses.MSE(outcome_x,decoded) 
      
      mixed_loss = -(prior_classification_loss + image_classification_loss)

      #mixed_loss = -(1-self.rec_loss_weight)*(prior_classification_loss) + self.rec_loss_weight*(reconstruction_loss)


    # Compute gradients
    prior_discriminator_gradients = tape.gradient(prior_classification_loss,self.prior_discriminator.trainable_variables)
    decoder_gradients = tape.gradient(image_classification_loss, self.decoder.trainable_variables)
    image_encoder_gradients = tape.gradient(mixed_loss, self.image_encoder.trainable_variables)
    full_encoder_gradients = tape.gradient(mixed_loss, self.full_encoder.trainable_variables)
    image_discriminator_gradients = tape.gradient(image_classification_loss, self.image_discriminator.trainable_variables)

    #apply gradients
    self.prior_discriminator.optimizer.apply_gradients(zip(prior_discriminator_gradients,self.prior_discriminator.trainable_variables))
    self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
    self.image_encoder.optimizer.apply_gradients(zip(image_encoder_gradients,self.image_encoder.trainable_variables))
    self.full_encoder.optimizer.apply_gradients(zip(full_encoder_gradients,self.full_encoder.trainable_variables))
    self.image_discriminator.optimizer.apply_gradients(zip(image_discriminator_gradients, self.image_discriminator.trainable_variables))

    # compute metrics
    #self.mse_metric.update_state(outcome_x,decoded)

    discr_true = tf.concat([tf.ones_like(real_prior),tf.zeros_like(fake_prior)],axis=0)
    discr_predicted = tf.concat([real_prior,fake_prior],axis=0)
    self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

    discr_true = tf.concat([tf.ones_like(real_image),tf.zeros_like(fake_image)],axis=0)
    discr_predicted = tf.concat([real_image,fake_image],axis=0)
    self.image_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

    return {
    "image discriminator accuracy": self.image_accuracy_metric.result(), 
    "prior discriminator accuracy": self.prior_accuracy_metric.result()}



  def save(self,output_dir: str):
    """Save the model to an output folder
    
    Args:
        output_dir (str): Target output folder
    """

    self.full_encoder.save(str(BytestreamPath(output_dir) / "full_encoder"))
    self.image_encoder.save(str(BytestreamPath(output_dir) / "image_encoder"))
    self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
    self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))
    self.image_discriminator.save(str(BytestreamPath(output_dir) / "image_discriminator"))


    # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
    #   json.dump(d,f)

  @classmethod
  def load(cls, input_dir: str):
    """Loads a saved instance of this class
    
    Args:
        input_dir (str): Target input folder
    
    Returns:
        SAAE: Loaded model
    """
    full_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "full_encoder"))
    image_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_encoder"))
    decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
    prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))
    image_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_discriminator"))

    return cls(
      image_encoder = image_encoder, 
      full_encoder = full_encoder, 
      decoder = decoder, 
      prior_discriminator = prior_discriminator, 
      image_discriminator = image_discriminator)

Ancestors

  • tensorflow.python.keras.engine.training.Model
  • tensorflow.python.keras.engine.network.Network
  • tensorflow.python.keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.training.tracking.tracking.AutoTrackable
  • tensorflow.python.training.tracking.base.Trackable
  • tensorflow.python.keras.utils.version_utils.LayerVersionSelector
  • tensorflow.python.keras.utils.version_utils.ModelVersionSelector

Class variables

var cross_entropy
var cross_entropy_metric
var image_accuracy_metric
var prior_accuracy_metric

Static methods

def load(input_dir: str)

Loads a saved instance of this class

Args

input_dir : str
Target input folder

Returns

SAAE
Loaded model
Expand source code
@classmethod
def load(cls, input_dir: str):
  """Loads a saved instance of this class
  
  Args:
      input_dir (str): Target input folder
  
  Returns:
      SAAE: Loaded model
  """
  full_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "full_encoder"))
  image_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_encoder"))
  decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
  prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))
  image_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_discriminator"))

  return cls(
    image_encoder = image_encoder, 
    full_encoder = full_encoder, 
    decoder = decoder, 
    prior_discriminator = prior_discriminator, 
    image_discriminator = image_discriminator)

Methods

def compile(self, optimizer='rmsprop', loss=None, metrics=None, loss_weights=None, weighted_metrics=None, run_eagerly=None, **kwargs)

Configures the model for training.

Arguments

optimizer: String (name of optimizer) or optimizer instance. See tf.keras.optimizers. loss: String (name of objective function), objective function or tf.keras.losses.Loss instance. See tf.keras.losses. An objective function is any callable with the signature loss = fn(y_true, y_pred), where y_true = ground truth values with shape = [batch_size, d0, .. dN], except sparse loss functions such as sparse categorical crossentropy where shape = [batch_size, d0, .. dN-1]. y_pred = predicted values with shape = [batch_size, d0, .. dN]. It returns a weighted loss float tensor. If a custom Loss instance is used and reduction is set to NONE, return value has the shape [batch_size, d0, .. dN-1] ie. per-sample or per-timestep loss values; otherwise, it is a scalar. If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses. metrics: List of metrics to be evaluated by the model during training and testing. Each of this can be a string (name of a built-in function), function or a tf.keras.metrics.Metric instance. See tf.keras.metrics. Typically you will use metrics=['accuracy']. A function is any callable with the signature result = fn(y_true, y_pred). To specify different metrics for different outputs of a multi-output model, you could also pass a dictionary, such as metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}. You can also pass a list (len = len(outputs)) of lists of metrics such as metrics=[['accuracy'], ['accuracy', 'mse']] or metrics=['accuracy', ['accuracy', 'mse']]. When you pass the strings 'accuracy' or 'acc', we convert this to one of tf.keras.metrics.BinaryAccuracy, tf.keras.metrics.CategoricalAccuracy, tf.keras.metrics.SparseCategoricalAccuracy based on the loss function used and the model output shape. We do a similar conversion for the strings 'crossentropy' and 'ce' as well. loss_weights: Optional list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the loss_weights coefficients. If a list, it is expected to have a 1:1 mapping to the model's outputs. If a dict, it is expected to map output names (strings) to scalar coefficients. sample_weight_mode: If you need to do timestep-wise sample weighting (2D weights), set this to "temporal". None defaults to sample-wise weights (1D). If the model has multiple outputs, you can use a different sample_weight_mode on each output by passing a dictionary or a list of modes. weighted_metrics: List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing. **kwargs: Any additional arguments. For eager execution, pass run_eagerly=True.

Raises

ValueError
In case of invalid arguments for optimizer, loss, metrics or sample_weight_mode.
Expand source code
def compile(self,
  optimizer='rmsprop',
  loss=None,
  metrics=None,
  loss_weights=None,
  weighted_metrics=None,
  run_eagerly=None,
  **kwargs):

  self.full_encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
  self.image_encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
  self.decoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
  self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
  self.image_discriminator.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)  

  super().compile(
    optimizer=optimizer,
    loss=loss,
    metrics=metrics,
    loss_weights=loss_weights,
    weighted_metrics=weighted_metrics,
    run_eagerly=run_eagerly,
    **kwargs)
def discriminator_loss(self, real, fake)
Expand source code
def discriminator_loss(self,real,fake):
  real_loss = self.cross_entropy(tf.ones_like(real), real)
  fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
  return (real_loss + fake_loss)
def save(self, output_dir: str)

Save the model to an output folder

Args

output_dir : str
Target output folder
Expand source code
def save(self,output_dir: str):
  """Save the model to an output folder
  
  Args:
      output_dir (str): Target output folder
  """

  self.full_encoder.save(str(BytestreamPath(output_dir) / "full_encoder"))
  self.image_encoder.save(str(BytestreamPath(output_dir) / "image_encoder"))
  self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
  self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))
  self.image_discriminator.save(str(BytestreamPath(output_dir) / "image_discriminator"))


  # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
  #   json.dump(d,f)
def scramble_font_batches(self, x: tensorflow.python.framework.ops.Tensor, labels: tensorflow.python.framework.ops.Tensor)

Creates a scrambled copy of a minibatch in which individual fonts are randomly shuffled. Returns the original minibatch in addition to the shuffled version.

Args

x : tf.Tensor
Feature tensor
labels : tf.Tensor
Label tensor

Returns

t.Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]
return scrambled and original feature-label pairs, in that order.
Expand source code
def scramble_font_batches(self, x: tf.Tensor, labels: tf.Tensor):
  """Creates a scrambled copy of a minibatch in which individual fonts are randomly shuffled. Returns the original minibatch in addition to the shuffled version.
  
  Args:
      x (tf.Tensor): Feature tensor
      labels (tf.Tensor): Label tensor
  
  Returns:
      t.Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: return scrambled and original feature-label pairs, in that order.
  """
  #
  x_shape = tf.shape(x)
  n_fonts = x_shape[0]
  #
  style_x = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
  style_y = tf.TensorArray(tf.float32, size=0, dynamic_size=True)

  outcome_x = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
  outcome_y = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
  #
  for k in range(n_fonts):
    x_k, labels_k = x[k], labels[k]
    x_k_shape = tf.shape(x_k)
    shuffling_idx = tf.range(start=0, limit=tf.shape(x_k)[0], dtype=tf.int32)
    scrambled = tf.random.shuffle(shuffling_idx)
    #
    style_x = style_x.write(k, tf.gather(x_k, scrambled, axis=0))
    style_y = style_y.write(k, tf.gather(labels_k, scrambled, axis=0))

    outcome_x = outcome_x.write(k, x_k)
    outcome_y = outcome_y.write(k, labels_k)
    #
  #
  #
  return style_x.concat(), style_y.concat(), outcome_x.concat(), outcome_y.concat()
def train_step(self, inputs)

The logic for one training step.

This method can be overridden to support custom training logic. This method is called by Model.make_train_function.

This method should contain the mathemetical logic for one step of training. This typically includes the forward pass, loss calculation, backpropagation, and metric updates.

Configuration details for how this logic is run (e.g. tf.function and tf.distribute.Strategy settings), should be left to Model.make_train_function, which can also be overridden.

Arguments

data: A nested structure of Tensors.

Returns

A dict containing values that will be passed to tf.keras.callbacks.CallbackList.on_train_batch_end. Typically, the values of the Model's metrics are returned. Example: {'loss': 0.2, 'accuracy': 0.7}.

Expand source code
def train_step(self, inputs):
  #prior_sampler = tf.random.normal
  x, labels = inputs

  #self.prior_batch_size = x.shape[0]
  style_x, style_y, outcome_x, outcome_y = self.scramble_font_batches(x ,labels)

  n_examples = tf.shape(style_x)[0]
  with tf.GradientTape(persistent=True) as tape:

    # apply autoencoder
    image_precode = self.image_encoder(style_x, training=True)
    full_precode = tf.concat([image_precode, style_y], axis=-1)
    code = self.full_encoder(full_precode, training=True)
    extended_code = tf.concat([code,outcome_y],axis=-1)
    decoded = self.decoder(extended_code,training=True)  

    # apply prior_discriminator model
    prior_samples = self.prior_sampler(shape=(n_examples,code.shape[1]))
    real_prior = self.prior_discriminator(prior_samples,training=True)
    fake_prior = self.prior_discriminator(code,training=True)
    prior_classification_loss = self.discriminator_loss(real_prior,fake_prior)

    #apply image_discriminator model
    real_image = self.image_discriminator(outcome_x)
    fake_image = self.image_discriminator(decoded)
    image_classification_loss = self.discriminator_loss(real_image, fake_image)

    # compute losses for the models
    #reconstruction_loss = tf.keras.losses.MSE(outcome_x,decoded) 
    
    mixed_loss = -(prior_classification_loss + image_classification_loss)

    #mixed_loss = -(1-self.rec_loss_weight)*(prior_classification_loss) + self.rec_loss_weight*(reconstruction_loss)


  # Compute gradients
  prior_discriminator_gradients = tape.gradient(prior_classification_loss,self.prior_discriminator.trainable_variables)
  decoder_gradients = tape.gradient(image_classification_loss, self.decoder.trainable_variables)
  image_encoder_gradients = tape.gradient(mixed_loss, self.image_encoder.trainable_variables)
  full_encoder_gradients = tape.gradient(mixed_loss, self.full_encoder.trainable_variables)
  image_discriminator_gradients = tape.gradient(image_classification_loss, self.image_discriminator.trainable_variables)

  #apply gradients
  self.prior_discriminator.optimizer.apply_gradients(zip(prior_discriminator_gradients,self.prior_discriminator.trainable_variables))
  self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
  self.image_encoder.optimizer.apply_gradients(zip(image_encoder_gradients,self.image_encoder.trainable_variables))
  self.full_encoder.optimizer.apply_gradients(zip(full_encoder_gradients,self.full_encoder.trainable_variables))
  self.image_discriminator.optimizer.apply_gradients(zip(image_discriminator_gradients, self.image_discriminator.trainable_variables))

  # compute metrics
  #self.mse_metric.update_state(outcome_x,decoded)

  discr_true = tf.concat([tf.ones_like(real_prior),tf.zeros_like(fake_prior)],axis=0)
  discr_predicted = tf.concat([real_prior,fake_prior],axis=0)
  self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

  discr_true = tf.concat([tf.ones_like(real_image),tf.zeros_like(fake_image)],axis=0)
  discr_predicted = tf.concat([real_image,fake_image],axis=0)
  self.image_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

  return {
  "image discriminator accuracy": self.image_accuracy_metric.result(), 
  "prior discriminator accuracy": self.prior_accuracy_metric.result()}
class PureFontStyleSAAE (full_encoder: tensorflow.python.keras.engine.training.Model, image_encoder: tensorflow.python.keras.engine.training.Model, decoder: tensorflow.python.keras.engine.training.Model, prior_discriminator: tensorflow.python.keras.engine.training.Model, reconstruction_loss_weight: float = 0.5, prior_batch_size: int = 32, code_regularisation_weight=0)

This model is trained on all characters from one or more font files at a time; the aim is to encode the font's style as opposed to single characters' styles, which can happen when training with scrambled characters from different fonts and results in sometimes having different-looking image styles for a given style in latent space. This model works with characters from a single typeface at a time, and use the style from a given character to decode a different randomly chosen character in the same font, using the encoded style and target label information.

Summary

Args

decoder : tf.keras.Model
Decoder model that maps style and characters to images
prior_discriminator : tf.keras.Model
Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
full_encoder : tf.keras.Model
Encoder that takes high-level image features and labels to produce embedded representations
image_encoder : tf.keras.Model
Encoder for image features
reconstruction_loss_weight : float, optional
Weight of reconstruction loss at training time. Should be between 0 and 1.
n_classes : int
number of labeled classes
prior_batch_size : int
Batch size from prior distribution at training time
Expand source code
class PureFontStyleSAAE(tf.keras.Model):

  """This model is trained on all characters from one or more font files at a time; the aim is to encode the font's style as opposed to single characters' styles, which can happen when training with scrambled characters from different fonts and results in sometimes having different-looking image styles for a given style in latent space. This model works with characters from a single typeface at a time, and use the style from a given character to decode a different randomly chosen character in the same font, using the encoded style and target label information. 
  """
  prior_accuracy_metric = tf.keras.metrics.Accuracy(name="prior adversarial accuracy")
  mse_metric = tf.keras.metrics.MeanSquaredError(name="Reconstruction error")
  cross_entropy_metric = tf.keras.metrics.BinaryCrossentropy(name="Prior adversarial cross entropy", from_logits=False)

  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False) #assumes prior_discriminator outputs 

    ## this __init__ has to be copy pasted from the model above, because Tensorflow hates good coding practices aparently
  def __init__(
    self,
    full_encoder: tf.keras.Model,
    image_encoder: tf.keras.Model,
    decoder: tf.keras.Model,
    prior_discriminator: tf.keras.Model,
    reconstruction_loss_weight:float=0.5,
    prior_batch_size:int=32,
    code_regularisation_weight=0):
    """Summary
    
    Args:
        decoder (tf.keras.Model): Decoder model that maps style and characters to images
        prior_discriminator (tf.keras.Model): Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
        full_encoder (tf.keras.Model): Encoder that takes high-level image features and labels to produce embedded representations
        image_encoder (tf.keras.Model): Encoder for image features
        reconstruction_loss_weight (float, optional): Weight of reconstruction loss at training time. Should be between 0 and 1.
        n_classes (int): number of labeled classes
        prior_batch_size (int): Batch size from prior distribution at training time
    """
    super(PureFontStyleSAAE, self).__init__()

    self.full_encoder = full_encoder
    self.image_encoder = image_encoder
    self.decoder = decoder
    self.prior_discriminator = prior_discriminator
    self.rec_loss_weight = min(max(reconstruction_loss_weight,0),1)
    self.prior_batch_size = prior_batch_size

    self.prior_sampler = tf.random.normal
    self.code_regularisation_weight = code_regularisation_weight
    # list of embedded models as instance attributes 
    self.model_list = ["full_encoder", "image_encoder", "decoder", "prior_discriminator"]


  def prior_discriminator_loss(self,real,fake):
    real_loss = self.cross_entropy(tf.ones_like(real), real)
    fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss)/(2*self.prior_batch_size)

  def compile(self,
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    **kwargs):

    self.full_encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.image_encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.decoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)  

    super().compile(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      loss_weights=loss_weights,
      weighted_metrics=weighted_metrics,
      run_eagerly=run_eagerly,
      **kwargs)

  def __call__(self, x, training=True, mask=None):
    return self.image_encoder(x, training=training)

  def scramble_font_batches(self, x: tf.Tensor, labels: tf.Tensor):
    """Creates a scrambled copy of a minibatch in which individual fonts are randomly shuffled. Returns the original minibatch in addition to the shuffled version.
    
    Args:
        x (tf.Tensor): Feature tensor
        labels (tf.Tensor): Label tensor
    
    Returns:
        t.Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: return scrambled and original feature-label pairs, in that order.
    """
    #
    x_shape = tf.shape(x)
    n_fonts = x_shape[0]
    #
    style_x = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
    style_y = tf.TensorArray(tf.float32, size=0, dynamic_size=True)

    outcome_x = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
    outcome_y = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
    #
    for k in range(n_fonts):
      x_k, labels_k = x[k], labels[k]
      x_k_shape = tf.shape(x_k)
      shuffling_idx = tf.range(start=0, limit=tf.shape(x_k)[0], dtype=tf.int32)
      scrambled = tf.random.shuffle(shuffling_idx)
      #
      style_x = style_x.write(k, tf.gather(x_k, scrambled, axis=0))
      style_y = style_y.write(k, tf.gather(labels_k, scrambled, axis=0))

      outcome_x = outcome_x.write(k, x_k)
      outcome_y = outcome_y.write(k, labels_k)
      #
    #
    #
    return style_x.concat(), style_y.concat(), outcome_x.concat(), outcome_y.concat()

  def train_step(self, inputs):
    #prior_sampler = tf.random.normal
    x, labels = inputs

    #self.prior_batch_size = x.shape[0]
    style_x, style_y, outcome_x, outcome_y = self.scramble_font_batches(x ,labels)

    n_examples = tf.shape(style_x)[0]
    with tf.GradientTape(persistent=True) as tape:

      # apply autoencoder
      image_precode = self.image_encoder(style_x, training=True)
      full_precode = tf.concat([image_precode, style_y], axis=-1)
      code = self.full_encoder(full_precode, training=True)
      extended_code = tf.concat([code,outcome_y],axis=-1)
      decoded = self.decoder(extended_code,training=True)  

      # apply prior_discriminator model
      prior_samples = self.prior_sampler(shape=(n_examples,code.shape[1]))
      real = self.prior_discriminator(prior_samples,training=True)
      fake = self.prior_discriminator(code,training=True)


      # Moment regularization for embedded representation to keep it closer to standard normal
      reg = self.code_regularisation_weight*(tf.reduce_mean(code)**2 + (tf.reduce_mean(code**2) - 1.0)**2)/2

      # compute losses for the models
      reconstruction_loss = tf.keras.losses.MSE(outcome_x,decoded) 
      classification_loss = self.prior_discriminator_loss(real,fake)

      mixed_loss = -(1-self.rec_loss_weight)*(classification_loss) + self.rec_loss_weight*(reconstruction_loss + reg)


    # Compute gradients
    discr_gradients = tape.gradient(classification_loss,self.prior_discriminator.trainable_variables)
    decoder_gradients = tape.gradient(reconstruction_loss, self.decoder.trainable_variables)
    image_encoder_gradients = tape.gradient(mixed_loss, self.image_encoder.trainable_variables)
    full_encoder_gradients = tape.gradient(mixed_loss, self.full_encoder.trainable_variables)

    #apply gradients
    self.prior_discriminator.optimizer.apply_gradients(zip(discr_gradients,self.prior_discriminator.trainable_variables))
    self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
    self.image_encoder.optimizer.apply_gradients(zip(image_encoder_gradients,self.image_encoder.trainable_variables))
    self.full_encoder.optimizer.apply_gradients(zip(full_encoder_gradients,self.full_encoder.trainable_variables))

    # compute metrics
    self.mse_metric.update_state(outcome_x,decoded)

    discr_true = tf.concat([tf.ones_like(real),tf.zeros_like(fake)],axis=0)
    discr_predicted = tf.concat([real,fake],axis=0)
    self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

    return {
    "reconstruction MSE": self.mse_metric.result(), 
    "discriminator accuracy": self.prior_accuracy_metric.result()}



  def save(self,output_dir: str):
    """Save the model to an output folder
    
    Args:
        output_dir (str): Target output folder
    """

    self.full_encoder.save(str(BytestreamPath(output_dir) / "full_encoder"))
    self.image_encoder.save(str(BytestreamPath(output_dir) / "image_encoder"))
    self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
    self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))

    d = {
      "reconstruction_loss_weight":self.rec_loss_weight,
      "prior_batch_size": self.prior_batch_size,
      "code_regularisation_weight": self.code_regularisation_weight
    }

    # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
    #   json.dump(d,f)

    (BytestreamPath(output_dir) / "aae-params.json").write_bytes(json.dumps(d).encode())

  @classmethod
  def load(cls, input_dir: str):
    """Loads a saved instance of this class
    
    Args:
        input_dir (str): Target input folder
    
    Returns:
        SAAE: Loaded model
    """
    full_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "full_encoder"))
    image_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_encoder"))
    decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
    prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))

    # with open(str(BytestreamPath(input_dir) / "aae-params.json"),"r") as f:
    #   d = json.loads(f.read())
    d_string = (BytestreamPath(input_dir) / "aae-params.json").read_bytes().decode("utf-8")
    d = json.loads(d_string)

    return cls(
      image_encoder = image_encoder, 
      full_encoder = full_encoder, 
      decoder = decoder, 
      prior_discriminator = prior_discriminator, 
      **d)

Ancestors

  • tensorflow.python.keras.engine.training.Model
  • tensorflow.python.keras.engine.network.Network
  • tensorflow.python.keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.training.tracking.tracking.AutoTrackable
  • tensorflow.python.training.tracking.base.Trackable
  • tensorflow.python.keras.utils.version_utils.LayerVersionSelector
  • tensorflow.python.keras.utils.version_utils.ModelVersionSelector

Class variables

var cross_entropy
var cross_entropy_metric
var mse_metric
var prior_accuracy_metric

Static methods

def load(input_dir: str)

Loads a saved instance of this class

Args

input_dir : str
Target input folder

Returns

SAAE
Loaded model
Expand source code
@classmethod
def load(cls, input_dir: str):
  """Loads a saved instance of this class
  
  Args:
      input_dir (str): Target input folder
  
  Returns:
      SAAE: Loaded model
  """
  full_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "full_encoder"))
  image_encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "image_encoder"))
  decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
  prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))

  # with open(str(BytestreamPath(input_dir) / "aae-params.json"),"r") as f:
  #   d = json.loads(f.read())
  d_string = (BytestreamPath(input_dir) / "aae-params.json").read_bytes().decode("utf-8")
  d = json.loads(d_string)

  return cls(
    image_encoder = image_encoder, 
    full_encoder = full_encoder, 
    decoder = decoder, 
    prior_discriminator = prior_discriminator, 
    **d)

Methods

def compile(self, optimizer='rmsprop', loss=None, metrics=None, loss_weights=None, weighted_metrics=None, run_eagerly=None, **kwargs)

Configures the model for training.

Arguments

optimizer: String (name of optimizer) or optimizer instance. See tf.keras.optimizers. loss: String (name of objective function), objective function or tf.keras.losses.Loss instance. See tf.keras.losses. An objective function is any callable with the signature loss = fn(y_true, y_pred), where y_true = ground truth values with shape = [batch_size, d0, .. dN], except sparse loss functions such as sparse categorical crossentropy where shape = [batch_size, d0, .. dN-1]. y_pred = predicted values with shape = [batch_size, d0, .. dN]. It returns a weighted loss float tensor. If a custom Loss instance is used and reduction is set to NONE, return value has the shape [batch_size, d0, .. dN-1] ie. per-sample or per-timestep loss values; otherwise, it is a scalar. If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses. metrics: List of metrics to be evaluated by the model during training and testing. Each of this can be a string (name of a built-in function), function or a tf.keras.metrics.Metric instance. See tf.keras.metrics. Typically you will use metrics=['accuracy']. A function is any callable with the signature result = fn(y_true, y_pred). To specify different metrics for different outputs of a multi-output model, you could also pass a dictionary, such as metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}. You can also pass a list (len = len(outputs)) of lists of metrics such as metrics=[['accuracy'], ['accuracy', 'mse']] or metrics=['accuracy', ['accuracy', 'mse']]. When you pass the strings 'accuracy' or 'acc', we convert this to one of tf.keras.metrics.BinaryAccuracy, tf.keras.metrics.CategoricalAccuracy, tf.keras.metrics.SparseCategoricalAccuracy based on the loss function used and the model output shape. We do a similar conversion for the strings 'crossentropy' and 'ce' as well. loss_weights: Optional list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the loss_weights coefficients. If a list, it is expected to have a 1:1 mapping to the model's outputs. If a dict, it is expected to map output names (strings) to scalar coefficients. sample_weight_mode: If you need to do timestep-wise sample weighting (2D weights), set this to "temporal". None defaults to sample-wise weights (1D). If the model has multiple outputs, you can use a different sample_weight_mode on each output by passing a dictionary or a list of modes. weighted_metrics: List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing. **kwargs: Any additional arguments. For eager execution, pass run_eagerly=True.

Raises

ValueError
In case of invalid arguments for optimizer, loss, metrics or sample_weight_mode.
Expand source code
def compile(self,
  optimizer='rmsprop',
  loss=None,
  metrics=None,
  loss_weights=None,
  weighted_metrics=None,
  run_eagerly=None,
  **kwargs):

  self.full_encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
  self.image_encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
  self.decoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
  self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)  

  super().compile(
    optimizer=optimizer,
    loss=loss,
    metrics=metrics,
    loss_weights=loss_weights,
    weighted_metrics=weighted_metrics,
    run_eagerly=run_eagerly,
    **kwargs)
def prior_discriminator_loss(self, real, fake)
Expand source code
def prior_discriminator_loss(self,real,fake):
  real_loss = self.cross_entropy(tf.ones_like(real), real)
  fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
  return (real_loss + fake_loss)/(2*self.prior_batch_size)
def save(self, output_dir: str)

Save the model to an output folder

Args

output_dir : str
Target output folder
Expand source code
def save(self,output_dir: str):
  """Save the model to an output folder
  
  Args:
      output_dir (str): Target output folder
  """

  self.full_encoder.save(str(BytestreamPath(output_dir) / "full_encoder"))
  self.image_encoder.save(str(BytestreamPath(output_dir) / "image_encoder"))
  self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
  self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))

  d = {
    "reconstruction_loss_weight":self.rec_loss_weight,
    "prior_batch_size": self.prior_batch_size,
    "code_regularisation_weight": self.code_regularisation_weight
  }

  # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
  #   json.dump(d,f)

  (BytestreamPath(output_dir) / "aae-params.json").write_bytes(json.dumps(d).encode())
def scramble_font_batches(self, x: tensorflow.python.framework.ops.Tensor, labels: tensorflow.python.framework.ops.Tensor)

Creates a scrambled copy of a minibatch in which individual fonts are randomly shuffled. Returns the original minibatch in addition to the shuffled version.

Args

x : tf.Tensor
Feature tensor
labels : tf.Tensor
Label tensor

Returns

t.Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]
return scrambled and original feature-label pairs, in that order.
Expand source code
def scramble_font_batches(self, x: tf.Tensor, labels: tf.Tensor):
  """Creates a scrambled copy of a minibatch in which individual fonts are randomly shuffled. Returns the original minibatch in addition to the shuffled version.
  
  Args:
      x (tf.Tensor): Feature tensor
      labels (tf.Tensor): Label tensor
  
  Returns:
      t.Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: return scrambled and original feature-label pairs, in that order.
  """
  #
  x_shape = tf.shape(x)
  n_fonts = x_shape[0]
  #
  style_x = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
  style_y = tf.TensorArray(tf.float32, size=0, dynamic_size=True)

  outcome_x = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
  outcome_y = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
  #
  for k in range(n_fonts):
    x_k, labels_k = x[k], labels[k]
    x_k_shape = tf.shape(x_k)
    shuffling_idx = tf.range(start=0, limit=tf.shape(x_k)[0], dtype=tf.int32)
    scrambled = tf.random.shuffle(shuffling_idx)
    #
    style_x = style_x.write(k, tf.gather(x_k, scrambled, axis=0))
    style_y = style_y.write(k, tf.gather(labels_k, scrambled, axis=0))

    outcome_x = outcome_x.write(k, x_k)
    outcome_y = outcome_y.write(k, labels_k)
    #
  #
  #
  return style_x.concat(), style_y.concat(), outcome_x.concat(), outcome_y.concat()
def train_step(self, inputs)

The logic for one training step.

This method can be overridden to support custom training logic. This method is called by Model.make_train_function.

This method should contain the mathemetical logic for one step of training. This typically includes the forward pass, loss calculation, backpropagation, and metric updates.

Configuration details for how this logic is run (e.g. tf.function and tf.distribute.Strategy settings), should be left to Model.make_train_function, which can also be overridden.

Arguments

data: A nested structure of Tensors.

Returns

A dict containing values that will be passed to tf.keras.callbacks.CallbackList.on_train_batch_end. Typically, the values of the Model's metrics are returned. Example: {'loss': 0.2, 'accuracy': 0.7}.

Expand source code
def train_step(self, inputs):
  #prior_sampler = tf.random.normal
  x, labels = inputs

  #self.prior_batch_size = x.shape[0]
  style_x, style_y, outcome_x, outcome_y = self.scramble_font_batches(x ,labels)

  n_examples = tf.shape(style_x)[0]
  with tf.GradientTape(persistent=True) as tape:

    # apply autoencoder
    image_precode = self.image_encoder(style_x, training=True)
    full_precode = tf.concat([image_precode, style_y], axis=-1)
    code = self.full_encoder(full_precode, training=True)
    extended_code = tf.concat([code,outcome_y],axis=-1)
    decoded = self.decoder(extended_code,training=True)  

    # apply prior_discriminator model
    prior_samples = self.prior_sampler(shape=(n_examples,code.shape[1]))
    real = self.prior_discriminator(prior_samples,training=True)
    fake = self.prior_discriminator(code,training=True)


    # Moment regularization for embedded representation to keep it closer to standard normal
    reg = self.code_regularisation_weight*(tf.reduce_mean(code)**2 + (tf.reduce_mean(code**2) - 1.0)**2)/2

    # compute losses for the models
    reconstruction_loss = tf.keras.losses.MSE(outcome_x,decoded) 
    classification_loss = self.prior_discriminator_loss(real,fake)

    mixed_loss = -(1-self.rec_loss_weight)*(classification_loss) + self.rec_loss_weight*(reconstruction_loss + reg)


  # Compute gradients
  discr_gradients = tape.gradient(classification_loss,self.prior_discriminator.trainable_variables)
  decoder_gradients = tape.gradient(reconstruction_loss, self.decoder.trainable_variables)
  image_encoder_gradients = tape.gradient(mixed_loss, self.image_encoder.trainable_variables)
  full_encoder_gradients = tape.gradient(mixed_loss, self.full_encoder.trainable_variables)

  #apply gradients
  self.prior_discriminator.optimizer.apply_gradients(zip(discr_gradients,self.prior_discriminator.trainable_variables))
  self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
  self.image_encoder.optimizer.apply_gradients(zip(image_encoder_gradients,self.image_encoder.trainable_variables))
  self.full_encoder.optimizer.apply_gradients(zip(full_encoder_gradients,self.full_encoder.trainable_variables))

  # compute metrics
  self.mse_metric.update_state(outcome_x,decoded)

  discr_true = tf.concat([tf.ones_like(real),tf.zeros_like(fake)],axis=0)
  discr_predicted = tf.concat([real,fake],axis=0)
  self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

  return {
  "reconstruction MSE": self.mse_metric.result(), 
  "discriminator accuracy": self.prior_accuracy_metric.result()}
class TensorFontStyleSAAE (encoder: tensorflow.python.keras.engine.training.Model, decoder: tensorflow.python.keras.engine.training.Model, prior_discriminator: tensorflow.python.keras.engine.training.Model, reconstruction_loss_weight: float = 0.5)

This model treats characters in a font as channels in an image. The encoder takes the whole font as an image with (n_char) channels.

Summary

Args

decoder : tf.keras.Model
Decoder model that maps style and characters to images
prior_discriminator : tf.keras.Model
Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
full_encoder : tf.keras.Model
Encoder that takes high-level image features and labels to produce embedded representations
encoder : tf.keras.Model
Encoder for image features
reconstruction_loss_weight : float, optional
Weight of reconstruction loss at training time. Should be between 0 and 1.
n_classes : int
number of labeled classes
prior_batch_size : int
Batch size from prior distribution at training time
Expand source code
class TensorFontStyleSAAE(tf.keras.Model):

  """This model treats characters in a font as channels in an image. The encoder takes the whole font as an image with (n_char) channels.
  """
  prior_accuracy_metric = tf.keras.metrics.Accuracy(name="prior adversarial accuracy")
  mse_metric = tf.keras.metrics.MeanSquaredError(name="Reconstruction error")
  cross_entropy_metric = tf.keras.metrics.BinaryCrossentropy(name="Prior adversarial cross entropy", from_logits=False)

  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False) #assumes prior_discriminator outputs 

    ## this __init__ has to be copy pasted from the model above, because Tensorflow hates good coding practices aparently
  def __init__(
    self,
    encoder: tf.keras.Model,
    decoder: tf.keras.Model,
    prior_discriminator: tf.keras.Model,
    reconstruction_loss_weight:float=0.5):
    """Summary
    
    Args:
        decoder (tf.keras.Model): Decoder model that maps style and characters to images
        prior_discriminator (tf.keras.Model): Discriminator between the embeddings' distribution and the target distribution, e.g. multivariate standard normal.
        full_encoder (tf.keras.Model): Encoder that takes high-level image features and labels to produce embedded representations
        encoder (tf.keras.Model): Encoder for image features
        reconstruction_loss_weight (float, optional): Weight of reconstruction loss at training time. Should be between 0 and 1.
        n_classes (int): number of labeled classes
        prior_batch_size (int): Batch size from prior distribution at training time
    """
    super(TensorFontStyleSAAE, self).__init__()

    self.encoder = encoder
    self.decoder = decoder
    self.prior_discriminator = prior_discriminator
    self.rec_loss_weight = min(max(reconstruction_loss_weight,0),1)

    self.prior_sampler = tf.random.normal
    # list of embedded models as instance attributes 
    self.model_list = ["encoder", "decoder", "prior_discriminator"]


  def prior_discriminator_loss(self,real,fake, n_fonts):
    real_loss = self.cross_entropy(tf.ones_like(real), real)
    fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss)/tf.cast(n_fonts, tf.float32)

  def compile(self,
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    **kwargs):

    self.encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.decoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
    self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)  

    super().compile(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      loss_weights=loss_weights,
      weighted_metrics=weighted_metrics,
      run_eagerly=run_eagerly,
      **kwargs)

  def __call__(self, x, training=True, mask=None):
    return self.encoder(x, training=training)

  def train_step(self, inputs):
    #prior_sampler = tf.random.normal
    x, _ = inputs

    x_shape = tf.shape(x)
    n_fonts = x_shape[0]
    n_chars = x_shape[1]
    height = x_shape[2]
    width = x_shape[3]

    x = tf.reshape(x, (n_fonts, n_chars, height, width))
    x = tf.transpose(x, [0,2,3,1])

    with tf.GradientTape(persistent=True) as tape:

      # apply autoencoder
      embedding = self.encoder(x, training=True)
      decoded = self.decoder(embedding,training=True)  

      # apply prior_discriminator model
      prior_samples = self.prior_sampler(shape=(n_fonts,embedding.shape[1]))
      real = self.prior_discriminator(prior_samples,training=True)
      fake = self.prior_discriminator(embedding,training=True)

      # compute losses for the models
      reconstruction_loss = tf.keras.losses.MSE(x,decoded) 
      classification_loss = self.prior_discriminator_loss(real,fake, n_fonts)

      mixed_loss = -(1-self.rec_loss_weight)*classification_loss + self.rec_loss_weight*reconstruction_loss

    # Compute gradients
    discr_gradients = tape.gradient(classification_loss,self.prior_discriminator.trainable_variables)
    decoder_gradients = tape.gradient(reconstruction_loss, self.decoder.trainable_variables)
    encoder_gradients = tape.gradient(mixed_loss, self.encoder.trainable_variables)

    #apply gradients
    self.prior_discriminator.optimizer.apply_gradients(zip(discr_gradients,self.prior_discriminator.trainable_variables))
    self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
    self.encoder.optimizer.apply_gradients(zip(encoder_gradients,self.encoder.trainable_variables))

    # compute metrics
    self.mse_metric.update_state(x,decoded)

    discr_true = tf.concat([tf.ones_like(real),tf.zeros_like(fake)],axis=0)
    discr_predicted = tf.concat([real,fake],axis=0)
    self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

    return {
    "reconstruction MSE": self.mse_metric.result(), 
    "discriminator accuracy": self.prior_accuracy_metric.result()}



  def save(self,output_dir: str):
    """Save the model to an output folder
    
    Args:
        output_dir (str): Target output folder
    """

    self.encoder.save(str(BytestreamPath(output_dir) / "encoder"))
    self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
    self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))

    d = {
      "reconstruction_loss_weight":self.rec_loss_weight
      }

    # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
    #   json.dump(d,f)

    (BytestreamPath(output_dir) / "aae-params.json").write_bytes(json.dumps(d).encode())

  @classmethod
  def load(cls, input_dir: str):
    """Loads a saved instance of this class
    
    Args:
        input_dir (str): Target input folder
    
    Returns:
        SAAE: Loaded model
    """
    encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "encoder"))
    decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
    prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))

    # with open(str(BytestreamPath(input_dir) / "aae-params.json"),"r") as f:
    #   d = json.loads(f.read())
    d_string = (BytestreamPath(input_dir) / "aae-params.json").read_bytes().decode("utf-8")
    d = json.loads(d_string)

    return cls(
      encoder = encoder, 
      decoder = decoder, 
      prior_discriminator = prior_discriminator, 
      **d)

Ancestors

  • tensorflow.python.keras.engine.training.Model
  • tensorflow.python.keras.engine.network.Network
  • tensorflow.python.keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.training.tracking.tracking.AutoTrackable
  • tensorflow.python.training.tracking.base.Trackable
  • tensorflow.python.keras.utils.version_utils.LayerVersionSelector
  • tensorflow.python.keras.utils.version_utils.ModelVersionSelector

Class variables

var cross_entropy
var cross_entropy_metric
var mse_metric
var prior_accuracy_metric

Static methods

def load(input_dir: str)

Loads a saved instance of this class

Args

input_dir : str
Target input folder

Returns

SAAE
Loaded model
Expand source code
@classmethod
def load(cls, input_dir: str):
  """Loads a saved instance of this class
  
  Args:
      input_dir (str): Target input folder
  
  Returns:
      SAAE: Loaded model
  """
  encoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "encoder"))
  decoder = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "decoder"))
  prior_discriminator = tf.keras.models.load_model(str(BytestreamPath(input_dir) / "prior_discriminator"))

  # with open(str(BytestreamPath(input_dir) / "aae-params.json"),"r") as f:
  #   d = json.loads(f.read())
  d_string = (BytestreamPath(input_dir) / "aae-params.json").read_bytes().decode("utf-8")
  d = json.loads(d_string)

  return cls(
    encoder = encoder, 
    decoder = decoder, 
    prior_discriminator = prior_discriminator, 
    **d)

Methods

def compile(self, optimizer='rmsprop', loss=None, metrics=None, loss_weights=None, weighted_metrics=None, run_eagerly=None, **kwargs)

Configures the model for training.

Arguments

optimizer: String (name of optimizer) or optimizer instance. See tf.keras.optimizers. loss: String (name of objective function), objective function or tf.keras.losses.Loss instance. See tf.keras.losses. An objective function is any callable with the signature loss = fn(y_true, y_pred), where y_true = ground truth values with shape = [batch_size, d0, .. dN], except sparse loss functions such as sparse categorical crossentropy where shape = [batch_size, d0, .. dN-1]. y_pred = predicted values with shape = [batch_size, d0, .. dN]. It returns a weighted loss float tensor. If a custom Loss instance is used and reduction is set to NONE, return value has the shape [batch_size, d0, .. dN-1] ie. per-sample or per-timestep loss values; otherwise, it is a scalar. If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses. metrics: List of metrics to be evaluated by the model during training and testing. Each of this can be a string (name of a built-in function), function or a tf.keras.metrics.Metric instance. See tf.keras.metrics. Typically you will use metrics=['accuracy']. A function is any callable with the signature result = fn(y_true, y_pred). To specify different metrics for different outputs of a multi-output model, you could also pass a dictionary, such as metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}. You can also pass a list (len = len(outputs)) of lists of metrics such as metrics=[['accuracy'], ['accuracy', 'mse']] or metrics=['accuracy', ['accuracy', 'mse']]. When you pass the strings 'accuracy' or 'acc', we convert this to one of tf.keras.metrics.BinaryAccuracy, tf.keras.metrics.CategoricalAccuracy, tf.keras.metrics.SparseCategoricalAccuracy based on the loss function used and the model output shape. We do a similar conversion for the strings 'crossentropy' and 'ce' as well. loss_weights: Optional list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the loss_weights coefficients. If a list, it is expected to have a 1:1 mapping to the model's outputs. If a dict, it is expected to map output names (strings) to scalar coefficients. sample_weight_mode: If you need to do timestep-wise sample weighting (2D weights), set this to "temporal". None defaults to sample-wise weights (1D). If the model has multiple outputs, you can use a different sample_weight_mode on each output by passing a dictionary or a list of modes. weighted_metrics: List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing. **kwargs: Any additional arguments. For eager execution, pass run_eagerly=True.

Raises

ValueError
In case of invalid arguments for optimizer, loss, metrics or sample_weight_mode.
Expand source code
def compile(self,
  optimizer='rmsprop',
  loss=None,
  metrics=None,
  loss_weights=None,
  weighted_metrics=None,
  run_eagerly=None,
  **kwargs):

  self.encoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
  self.decoder.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)
  self.prior_discriminator.compile(optimizer = copy.deepcopy(optimizer),run_eagerly=run_eagerly)  

  super().compile(
    optimizer=optimizer,
    loss=loss,
    metrics=metrics,
    loss_weights=loss_weights,
    weighted_metrics=weighted_metrics,
    run_eagerly=run_eagerly,
    **kwargs)
def prior_discriminator_loss(self, real, fake, n_fonts)
Expand source code
def prior_discriminator_loss(self,real,fake, n_fonts):
  real_loss = self.cross_entropy(tf.ones_like(real), real)
  fake_loss = self.cross_entropy(tf.zeros_like(fake), fake)
  return (real_loss + fake_loss)/tf.cast(n_fonts, tf.float32)
def save(self, output_dir: str)

Save the model to an output folder

Args

output_dir : str
Target output folder
Expand source code
def save(self,output_dir: str):
  """Save the model to an output folder
  
  Args:
      output_dir (str): Target output folder
  """

  self.encoder.save(str(BytestreamPath(output_dir) / "encoder"))
  self.decoder.save(str(BytestreamPath(output_dir) / "decoder"))
  self.prior_discriminator.save(str(BytestreamPath(output_dir) / "prior_discriminator"))

  d = {
    "reconstruction_loss_weight":self.rec_loss_weight
    }

  # with open(str(BytestreamPath(output_dir) / "aae-params.json"),"w") as f:
  #   json.dump(d,f)

  (BytestreamPath(output_dir) / "aae-params.json").write_bytes(json.dumps(d).encode())
def train_step(self, inputs)

The logic for one training step.

This method can be overridden to support custom training logic. This method is called by Model.make_train_function.

This method should contain the mathemetical logic for one step of training. This typically includes the forward pass, loss calculation, backpropagation, and metric updates.

Configuration details for how this logic is run (e.g. tf.function and tf.distribute.Strategy settings), should be left to Model.make_train_function, which can also be overridden.

Arguments

data: A nested structure of Tensors.

Returns

A dict containing values that will be passed to tf.keras.callbacks.CallbackList.on_train_batch_end. Typically, the values of the Model's metrics are returned. Example: {'loss': 0.2, 'accuracy': 0.7}.

Expand source code
def train_step(self, inputs):
  #prior_sampler = tf.random.normal
  x, _ = inputs

  x_shape = tf.shape(x)
  n_fonts = x_shape[0]
  n_chars = x_shape[1]
  height = x_shape[2]
  width = x_shape[3]

  x = tf.reshape(x, (n_fonts, n_chars, height, width))
  x = tf.transpose(x, [0,2,3,1])

  with tf.GradientTape(persistent=True) as tape:

    # apply autoencoder
    embedding = self.encoder(x, training=True)
    decoded = self.decoder(embedding,training=True)  

    # apply prior_discriminator model
    prior_samples = self.prior_sampler(shape=(n_fonts,embedding.shape[1]))
    real = self.prior_discriminator(prior_samples,training=True)
    fake = self.prior_discriminator(embedding,training=True)

    # compute losses for the models
    reconstruction_loss = tf.keras.losses.MSE(x,decoded) 
    classification_loss = self.prior_discriminator_loss(real,fake, n_fonts)

    mixed_loss = -(1-self.rec_loss_weight)*classification_loss + self.rec_loss_weight*reconstruction_loss

  # Compute gradients
  discr_gradients = tape.gradient(classification_loss,self.prior_discriminator.trainable_variables)
  decoder_gradients = tape.gradient(reconstruction_loss, self.decoder.trainable_variables)
  encoder_gradients = tape.gradient(mixed_loss, self.encoder.trainable_variables)

  #apply gradients
  self.prior_discriminator.optimizer.apply_gradients(zip(discr_gradients,self.prior_discriminator.trainable_variables))
  self.decoder.optimizer.apply_gradients(zip(decoder_gradients,self.decoder.trainable_variables))
  self.encoder.optimizer.apply_gradients(zip(encoder_gradients,self.encoder.trainable_variables))

  # compute metrics
  self.mse_metric.update_state(x,decoded)

  discr_true = tf.concat([tf.ones_like(real),tf.zeros_like(fake)],axis=0)
  discr_predicted = tf.concat([real,fake],axis=0)
  self.prior_accuracy_metric.update_state(discr_true,tf.round(discr_predicted))

  return {
  "reconstruction MSE": self.mse_metric.result(), 
  "discriminator accuracy": self.prior_accuracy_metric.result()}