6.3. Simple Quantized network

6.3.1. Quantize network example code

from functools import partial

import tensorflow as tf

from blueoil.networks.classification.base import Base
from blueoil.layers import batch_norm, conv2d


class SampleNetwork(Base):
    """Sample network with simple layer."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.activation = lambda x: tf.nn.leaky_relu(x, alpha=0.1, name="leaky_relu")

    def base(self, images, is_training):
        assert self.data_format == "NHWC"
        channel_data_format = "channels_last"

        self.inputs = self.images = images

        with tf.compat.v1.variable_scope("block_1"):
            conv = conv2d("conv", self.inputs, filters=32, kernel_size=3,
                          activation=None, use_bias=False, data_format=channel_data_format,
                          kernel_initializer=tf.contrib.layers.variance_scaling_initializer())
            batch_normed = batch_norm("bn", conv, is_training=is_training, decay=0.99, scale=True, center=True,
                                      data_format=self.data_format)
            self.block_1 = self.activation(batch_normed)

        self.block_last = conv2d("block_last", self.block_1, filters=self.num_classes, kernel_size=1,
                                 activation=None, use_bias=True, is_debug=self.is_debug,
                                 kernel_initializer=tf.random_normal_initializer(mean=0.0, stddev=0.01),
                                 data_format=channel_data_format)

        h = self.block_last.get_shape()[1].value
        w = self.block_last.get_shape()[2].value
        self.pool = tf.layers.average_pooling2d(name='global_average_pool', inputs=self.block_last,
                                                pool_size=[h, w], padding='VALID', strides=1,
                                                data_format=channel_data_format)
        self.base_output = tf.reshape(self.pool, [-1, self.num_classes], name="pool_reshape")

        return self.base_output


class SampleNetworkQuantize(SampleNetwork):
    """Quantize Sample Network."""

    def __init__(
            self,
            quantize_first_convolution=True,
            quantize_last_convolution=True,
            activation_quantizer=None,
            activation_quantizer_kwargs={},
            weight_quantizer=None,
            weight_quantizer_kwargs={},
            *args,
            **kwargs
    ):
        """
        Args:
            quantize_first_convolution(bool): use quantization in first conv.
            quantize_last_convolution(bool): use quantization in last conv.
            weight_quantizer (callable): weight quantizer.
            weight_quantizer_kwargs(dict): Initialize kwargs for weight quantizer.
            activation_quantizer (callable): activation quantizer
            activation_quantizer_kwargs(dict): Initialize kwargs for activation quantizer.
        """

        super().__init__(*args, **kwargs)

        self.quantize_first_convolution = quantize_first_convolution
        self.quantize_last_convolution = quantize_last_convolution

        assert callable(weight_quantizer)
        assert callable(activation_quantizer)

        self.weight_quantization = weight_quantizer(**weight_quantizer_kwargs)
        self.activation = activation_quantizer(**activation_quantizer_kwargs)

    @staticmethod
    def _quantized_variable_getter(
            weight_quantization,
            quantize_first_convolution,
            quantize_last_convolution,
            getter,
            name,
            *args,
            **kwargs):
        """Get the quantized variables.

        Use if to choose or skip the target should be quantized.

        Args:
            weight_quantization: Callable object which quantize variable.
            quantize_first_convolution(bool): Use quantization in first conv.
            quantize_last_convolution(bool): Use quantization in last conv.
            getter: Default from tensorflow.
            name: Default from tensorflow.
            args: Args.
            kwargs: Kwargs.
        """
        assert callable(weight_quantization)
        var = getter(name, *args, **kwargs)
        with tf.compat.v1.variable_scope(name):
            if "kernel" == var.op.name.split("/")[-1]:

                if not quantize_first_convolution:
                    if var.op.name.startswith("block_1/"):
                        return var

                if not quantize_last_convolution:
                    if var.op.name.startswith("block_last/"):
                        return var

                # Apply weight quantize to variable whose last word of name is "kernel".
                quantized_kernel = weight_quantization(var)
                tf.compat.v1.summary.histogram("quantized_kernel", quantized_kernel)
                return quantized_kernel

        return var

    def base(self, images, is_training):
        custom_getter = partial(
            self._quantized_variable_getter,
            self.weight_quantization,
            self.quantize_first_convolution,
            self.quantize_last_convolution,
        )
        with tf.compat.v1.variable_scope("", custom_getter=custom_getter):
            return super().base(images, is_training)

6.3.2. Quantize Network graph structure

../_images/quantize_example_minimal_graph_with_shape.png