Skip to content

Commit

Permalink
[stdlib] Add __init__(*, from_bits) and to_bits() to SIMD
Browse files Browse the repository at this point in the history
Signed-off-by: Yiwu Chen <[email protected]>
  • Loading branch information
soraros committed Oct 16, 2024
1 parent 15c7d50 commit 59422e7
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 74 deletions.
79 changes: 34 additions & 45 deletions stdlib/src/builtin/simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,20 @@ struct SIMD[type: DType, size: Int](
)
)

fn __init__[
int_type: DType, //
](inout self, *, from_bits: SIMD[int_type, size]):
"""Initializes the SIMD vector from the bits of an integral SIMD vector.
Parameters:
int_type: The integral type of the input SIMD vector.
Args:
from_bits: The SIMD vector to copy the bits from.
"""
constrained[int_type.is_integral(), "the SIMD type must be integral"]()
self = bitcast[type, size](from_bits)

# ===-------------------------------------------------------------------===#
# Operator dunders
# ===-------------------------------------------------------------------===#
Expand Down Expand Up @@ -778,9 +792,7 @@ struct SIMD[type: DType, size: Int](
# As a workaround, we roll our own implementation
@parameter
if has_neon() and type is DType.bfloat16:
var int_self = bitcast[_integral_type_of[type](), size](self)
var int_rhs = bitcast[_integral_type_of[type](), size](rhs)
return int_self == int_rhs
return self.to_bits() == rhs.to_bits()
else:
return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred eq>`](
self.value, rhs.value
Expand All @@ -803,9 +815,7 @@ struct SIMD[type: DType, size: Int](
# As a workaround, we roll our own implementation.
@parameter
if has_neon() and type is DType.bfloat16:
var int_self = bitcast[_integral_type_of[type](), size](self)
var int_rhs = bitcast[_integral_type_of[type](), size](rhs)
return int_self != int_rhs
return self.to_bits() != rhs.to_bits()
else:
return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred ne>`](
self.value, rhs.value
Expand Down Expand Up @@ -1505,9 +1515,8 @@ struct SIMD[type: DType, size: Int](
self
)

alias integral_type = FPUtils[type].integral_type
var m = self._float_to_bits[integral_type]()
return (m & (FPUtils[type].sign_mask() - 1))._bits_to_float[type]()
alias mask = FPUtils[type].exponent_mantissa_mask()
return Self(from_bits=self.to_bits() & mask)
else:
return (self < 0).select(-self, self)

Expand Down Expand Up @@ -1711,32 +1720,21 @@ struct SIMD[type: DType, size: Int](
if size > 1:
writer.write_str("]")

# FIXME: `_integral_type_of` doesn't work with `DType.bool`.
@always_inline
fn _bits_to_float[dest_type: DType](self) -> SIMD[dest_type, size]:
"""Bitcasts the integer value to a floating-point value.
fn to_bits[
int_dtype: DType = _integral_type_of[type]()
](self) -> SIMD[int_dtype, size]:
"""Bitcasts the SIMD vector to an integer SIMD vector
of the same bitwidth.
Parameters:
dest_type: DType to bitcast the input SIMD vector to.
Returns:
A floating-point representation of the integer value.
"""
alias integral_type = FPUtils[type].integral_type
return bitcast[dest_type, size](self.cast[integral_type]())

@always_inline
fn _float_to_bits[dest_type: DType](self) -> SIMD[dest_type, size]:
"""Bitcasts the floating-point value to an integer value.
Parameters:
dest_type: DType to bitcast the input SIMD vector to.
int_dtype: The integer type to cast to.
Returns:
An integer representation of the floating-point value.
"""
alias integral_type = FPUtils[type].integral_type
var v = bitcast[integral_type, size](self)
return v.cast[dest_type]()
return bitcast[int_dtype, size](self)

fn _floor_ceil_trunc_impl[intrinsic: StringLiteral](self) -> Self:
constrained[
Expand Down Expand Up @@ -2937,18 +2935,13 @@ alias _fp32_bf16_mantissa_diff = FPUtils[


@always_inline
fn _bfloat16_to_f32_scalar(
val: Scalar[DType.bfloat16],
) -> Scalar[DType.float32]:
fn _bfloat16_to_f32_scalar(val: Scalar[DType.bfloat16]) -> Float32:
@parameter
if has_neon():
# TODO(KERN-228): support BF16 on neon systems.
return _unchecked_zero[DType.float32, 1]()

var bfloat_bits = FPUtils[DType.bfloat16].bitcast_to_integer(val)
return FPUtils[DType.float32].bitcast_from_integer(
bfloat_bits << _fp32_bf16_mantissa_diff
)
return Float32(from_bits=val.to_bits() << _fp32_bf16_mantissa_diff)


@always_inline
Expand All @@ -2973,9 +2966,7 @@ fn _bfloat16_to_f32[


@always_inline
fn _f32_to_bfloat16_scalar(
val: Scalar[DType.float32],
) -> Scalar[DType.bfloat16]:
fn _f32_to_bfloat16_scalar(val: Float32) -> Scalar[DType.bfloat16]:
@parameter
if has_neon():
# TODO(KERN-228): support BF16 on neon systems.
Expand All @@ -2986,15 +2977,14 @@ fn _f32_to_bfloat16_scalar(
val
) else _nan[DType.bfloat16]()

var float_bits = FPUtils[DType.float32].bitcast_to_integer(val)
var float_bits = val.to_bits()

var lsb = (float_bits >> _fp32_bf16_mantissa_diff) & 1
var rounding_bias = 0x7FFF + lsb
float_bits += rounding_bias

var bfloat_bits = float_bits >> _fp32_bf16_mantissa_diff

return FPUtils[DType.bfloat16].bitcast_from_integer(bfloat_bits)
return Scalar[DType.bfloat16](from_bits=bfloat_bits)


@always_inline
Expand Down Expand Up @@ -3166,18 +3156,17 @@ fn _floor(x: SIMD) -> __type_of(x):
if x.type.is_integral():
return x

alias integral_type = FPUtils[x.type].integral_type
alias bitwidth = bitwidthof[x.type]()
alias exponent_width = FPUtils[x.type].exponent_width()
alias mantissa_width = FPUtils[x.type].mantissa_width()
alias mask = (1 << exponent_width) - 1
alias mask = FPUtils[x.type].exponent_mask()
alias bias = FPUtils[x.type].exponent_bias()
alias shift_factor = bitwidth - exponent_width - 1

var bits = bitcast[integral_type, x.size](x)
bits = x.to_bits()
var e = ((bits >> mantissa_width) & mask) - bias
bits = (e < shift_factor).select(
bits & ~((1 << (shift_factor - e)) - 1),
bits,
)
return bitcast[x.type, x.size](bits)
return __type_of(x)(from_bits=bits)
2 changes: 1 addition & 1 deletion stdlib/src/hashlib/_ahash.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ struct AHasher[key: U256](_Hasher):

@parameter
if new_data.type.is_floating_point():
v64 = new_data._float_to_bits[DType.uint64]()
v64 = new_data.to_bits().cast[DType.uint64]()
else:
v64 = new_data.cast[DType.uint64]()

Expand Down
42 changes: 17 additions & 25 deletions stdlib/src/math/math.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,9 @@ fn exp2[
if type not in (DType.float32, DType.float64):
return exp2(x.cast[DType.float32]()).cast[type]()

alias integral_type = FPUtils[type].integral_type

var xc = x.clamp(-126, 126)

var m = xc.cast[integral_type]()
var m = xc.to_bits()

xc -= m.cast[type]()

Expand All @@ -438,10 +436,9 @@ fn exp2[
),
](xc)

return (
r._float_to_bits[integral_type]()
+ (m << FPUtils[type].mantissa_width())
)._bits_to_float[type]()
return __type_of(x)(
from_bits=r.to_bits() + (m << FPUtils[type].mantissa_width())
)


# ===----------------------------------------------------------------------=== #
Expand Down Expand Up @@ -504,12 +501,9 @@ fn _ldexp_impl[

return res

alias integral_type = FPUtils[type].integral_type
var m: SIMD[integral_type, simd_width] = (
exp.cast[integral_type]() + FPUtils[type].exponent_bias()
)
var m = exp.to_bits() + FPUtils[type].exponent_bias()

return x * (m << FPUtils[type].mantissa_width())._bits_to_float[type]()
return x * __type_of(x)(from_bits=(m << FPUtils[type].mantissa_width()))


@always_inline
Expand Down Expand Up @@ -630,8 +624,8 @@ fn exp[

@always_inline
fn _frexp_mask1[
simd_width: Int, type: DType, integral_type: DType
]() -> SIMD[integral_type, simd_width]:
simd_width: Int, type: DType
]() -> SIMD[_integral_type_of[type](), simd_width]:
@parameter
if type is DType.float16:
return 0x7C00
Expand All @@ -646,8 +640,8 @@ fn _frexp_mask1[

@always_inline
fn _frexp_mask2[
simd_width: Int, type: DType, integral_type: DType
]() -> SIMD[integral_type, simd_width]:
simd_width: Int, type: DType
]() -> SIMD[_integral_type_of[type](), simd_width]:
@parameter
if type is DType.float16:
return 0x3800
Expand Down Expand Up @@ -682,22 +676,20 @@ fn frexp[
"""
# Based on the implementation in boost/simd/arch/common/simd/function/ifrexp.hpp
constrained[type.is_floating_point(), "must be a floating point value"]()
alias integral_type = _integral_type_of[type]()
alias zero = SIMD[type, simd_width](0)
alias T = SIMD[type, simd_width]
alias zero = T(0)
alias max_exponent = FPUtils[type].max_exponent() - 2
alias mantissa_width = FPUtils[type].mantissa_width()
var mask1 = _frexp_mask1[simd_width, type, integral_type]()
var mask2 = _frexp_mask2[simd_width, type, integral_type]()
var x_int = x._float_to_bits[integral_type]()
var mask1 = _frexp_mask1[simd_width, type]()
var mask2 = _frexp_mask2[simd_width, type]()
var x_int = x.to_bits()
var selector = x != zero
var exp = selector.select(
(((mask1 & x_int) >> mantissa_width) - max_exponent).cast[type](),
zero,
)
var frac = selector.select(
((x_int & ~mask1) | mask2)._bits_to_float[type](), zero
)
return StaticTuple[SIMD[type, simd_width], 2](frac, exp)
var frac = selector.select(T(from_bits=x_int & ~mask1 | mask2), zero)
return StaticTuple[size=2](frac, exp)


# ===----------------------------------------------------------------------=== #
Expand Down
1 change: 0 additions & 1 deletion stdlib/src/memory/memory.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ from sys import (
_libc as libc,
)
from collections import Optional
from builtin.dtype import _integral_type_of
from memory.pointer import AddressSpace, _GPUAddressSpace

# ===----------------------------------------------------------------------=== #
Expand Down
4 changes: 2 additions & 2 deletions stdlib/src/utils/numerics.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ struct FPUtils[
fn exponent_mantissa_mask() -> Int:
"""Returns the exponent and mantissa mask of a floating point type.
It is computed by `exponent_mask + mantissa_mask`.
It is computed by `exponent_mask | mantissa_mask`.
Returns:
The exponent and mantissa mask.
"""
return Self.exponent_mask() + Self.mantissa_mask()
return Self.exponent_mask() | Self.mantissa_mask()

@staticmethod
@always_inline
Expand Down

0 comments on commit 59422e7

Please sign in to comment.