Enforcing Type Safety in C++

Libraries and Custom Strong Types

Correction:

The examples here are all in microvolts but includes multiplication and division which is dimensionally incorrect. I prefer to remove operator* and operator/ in favor of explicit functions for those calculations which can be done for all the referenced libraries.

Introduction

C++ is a statically typed language with relatively weak type safety. This can be useful and definitely cuts down on the lines of code. Of course so does removing all newlines and comments. Operations like multiplying floats and ints almost always behave as expected, allowing you to ignore the type differences (provided you turn off enough of your compiler warnings). Strong typing helps force the intent of the author to be literal: improving readability while adding compile-time checks. One example I’ve been burned on before is changing the order of operations in an arithmetic statement leading to precision loss. This is not something that we want to happen.

While disciplined use of compiler warnings and explicit casts can help, strong typing lets the compiler enforce these constraints for you.

Strong types also carry contextual meaning through your codebase. The classic example is units: if you’re working on a physical system you don’t want to multiply meters by volts and treat the result as liters. Your compiler and static analysis tools can help keep your units dimensionally correct in ways they can’t with generic types and weak typing.

One elegant example is the loved and hated Ada language. If Verilog is C then VHDL is Ada (sorry Ada). Ada is extremely strongly typed; all conversions must be explicitly defined. This can be verbose but for when you really, REALLY want to know what’s going on in your code base it’s hard to beat.

Recently, while updating our LC120 laser controller board, I introduced a bug by switching positional arguments, mixing amps and volts in a calculation. The tests caught it, but I would have preferred if the compiler had. Enter strong typing.

Goals & Use Case

The use case I have in mind is something like this:

using microvolt_t = StrongArithmeticType<MicrovoltTag, int>;
using ohms_t = StrongArithmeticType<OhmsTag, double>;
using watt_t = StrongArithmeticType<WattTag, double>;

const auto value = microvolt_t{3} + microvolt_t{6};  //  microvolt_t{9}

watt_t calculate_power(microvolt_t uv, ohms_t res) {
    /*
        P = V**2 / R
    */
    const double MICRO = 1e6;
    const double volts = static_cast<double>(static_cast<int>(uv))/MICRO;
    auto power = (volts*volts) / static_cast<double>(res);
    return watt_t{power};
}

microvolt_t voltage{500000};   // 500,000 µV = 0.5 V
ohms_t resistance{10.0};       // 10 Ω
watt_t power = calculate_power(voltage, resistance);
  • Wrap a primitive type and prevent implicit conversion to different types.
  • Support arithmetic with the same type
  • Prevent implicit mixing of units
  • Allow type agnostic conversion with static_cast if the underlying primitive matches the cast type.
  • Minimize abstraction cost
  • Types should be distinct even if the underlying type is the same

Custom StrongArithmeticType Definition

The following defines a StrongArithmeticType which is a template around a primitive type base_t with a tag type tag_t to set it apart from other types with the same base_t. The tag_t avoids the following issue:

using microvolt_t = StrongArithmeticType<int>;
using microamp_t = StrongArithmeticType<int>;

as this is equivilent to:

using microvolt_t = StrongArithmeticType<int>;
using microamp_t = microvolt_t;

The tag_t is only used to create a unique type signature.

Using our own types instead of a library keeps the quantity of template magic to a minimum.

template<typename tag_t, typename base_t>
class StrongArithmeticType {
public:
    using type = base_t;
    StrongArithmeticType(const StrongArithmeticType& other) = default;
    explicit constexpr StrongArithmeticType(base_t value) : value_(value) {}

    // Arithmetic within the same type
    constexpr StrongArithmeticType operator+(const StrongArithmeticType& other) const { return StrongArithmeticType(value_ + other.value_); }
    constexpr StrongArithmeticType operator-(const StrongArithmeticType& other) const { return StrongArithmeticType(value_ - other.value_); }
    constexpr StrongArithmeticType operator*(const StrongArithmeticType& other) const { return StrongArithmeticType(value_ * other.value_); }
    constexpr StrongArithmeticType operator/(const StrongArithmeticType& other) const { return StrongArithmeticType(value_ / other.value_); }

    // Comparison
    constexpr bool operator==(const StrongArithmeticType& other) const { return value_ == other.value_; }
    constexpr bool operator!=(const StrongArithmeticType& other) const { return value_ != other.value_; }
    constexpr bool operator<(const StrongArithmeticType& other) const { return value_ < other.value_; }
    constexpr bool operator>(const StrongArithmeticType& other) const { return value_ > other.value_; }
    constexpr bool operator<=(const StrongArithmeticType& other) const { return value_ <= other.value_; }
    constexpr bool operator>=(const StrongArithmeticType& other) const { return value_ >= other.value_; }

