Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ void CodeGen_X86::visit(const Cast *op) {
if (target.has_feature(Target::F16C) &&
dst.code() == Type::Float &&
src.code() == Type::Float &&
(dst.bits() == 16 || src.bits() == 16)) {
(dst.bits() == 16 || src.bits() == 16) &&
src.bits() <= 32) { // Don't use for narrowing casts from double - it results in a libm call
// Node we use code() == Type::Float instead of is_float(), because we
// don't want to catch bfloat casts.

Expand Down
108 changes: 80 additions & 28 deletions src/EmulateFloat16Math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,44 @@ namespace Halide {
namespace Internal {

Expr bfloat16_to_float32(Expr e) {
const int lanes = e.type().lanes();
if (e.type().is_bfloat()) {
e = reinterpret(e.type().with_code(Type::UInt), e);
}
e = cast(UInt(32, e.type().lanes()), e);
e = cast(UInt(32, lanes), e);
e = e << 16;
e = reinterpret(Float(32, e.type().lanes()), e);
e = reinterpret(Float(32, lanes), e);
e = strict_float(e);
return e;
}

Expr float32_to_bfloat16(Expr e) {
internal_assert(e.type().bits() == 32);
Expr float_to_bfloat16(Expr e) {
const int lanes = e.type().lanes();
e = strict_float(e);
e = reinterpret(UInt(32, e.type().lanes()), e);
// We want to round ties to even, so before truncating either
// add 0x8000 (0.5) to odd numbers or 0x7fff (0.499999) to
// even numbers.
e += 0x7fff + ((e >> 16) & 1);

Expr err;
// First round to float and record any gain of loss of magnitude
if (e.type().bits() == 64) {
Expr f = cast(Float(32, lanes), e);
err = abs(e) - abs(f);
e = f;
} else {
internal_assert(e.type().bits() == 32);
}
e = reinterpret(UInt(32, lanes), e);

// We want to round ties to even, so if we have no error recorded above,
// before truncating either add 0x8000 (0.5) to odd numbers or 0x7fff
// (0.499999) to even numbers. If we have error, break ties using that
// instead.
Expr tie_breaker = (e >> 16) & 1; // 1 when rounding down would go to odd
if (err.defined()) {
tie_breaker = ((err == 0) & tie_breaker) | (err > 0);
}
e += tie_breaker + 0x7fff;
e = (e >> 16);
e = cast(UInt(16, e.type().lanes()), e);
e = reinterpret(BFloat(16, e.type().lanes()), e);
e = cast(UInt(16, lanes), e);
e = reinterpret(BFloat(16, lanes), e);
return e;
}

Expand Down Expand Up @@ -63,51 +80,75 @@ Expr float16_to_float32(Expr value) {
return f32;
}

Expr float32_to_float16(Expr value) {
Expr float_to_float16(Expr value) {
// We're about the sniff the bits of a float, so we should
// guard it with strict float to ensure we don't do things
// like assume it can't be denormal.
value = strict_float(value);

Type f32_t = Float(32, value.type().lanes());
const int src_bits = value.type().bits();

Type float_t = Float(src_bits, value.type().lanes());
Type f16_t = Float(16, value.type().lanes());
Type u32_t = UInt(32, value.type().lanes());
Type bits_t = UInt(src_bits, value.type().lanes());
Type u16_t = UInt(16, value.type().lanes());

Expr bits = reinterpret(u32_t, value);
Expr bits = reinterpret(bits_t, value);

// Extract the sign bit
Expr sign = bits & make_const(u32_t, 0x80000000);
Expr sign = bits & make_const(bits_t, (uint64_t)1 << (src_bits - 1));
bits = bits ^ sign;

// Test the endpoints
Expr is_denorm = (bits < make_const(u32_t, 0x38800000));
Expr is_inf = (bits >= make_const(u32_t, 0x47800000));
Expr is_nan = (bits > make_const(u32_t, 0x7f800000));

// Smallest input representable as normal float16 (2^-14)
Expr two_to_the_minus_14 = src_bits == 32 ?
make_const(bits_t, 0x38800000) :
make_const(bits_t, (uint64_t)0x3f10000000000000ULL);
Expr is_denorm = bits < two_to_the_minus_14;

// Smallest input too big to represent as a float16 (2^16)
Expr two_to_the_16 = src_bits == 32 ?
make_const(bits_t, 0x47800000) :
make_const(bits_t, (uint64_t)0x40f0000000000000ULL);
Expr is_inf = bits >= two_to_the_16;

// Check if the input is a nan, which is anything bigger than an infinity bit pattern
Expr input_inf_bits = src_bits == 32 ?
make_const(bits_t, 0x7f800000) :
make_const(bits_t, (uint64_t)0x7ff0000000000000ULL);
Expr is_nan = bits > input_inf_bits;

// Denorms are linearly spaced, so we can handle them
// by scaling up the input as a float and using the
// existing int-conversion rounding instructions.
Expr denorm_bits = cast(u16_t, strict_float(round(strict_float(reinterpret(f32_t, bits + 0x0c000000)))));
Expr two_to_the_24 = src_bits == 32 ?
make_const(bits_t, 0x0c000000) :
make_const(bits_t, (uint64_t)0x0180000000000000ULL);
Expr denorm_bits = cast(u16_t, strict_float(round(reinterpret(float_t, bits + two_to_the_24))));
Expr inf_bits = make_const(u16_t, 0x7c00);
Expr nan_bits = make_const(u16_t, 0x7fff);

// We want to round to nearest even, so we add either
// 0.5 if the integer part is odd, or 0.4999999 if the
// integer part is even, then truncate.
bits += (bits >> 13) & 1;
bits += 0xfff;
bits = bits >> 13;
const int float16_mantissa_bits = 10;
const int input_mantissa_bits = src_bits == 32 ? 23 : 52;
const int bits_lost = input_mantissa_bits - float16_mantissa_bits;
bits += (bits >> bits_lost) & 1;
bits += make_const(bits_t, ((uint64_t)1 << (bits_lost - 1)) - 1);
bits = cast(u16_t, bits >> bits_lost);

// Rebias the exponent
bits -= 0x1c000;
bits -= 0x4000;
// Truncate the top bits of the exponent
bits = bits & 0x7fff;
bits = select(is_denorm, denorm_bits,
is_inf, inf_bits,
is_nan, nan_bits,
cast(u16_t, bits));
// Recover the sign bit
bits = bits | cast(u16_t, sign >> 16);
bits = bits | cast(u16_t, sign >> (src_bits - 16));
return common_subexpression_elimination(reinterpret(f16_t, bits));
}

Expand Down Expand Up @@ -157,7 +198,7 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *op) {
Expr e = Call::make(t, it->second, new_args, op->call_type,
op->func, op->value_index, op->image, op->param);
if (op->type.is_float()) {
e = float32_to_float16(e);
e = float_to_float16(e);
}
internal_assert(e.type() == op->type);
return e;
Expand All @@ -171,6 +212,7 @@ Expr lower_float16_cast(const Cast *op) {
Type src = op->value.type();
Type dst = op->type;
Type f32 = Float(32, dst.lanes());
Type f64 = Float(64, dst.lanes());
Expr val = op->value;

if (src.is_bfloat()) {
Expand All @@ -183,10 +225,20 @@ Expr lower_float16_cast(const Cast *op) {

if (dst.is_bfloat()) {
internal_assert(dst.bits() == 16);
val = float32_to_bfloat16(cast(f32, val));
if (src.bits() > 32) {
val = cast(f64, val);
} else {
val = cast(f32, val);
}
val = float_to_bfloat16(val);
} else if (dst.is_float() && dst.bits() < 32) {
internal_assert(dst.bits() == 16);
val = float32_to_float16(cast(f32, val));
if (src.bits() > 32) {
val = cast(f64, val);
} else {
val = cast(f32, val);
}
val = float_to_float16(val);
}

return cast(dst, val);
Expand Down
4 changes: 2 additions & 2 deletions src/EmulateFloat16Math.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *);

/** Cast to/from float and bfloat using bitwise math. */
//@{
Expr float32_to_bfloat16(Expr e);
Expr float32_to_float16(Expr e);
Expr float_to_bfloat16(Expr e);
Expr float_to_float16(Expr e);
Expr float16_to_float32(Expr e);
Expr bfloat16_to_float32(Expr e);
Expr lower_float16_cast(const Cast *op);
Expand Down
46 changes: 40 additions & 6 deletions src/Float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ namespace Internal {

// Conversion routines to and from float cribbed from Christian Rau's
// half library (half.sourceforge.net)
uint16_t float_to_float16(float value) {
template<typename T>
uint16_t float_to_float16(T value) {
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"float_to_float16 only supports float and double types");
// Start by copying over the sign bit
uint16_t bits = std::signbit(value) << 15;

Expand Down Expand Up @@ -40,14 +43,14 @@ uint16_t float_to_float16(float value) {

// We've normalized value as much as possible. Put the integer
// portion of it into the mantissa.
float ival;
float frac = std::modf(value, &ival);
T ival;
T frac = std::modf(value, &ival);
bits += (uint16_t)(std::abs((int)ival));

// Now consider the fractional part. We round to nearest with ties
// going to even.
frac = std::abs(frac);
bits += (frac > 0.5f) | ((frac == 0.5f) & bits);
bits += (frac > T(0.5)) | ((frac == T(0.5)) & bits);

return bits;
}
Expand Down Expand Up @@ -341,6 +344,19 @@ uint16_t float_to_bfloat16(float f) {
return ret >> 16;
}

uint16_t float_to_bfloat16(double f) {
// Coming from double is a little tricker. We first narrow to float and
// record if any magnitude was lost or gained in the process. If so we'll
// use that to break ties instead of testing whether or not truncation would
// return odd.
float f32 = (float)f;
const double err = std::abs(f) - (double)std::abs(f32);
uint32_t ret;
memcpy(&ret, &f32, sizeof(float));
ret += 0x7fff + (((err >= 0) & ((ret >> 16) & 1)) | (err > 0));
return ret >> 16;
}

float bfloat16_to_float(uint16_t b) {
// Assume little-endian floats
uint16_t bits[2] = {0, b};
Expand All @@ -362,7 +378,17 @@ float16_t::float16_t(double value)
}

float16_t::float16_t(int value)
: data(float_to_float16(value)) {
: data(float_to_float16((float)value)) {
// integers of any size that map to finite float16s are all representable as
// float, so we can go via the float conversion method.
}

float16_t::float16_t(int64_t value)
: data(float_to_float16((float)value)) {
}

float16_t::float16_t(uint64_t value)
: data(float_to_float16((float)value)) {
}

float16_t::operator float() const {
Expand Down Expand Up @@ -464,7 +490,15 @@ bfloat16_t::bfloat16_t(double value)
}

bfloat16_t::bfloat16_t(int value)
: data(float_to_bfloat16(value)) {
: data(float_to_bfloat16((double)value)) {
}

bfloat16_t::bfloat16_t(int64_t value)
: data(float_to_bfloat16((double)value)) {
}

bfloat16_t::bfloat16_t(uint64_t value)
: data(float_to_bfloat16((double)value)) {
}

bfloat16_t::operator float() const {
Expand Down
4 changes: 4 additions & 0 deletions src/Float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct float16_t {
explicit float16_t(float value);
explicit float16_t(double value);
explicit float16_t(int value);
explicit float16_t(int64_t value);
explicit float16_t(uint64_t value);
// @}

/** Construct a float16_t with the bits initialised to 0. This represents
Expand Down Expand Up @@ -175,6 +177,8 @@ struct bfloat16_t {
explicit bfloat16_t(float value);
explicit bfloat16_t(double value);
explicit bfloat16_t(int value);
explicit bfloat16_t(int64_t value);
explicit bfloat16_t(uint64_t value);
// @}

/** Construct a bfloat16_t with the bits initialised to 0. This represents
Expand Down
1 change: 1 addition & 0 deletions src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ const char *const intrinsic_op_names[] = {
"sliding_window_marker",
"sorted_avg",
"strict_add",
"strict_cast",
"strict_div",
"strict_eq",
"strict_le",
Expand Down
2 changes: 2 additions & 0 deletions src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ struct Call : public ExprNode<Call> {
// them as reals and ignoring the existence of nan and inf. Using these
// intrinsics instead prevents any such optimizations.
strict_add,
strict_cast,
strict_div,
strict_eq,
strict_le,
Expand Down Expand Up @@ -792,6 +793,7 @@ struct Call : public ExprNode<Call> {
bool is_strict_float_intrinsic() const {
return is_intrinsic(
{Call::strict_add,
Call::strict_cast,
Call::strict_div,
Call::strict_max,
Call::strict_min,
Expand Down
12 changes: 12 additions & 0 deletions src/StrictifyFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ class Strictify : public IRMutator {
return IRMutator::visit(op);
}
}

Expr visit(const Cast *op) override {
if (op->value.type().is_float() &&
op->type.is_float()) {
return Call::make(op->type, Call::strict_cast,
{mutate(op->value)}, Call::PureIntrinsic);
} else {
return IRMutator::visit(op);
}
}
};

const std::set<std::string> strict_externs = {
Expand Down Expand Up @@ -142,6 +152,8 @@ Expr unstrictify_float(const Call *op) {
return op->args[0] <= op->args[1];
} else if (op->is_intrinsic(Call::strict_eq)) {
return op->args[0] == op->args[1];
} else if (op->is_intrinsic(Call::strict_cast)) {
return cast(op->type, op->args[0]);
} else {
internal_error << "Missing lowering of strict float intrinsic: "
<< Expr(op) << "\n";
Expand Down
Loading
Loading