Skip to content

Commit

Permalink
Implemented Adbc.Column.f16/2 (#83)
Browse files Browse the repository at this point in the history
* implemented `Adbc.Column.f16/2`

* updated module name in inline docs examples

* restore `storage_type`

* handle inf/-inf/nan properly

* fixed metadata

* test inf/-inf/nan

* test with postgres 14

* fix function specs: add `| nil`
  • Loading branch information
cocoa-xu authored May 18, 2024
1 parent c5731af commit c4c4ada
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:

services:
pg:
image: postgres:11
image: postgres:14
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
Expand Down
62 changes: 59 additions & 3 deletions c_src/adbc_arrow_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
#pragma once

#include <stdio.h>
#include <cmath>
#include <cstdbool>
#include <cstdint>
#include <vector>
#include <adbc.h>
#include <erl_nif.h>
#include "adbc_half_float.hpp"

static int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level, std::vector<ERL_NIF_TERM> &out_terms, ERL_NIF_TERM &value_type, ERL_NIF_TERM &metadata, ERL_NIF_TERM &error, bool *end_of_series = nullptr);
static int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, int64_t offset, int64_t count, int64_t level, std::vector<ERL_NIF_TERM> &out_terms, ERL_NIF_TERM &value_type, ERL_NIF_TERM &metadata, ERL_NIF_TERM &error, bool *end_of_series = nullptr);
Expand Down Expand Up @@ -503,7 +505,7 @@ int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct
while (ArrowMetadataReaderRead(&metadata_reader, &key, &value) == NANOARROW_OK) {
// printf("key: %.*s, value: %.*s\n", (int)key.size_bytes, key.data, (int)value.size_bytes, value.data);
metadata_keys.push_back(erlang::nif::make_binary(env, key.data, (size_t)key.size_bytes));
metadata_values.push_back(erlang::nif::make_binary(env, value.data, (size_t)key.size_bytes));
metadata_values.push_back(erlang::nif::make_binary(env, value.data, (size_t)value.size_bytes));
}
if (metadata_keys.size() > 0) {
enif_make_map_from_arrays(env, metadata_keys.data(), metadata_values.data(), (unsigned)metadata_keys.size(), &arrow_metadata);
Expand Down Expand Up @@ -651,6 +653,36 @@ int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct
(const value_type *)values->buffers[data_buffer_index],
enif_make_uint64
);
} else if (format[0] == 'e') {
// NANOARROW_TYPE_HALF_FLOAT
using value_type = uint16_t;
term_type = kAdbcColumnTypeF16;
if (count == -1) count = values->length;
if (values->n_buffers != 2) {
error = erlang::nif::error(env, "invalid n_buffers value for ArrowArray (format=e), values->n_buffers != 2");
return 1;
}
current_term = values_from_buffer(
env,
offset,
count,
(const uint8_t *)values->buffers[bitmap_buffer_index],
(const value_type *)values->buffers[data_buffer_index],
[](ErlNifEnv *env, const uint16_t u16) -> ERL_NIF_TERM {
float val = float16_to_float(u16);
if (std::isnan(val)) {
return kAtomNaN;
} else if (std::isinf(val)) {
if (val > 0) {
return kAtomInfinity;
} else {
return kAtomNegInfinity;
}
} else {
return enif_make_double(env, val);
}
}
);
} else if (format[0] == 'f') {
// NANOARROW_TYPE_FLOAT
using value_type = float;
Expand All @@ -666,7 +698,19 @@ int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct
count,
(const uint8_t *)values->buffers[bitmap_buffer_index],
(const value_type *)values->buffers[data_buffer_index],
enif_make_double
[](ErlNifEnv *env, double val) -> ERL_NIF_TERM {
if (std::isnan(val)) {
return kAtomNaN;
} else if (std::isinf(val)) {
if (val > 0) {
return kAtomInfinity;
} else {
return kAtomNegInfinity;
}
} else {
return enif_make_double(env, val);
}
}
);
} else if (format[0] == 'g') {
// NANOARROW_TYPE_DOUBLE
Expand All @@ -683,7 +727,19 @@ int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct
count,
(const uint8_t *)values->buffers[bitmap_buffer_index],
(const value_type *)values->buffers[data_buffer_index],
enif_make_double
[](ErlNifEnv *env, double val) -> ERL_NIF_TERM {
if (std::isnan(val)) {
return kAtomNaN;
} else if (std::isinf(val)) {
if (val > 0) {
return kAtomInfinity;
} else {
return kAtomNegInfinity;
}
} else {
return enif_make_double(env, val);
}
}
);
} else if (format[0] == 'b') {
// NANOARROW_TYPE_BOOL
Expand Down
38 changes: 37 additions & 1 deletion c_src/adbc_column.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <adbc.h>
#include <erl_nif.h>
#include "adbc_consts.h"
#include "adbc_half_float.hpp"
#include "nif_utils.hpp"

ERL_NIF_TERM make_adbc_column(ErlNifEnv *env, ERL_NIF_TERM name_term, ERL_NIF_TERM type_term, bool nullable, ERL_NIF_TERM metadata, ERL_NIF_TERM data) {
Expand Down Expand Up @@ -111,15 +112,48 @@ int get_list_float(ErlNifEnv *env, ERL_NIF_TERM list, bool nullable, const std::
if (!erlang::nif::get(env, head, &val)) {
if (nullable && enif_is_identical(head, kAtomNil)) {
callback(0, true);
} else if (enif_is_identical(head, kAtomInfinity)) {
callback(std::numeric_limits<double>::infinity(), false);
} else if (enif_is_identical(head, kAtomNegInfinity)) {
callback(-std::numeric_limits<double>::infinity(), false);
} else if (enif_is_identical(head, kAtomNaN)) {
callback(std::numeric_limits<double>::quiet_NaN(), false);
} else {
return 1;
}
} else {
callback(val, false);
}
callback(val, false);
}
return 0;
}

int do_get_list_half_float(ErlNifEnv *env, ERL_NIF_TERM list, bool nullable, ArrowType nanoarrow_type, struct ArrowArray* array_out, struct ArrowSchema* schema_out, struct ArrowError* error_out) {
NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema_out, nanoarrow_type));
NANOARROW_RETURN_NOT_OK(ArrowArrayInitFromSchema(array_out, schema_out, error_out));
struct ArrowArrayPrivateData* private_data = (struct ArrowArrayPrivateData*)array_out->private_data;
auto storage_type = private_data->storage_type;
private_data->storage_type = NANOARROW_TYPE_UINT16;
NANOARROW_RETURN_NOT_OK(ArrowArrayStartAppending(array_out));
int ret;
if (nullable) {
ret = get_list_float(env, list, nullable, [&array_out](double val, bool is_nil) -> void {
ArrowArrayAppendUInt(array_out, float_to_float16(val));
if (is_nil) {
ArrowArrayAppendNull(array_out, 1);
}
});
private_data->storage_type = storage_type;
return ret;
} else {
ret = get_list_float(env, list, nullable, [&array_out](double val, bool) -> void {
ArrowArrayAppendUInt(array_out, float_to_float16(val));
});
private_data->storage_type = storage_type;
return ret;
}
}

