Skip to content

Commit

Permalink
hl: Refactor utilities.
Browse files Browse the repository at this point in the history
  • Loading branch information
xlauko committed Mar 7, 2024
1 parent 6a91537 commit 0744287
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 47 deletions.
4 changes: 2 additions & 2 deletions include/vast/Dialect/HighLevel/HighLevelOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ class RecordLikeDeclOp< string mnemonic, string concrete_name, list< Trait > tra
let extraClassDefinition = [{
// AggregateTypeDefinitionInterface

gap::generator<mlir::Type> }] # concrete_name # [{ ::getFieldTypes() {
gap::generator< mlir_type > }] # concrete_name # [{ ::getFieldTypes() {
return hl::get_field_types(*this);
}

gap::generator<std::tuple<std::string, mlir::Type>>}] # concrete_name # [{ ::getFieldsInfo() {
gap::generator< vast::field_info_t >}] # concrete_name # [{ ::getFieldsInfo() {
return hl::get_fields_info(*this);
}

Expand Down
90 changes: 48 additions & 42 deletions include/vast/Dialect/HighLevel/HighLevelUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@

#include <gap/core/generator.hpp>

#include <ranges>

/* Contains common utilities often needed to work with hl dialect. */

namespace vast::hl {

using aggregate_interface = AggregateTypeDefinitionInterface;

static inline gap::generator< mlir_type > get_field_types(auto op) {
for (auto [_, type] : get_fields_info(op)) {
for (auto &&[_, type] : get_fields_info(op)) {
co_yield type;
}
}

static inline gap::generator< std::tuple< std::string, mlir_type > > get_fields_info(auto op
) {
gap::generator< field_info_t > get_fields_info(auto op) {
for (auto &maybe_field : op.getOps()) {
// Definition of nested structure, we ignore not a field.
if (mlir::isa< aggregate_interface >(maybe_field)) {
Expand All @@ -36,12 +37,11 @@ namespace vast::hl {

auto field_decl = mlir::dyn_cast< hl::FieldDeclOp >(maybe_field);
VAST_ASSERT(field_decl);
co_yield std::make_tuple(field_decl.getName().str(), field_decl.getType());
co_yield { field_decl.getName().str(), field_decl.getType() };
}
}

static inline gap::generator< aggregate_interface >
get_nested_declarations(auto op) {
gap::generator< aggregate_interface > get_nested_declarations(auto op) {
for (auto &maybe_field : op.getOps()) {
if (auto casted = mlir::dyn_cast< aggregate_interface >(maybe_field)) {
co_yield casted;
Expand All @@ -52,78 +52,84 @@ namespace vast::hl {
// TODO(hl): This is a placeholder that works in our test cases so far.
// In general, we will need generic resolution for scoping that
// will be used instead of this function.
static inline auto definition_of(mlir::Type t, vast_module module_op)
-> aggregate_interface
{
auto type_name = hl::name_of_record(t);
VAST_CHECK(type_name, "hl::name_of_record failed with {0}", t);
aggregate_interface definition_of(mlir_type ty, auto scope) {
auto type_name = hl::name_of_record(ty);
VAST_CHECK(type_name, "hl::name_of_record failed with {0}", ty);

aggregate_interface out;
;
auto walker = [&](aggregate_interface op) {
if (op.getDefinedName() == type_name) {
out = op;
return mlir::WalkResult::interrupt();
return walk_result::interrupt();
}
return mlir::WalkResult::advance();
};
module_op->walk(walker);
scope->walk(walker);
return out;
}

static inline auto field_types(mlir::Type t, vast_module module_op)
-> gap::generator< mlir_type > {
auto def = definition_of(t, module_op);
VAST_CHECK(def, "Was not able to fetch definition of type: {0}", t);
gap::generator< mlir_type > field_types(mlir_type ty, auto scope) {
auto def = definition_of(ty, scope);
VAST_CHECK(def, "Was not able to fetch definition of type: {0}", ty);
return def.getFieldTypes();
}

static inline hl::ImplicitCastOp
implicit_cast_lvalue_to_rvalue(auto &rewriter, auto loc, auto lvalue_op) {
auto lvalue_type = mlir::dyn_cast< hl::LValueType >(lvalue_op.getType());
VAST_ASSERT(lvalue_type);
hl::ImplicitCastOp implicit_cast_lvalue_to_rvalue(auto &rewriter, auto loc, auto lvalue_op) {
auto value_type = mlir::dyn_cast< hl::LValueType >(lvalue_op.getType());
VAST_ASSERT(value_type);
return rewriter.template create< hl::ImplicitCastOp >(
loc, lvalue_type.getElementType(), lvalue_op, hl::CastKind::LValueToRValue
loc, value_type.getElementType(), lvalue_op, hl::CastKind::LValueToRValue
);
}

// Given record `root` emit `hl::RecordMemberOp` for each its member.
static inline auto generate_ptrs_to_record_members(operation root, auto loc, auto &bld)
-> gap::generator< hl::RecordMemberOp > {
auto module_op = root->getParentOfType< vast_module >();
VAST_ASSERT(module_op);
auto def = definition_of(root->getResultTypes()[0], module_op);
auto generate_ptrs_to_record_members(operation root, auto loc, auto &bld)
-> gap::generator< hl::RecordMemberOp >
{
auto scope = root->getParentOfType< vast_module >();
VAST_ASSERT(scope);
VAST_ASSERT(root->getNumResults() == 1);
auto def = definition_of(root->getResultTypes()[0], scope);
VAST_CHECK(def, "Was not able to fetch definition of type from: {0}", *root);

for (const auto &[name, type] : def.getFieldsInfo()) {
VAST_ASSERT(root->getNumResults() == 1);
auto as_val = root->getResult(0);
auto as_val = root->getResult(0);
// `hl.member` requires type to be an lvalue.
auto wrap_type = hl::LValueType::get(module_op.getContext(), type);
auto wrap_type = hl::LValueType::get(scope.getContext(), type);
co_yield bld.template create< hl::RecordMemberOp >(loc, wrap_type, as_val, name);
}
}

// Given record `root` emit `hl::RecordMemberOp` casted as rvalue for each
// its member.
static inline auto generate_values_of_record_members(operation root, auto &bld)
-> gap::generator< hl::ImplicitCastOp > {
auto generate_values_of_record_members(operation root, auto &bld)
-> gap::generator< hl::ImplicitCastOp >
{
for (auto member_ptr : generate_ptrs_to_members(root, bld)) {
co_yield implicit_cast_lvalue_to_rvalue(bld, member_ptr->getLoc(), member_ptr);
}
}

static inline std::optional< std::size_t >
field_idx(llvm::StringRef name, aggregate_interface agg) {
std::size_t idx = 0;
// `llvm::enumerate` is unhappy when coroutine is passed in.
for (const auto &[field_name, _] : agg.getFieldsInfo()) {
if (field_name == name) {
return { idx };
namespace detail
{
template <typename T>
std::vector<T> to_vector(gap::generator<T> &&gen) {
std::vector<T> result;
std::ranges::copy(gen, std::back_inserter(result));
return result;
}
} // namespace detail

static inline auto field_index(string_ref name, aggregate_interface agg)
-> std::optional< std::size_t >
{
for (const auto &field : llvm::enumerate(to_vector(agg.getFieldsInfo()))) {
if (field.value().name == name) {
return field.index();
}
++idx;
}
return {};

return std::nullopt;
}

template< typename yield_t >
Expand Down
13 changes: 13 additions & 0 deletions include/vast/Interfaces/AggregateTypeDefinitionInterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,20 @@ VAST_RELAX_WARNINGS
#include <mlir/IR/OperationSupport.h>
VAST_RELAX_WARNINGS

#include "vast/Util/Common.hpp"

#include <gap/core/generator.hpp>

namespace vast {

struct field_info_t
{
std::string name;
mlir_type type;
};

} // namespace vast


/// Include the generated interface declarations.
#include "vast/Interfaces/AggregateTypeDefinitionInterface.h.inc"

Check failure on line 30 in include/vast/Interfaces/AggregateTypeDefinitionInterface.hpp

View workflow job for this annotation

GitHub Actions / cpp-linter (17, 22.04)

include/vast/Interfaces/AggregateTypeDefinitionInterface.hpp:30:10 [clang-diagnostic-error]

'vast/Interfaces/AggregateTypeDefinitionInterface.h.inc' file not found
3 changes: 1 addition & 2 deletions include/vast/Interfaces/AggregateTypeDefinitionInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ def AggregateTypeDefinition
"gap::generator< mlir::Type >", "getFieldTypes", (ins), [{}] >,

InterfaceMethod< "Return all elements in order of their declaration.",
"gap::generator< std::tuple< std::string, mlir::Type > >",
"getFieldsInfo", (ins), [{}] >,
"gap::generator< vast::field_info_t >", "getFieldsInfo", (ins), [{}] >,

InterfaceMethod< "Return all nested definitions",
"gap::generator< vast::AggregateTypeDefinitionInterface >",
Expand Down
2 changes: 1 addition & 1 deletion lib/vast/Conversion/FromHL/ToLLGEPs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace vast {
op_t op, typename op_t::Adaptor ops, conversion_rewriter &rewriter,
hl::StructDeclOp struct_decl
) const {
auto idx = hl::field_idx(op.getName(), struct_decl);
auto idx = hl::field_index(op.getName(), struct_decl);
if (!idx) {
return mlir::failure();
}
Expand Down

0 comments on commit 0744287

Please sign in to comment.