    // Explicit conversion to underlying primitive type
    constexpr explicit operator base_t() const { return value_; }

protected:
    base_t value_;
};

struct MicrovoltTag {};
using microvolt_t = StrongArithmeticType<MicrovoltTag, int>;

This Godbolt example

shows that we get a zero cost abstraction once -O1 is turned on with the exception of func_capped:

microvolt_t func_capped(microvolt_t val)
{
    const microvolt_t max = microvolt_t{5000};
    if (val > max)
        return max;
    return val;
}

Changing to -O2 produces the same assembly as using a naked double. Since there’s no penalty we can use the type safe abstraction even in performance critical sections.

Libraries

There are several interesting libraries out there that are worth looking at.

doom/strong_types

We can use the strong_types library to create an equivalent type:

#include "st/st.hpp"

using base_t = double;
struct MicrovoltTag{};
using microvolt_t = st::type<base_t, MicrovoltTag, st::arithmetic, st::equality_comparable>;

Godbolt example

The emitted assembly is the same as before as -O2 and nearly identical at -O1.

rollbear/strong_type

rollbear/strong_type is just as easy to use as doom/strong_types and also doesn’t allow casting to the base type. I’d be interested to hear arguments about why casting to the base type shouldn’t be allowed by default or if it’s just an abundance of caution.

#include "strong_type/strong_type.hpp"

using base_t = double;
struct MicrovoltTag{};
using microvolt_t = strong::type<base_t, struct MicrovoltTag, strong::arithmetic, strong::equality, strong::ordered>;


// Shift division - microvolt_t
microvolt_t func_shift_division(microvolt_t val)
{
    return val / microvolt_t{4};
}

base_t cast(microvolt_t val) {  // Fails to compile
    return static_cast<base_t>(val);
}

Godbolt example

We again get the same results as before at both -O1 and -O2 which instills a bit of confidence.

foonathan/type_safe

foonathan/type_safe is especially interesting if you are mostly concerned with conversion between primitive types. It isn’t focused on unit like types but rather moving runtime errors to compile time. Worth checking out!

NamedTypes

This was the last library I got to. It’s not immediately obvious how to use it so I didn’t.

Resources

Videos

Articles

Appendix

To make the Godbolt example I combined each of these libraries into gists using the following. It searches for matching include statements and follows each include to get the include order. Once the include order is established all of the headers are written to a single output file.

import os
import re
import logging
import click

log_ = logging.getLogger("merge_headers")


def parse_includes(file_path, regex):
    includes = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            match = regex.match(line)
            if match:
                includes.append(match.group(1))
    return includes


def process_header(file_path, include_dirs, seen, ordered, regex):
    abs_path = os.path.abspath(file_path)
    if abs_path in seen:
        log_.info(f"Skipping already processed: {file_path}")
        return

    seen.add(abs_path)
    log_.info(f"Processing: {file_path}")

    includes = parse_includes(file_path, regex=regex)
    search_dirs = [os.path.dirname(file_path)] + include_dirs

    for inc in includes:
        for directory in search_dirs:
            candidate = os.path.join(directory, inc)
            if os.path.exists(candidate):
                process_header(candidate, include_dirs, seen, ordered, regex)
                break
        else:
            log_.error(f"Include '{inc}' in '{file_path}' not found.")

    # After processing dependencies, append self to ordered list
    ordered.append(abs_path)


def get_include_regex(libname=None):
    if libname:
        return re.compile(fr'^\s*#include\s+<{re.escape(libname)}/([^>]+)>')
    else:
        return re.compile(r'^\s*#include\s+"([^"]+)"')


@click.command()
@click.argument('entry_point', type=click.Path(exists=True, readable=True))
@click.argument('output_path', type=click.Path(writable=True))
@click.argument('include_dirs', nargs=-1, type=click.Path(exists=True, file_okay=False))
@click.argument('libname', default=None, required=False)
@click.option('--verbose', '-v', is_flag=True, help='Enable verbose output.')
def main(entry_point, output_path, include_dirs, libname, verbose):
    """
    Merges C++ headers starting from ENTRY_POINT into OUTPUT_PATH.
    Resolves includes recursively with proper ordering.
    """

    logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)

    seen = set()      # Fast cycle checking
    ordered = []      # Topologically sorted output order

    regex = get_include_regex(libname=libname)

    process_header(entry_point, list(include_dirs), seen, ordered, regex=regex)

    # Emit combined header in correct order
    with open(output_path, 'w', encoding='utf-8') as f_out:
        for file_path in ordered:
            with open(file_path, 'r', encoding='utf-8') as f_in:
                for line in f_in:
                    if not regex.match(line):
                        f_out.write(line.rstrip('\r\n') + '\n')

    click.echo(f"Combined header written to {output_path}")


if __name__ == '__main__':
    main()