int do_get_list_float(ErlNifEnv *env, ERL_NIF_TERM list, bool nullable, ArrowType nanoarrow_type, struct ArrowArray* array_out, struct ArrowSchema* schema_out, struct ArrowError* error_out) {
NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema_out, nanoarrow_type));
NANOARROW_RETURN_NOT_OK(ArrowArrayInitFromSchema(array_out, schema_out, error_out));
Expand Down Expand Up @@ -730,6 +764,8 @@ int adbc_column_to_adbc_field(ErlNifEnv *env, ERL_NIF_TERM adbc_buffer, struct A
ret = do_get_list_integer<uint32_t>(env, data_term, nullable, NANOARROW_TYPE_UINT32, array_out, schema_out, error_out);
} else if (enif_is_identical(type_term, kAdbcColumnTypeU64)) {
ret = do_get_list_integer<uint64_t>(env, data_term, nullable, NANOARROW_TYPE_UINT64, array_out, schema_out, error_out);
} else if (enif_is_identical(type_term, kAdbcColumnTypeF16)) {
ret = do_get_list_half_float(env, data_term, nullable, NANOARROW_TYPE_HALF_FLOAT, array_out, schema_out, error_out);
} else if (enif_is_identical(type_term, kAdbcColumnTypeF32)) {
ret = do_get_list_float(env, data_term, nullable, NANOARROW_TYPE_FLOAT, array_out, schema_out, error_out);
} else if (enif_is_identical(type_term, kAdbcColumnTypeF64)) {
Expand Down
4 changes: 4 additions & 0 deletions c_src/adbc_consts.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ static ERL_NIF_TERM kAtomAdbcError;
static ERL_NIF_TERM kAtomNil;
static ERL_NIF_TERM kAtomTrue;
static ERL_NIF_TERM kAtomFalse;
static ERL_NIF_TERM kAtomInfinity;
static ERL_NIF_TERM kAtomNegInfinity;
static ERL_NIF_TERM kAtomNaN;
static ERL_NIF_TERM kAtomEndOfSeries;
static ERL_NIF_TERM kAtomStructKey;
static ERL_NIF_TERM kAtomTime32;
Expand Down Expand Up @@ -51,6 +54,7 @@ static ERL_NIF_TERM kAdbcColumnTypeI8;
static ERL_NIF_TERM kAdbcColumnTypeI16;
static ERL_NIF_TERM kAdbcColumnTypeI32;
static ERL_NIF_TERM kAdbcColumnTypeI64;
static ERL_NIF_TERM kAdbcColumnTypeF16;
static ERL_NIF_TERM kAdbcColumnTypeF32;
static ERL_NIF_TERM kAdbcColumnTypeF64;
static ERL_NIF_TERM kAdbcColumnTypeStruct;
Expand Down
82 changes: 82 additions & 0 deletions c_src/adbc_half_float.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#ifndef ADBC_HALF_FLOAT_HPP
#pragma once

#include <cmath>
#include <cstdint>
#include <cstring>

// Function to convert float16 (IEEE 754 half-precision) to float
float float16_to_float(uint16_t value) {
static_assert(std::numeric_limits<float>::is_iec559, "IEEE 754 required");

uint16_t sign = (value >> 15) & 0x1;
uint16_t exp = (value >> 10) & 0x1F;
uint16_t mant = value & 0x3FF;

uint32_t fbits;

if (exp == 0) {
if (mant == 0) {
// Zero
fbits = sign << 31;
} else {
// Denormalized number
exp = 127 - 15 + 1;
while ((mant & 0x400) == 0) {
mant <<= 1;
exp--;
}
mant &= 0x3FF;
fbits = (sign << 31) | (exp << 23) | (mant << 13);
}
} else if (exp == 0x1F) {
// Infinity or NaN
fbits = (sign << 31) | 0x7F800000 | (mant << 13);
} else {
// Normalized number
exp = exp - 15 + 127;
fbits = (sign << 31) | (exp << 23) | (mant << 13);
}

float result;
memcpy(&result, &fbits, sizeof(result));
return result;
}

// Function to convert float to float16 (IEEE 754 half-precision)
uint16_t float_to_float16(float value) {
static_assert(std::numeric_limits<float>::is_iec559, "IEEE 754 required");

uint32_t fbits;
memcpy(&fbits, &value, sizeof(fbits));

uint16_t sign = (fbits >> 16) & 0x8000; // sign bit
int16_t exp = ((fbits >> 23) & 0xFF) - 127 + 15; // exponent
uint32_t mant = fbits & 0x007FFFFF; // mantissa

if (exp <= 0) {
if (exp < -10) {
// Too small to be represented as a normalized float16
return sign;
}
// Denormalized half-precision
mant = (mant | 0x00800000) >> (1 - exp);
return sign | (mant >> 13);
} else if (exp == 0xFF - (127 - 15)) {
if (mant == 0) {
// Infinity
return sign | 0x7C00;
} else {
// NaN
return sign | 0x7C00 | (mant >> 13);
}
} else if (exp > 30) {
// Overflow to infinity
return sign | 0x7C00;
}

// Normalized half-precision
return sign | (exp << 10) | (mant >> 13);
}

#endif // ADBC_HALF_FLOAT_HPP
4 changes: 4 additions & 0 deletions c_src/adbc_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,9 @@ static int on_load(ErlNifEnv *env, void **, ERL_NIF_TERM) {
kAtomNil = erlang::nif::atom(env, "nil");
kAtomTrue = erlang::nif::atom(env, "true");
kAtomFalse = erlang::nif::atom(env, "false");
kAtomInfinity = erlang::nif::atom(env, "infinity");
kAtomNegInfinity = erlang::nif::atom(env, "neg_infinity");
kAtomNaN = erlang::nif::atom(env, "nan");
kAtomEndOfSeries = erlang::nif::atom(env, "end_of_series");
kAtomStructKey = erlang::nif::atom(env, "__struct__");
kAtomTime32 = erlang::nif::atom(env, "time32");
Expand Down Expand Up @@ -829,6 +832,7 @@ static int on_load(ErlNifEnv *env, void **, ERL_NIF_TERM) {
kAdbcColumnTypeI16 = erlang::nif::atom(env, "i16");
kAdbcColumnTypeI32 = erlang::nif::atom(env, "i32");
kAdbcColumnTypeI64 = erlang::nif::atom(env, "i64");
kAdbcColumnTypeF16 = erlang::nif::atom(env, "f16");
kAdbcColumnTypeF32 = erlang::nif::atom(env, "f32");
kAdbcColumnTypeF64 = erlang::nif::atom(env, "f64");
kAdbcColumnTypeStruct = erlang::nif::atom(env, "struct");
Expand Down
Loading

0 comments on commit c4c4ada

Please sign in to comment.