diff --git a/include/runtime/runtime_ir.hpp b/include/runtime/runtime_ir.hpp index 612474f0..b225d59c 100644 --- a/include/runtime/runtime_ir.hpp +++ b/include/runtime/runtime_ir.hpp @@ -51,7 +51,7 @@ class RuntimeGraph { * @param param_path Path to the parameter file defining the graph structure * @param bin_path Path to the bin file containing the graph weights */ - explicit RuntimeGraph(std::string param_path, std::string bin_path); + RuntimeGraph(std::string param_path, std::string bin_path); /** * @brief Sets the inputs to the graph @@ -61,8 +61,7 @@ class RuntimeGraph { * @param input_name Name of the input * @param inputs Vector of input tensors */ - template - void set_inputs(const std::string& input_name, const std::vector>& inputs); + void set_inputs(const std::string& input_name, const std::vector& inputs); /** * @brief Gets output tensors from the graph @@ -72,9 +71,7 @@ class RuntimeGraph { * @param output_name Name of the graph output * @return Vector of output tensors */ - - template - std::vector> get_outputs(const std::string& output_name) const; + std::vector get_outputs(const std::string& output_name) const; /** * @brief Checks if an op is an input op diff --git a/source/runtime/runtime_ir.cpp b/source/runtime/runtime_ir.cpp index ebca2d14..9c2156be 100644 --- a/source/runtime/runtime_ir.cpp +++ b/source/runtime/runtime_ir.cpp @@ -41,6 +41,8 @@ const std::string& RuntimeGraph::param_path() const { return this->param_path_; const std::string& RuntimeGraph::bin_path() const { return this->bin_path_; } +static bool IsQuantizeOp(const pnnx::Operator* op) { return false; } + bool RuntimeGraph::Init() { if (this->bin_path_.empty() || this->param_path_.empty()) { LOG(ERROR) << "The bin path or param path is empty"; @@ -63,26 +65,31 @@ bool RuntimeGraph::Init() { operators_.clear(); for (const pnnx::Operator* op : operators) { if (!op) { - LOG(ERROR) << "Meet the empty node"; + LOG(ERROR) << "Meet the empty node in the model"; continue; } else { - std::shared_ptr runtime_operator = std::make_shared(); - // 初始化算子的名称 - runtime_operator->name = op->name; - runtime_operator->type = op->type; + if (!IsQuantizeOp(op)) { + std::shared_ptr runtime_operator = std::make_shared(); + // 初始化算子的名称 + runtime_operator->name = op->name; + runtime_operator->type = op->type; - // 初始化算子中的input - InitGraphOperatorsInput(op->inputs, runtime_operator); + // 初始化算子中的input + InitGraphOperatorsInput(op->inputs, runtime_operator); - // 记录输出operand中的名称 - InitGraphOperatorsOutput(op->outputs, runtime_operator); + // 记录输出operand中的名称 + InitGraphOperatorsOutput(op->outputs, runtime_operator); - // 初始化算子中的attribute(权重) - InitGraphAttrs(op->attrs, runtime_operator); + // 初始化算子中的attribute(权重) + InitGraphAttrs(op->attrs, runtime_operator); - // 初始化算子中的parameter - InitGraphParams(op->params, runtime_operator); - this->operators_.push_back(runtime_operator); + // 初始化算子中的parameter + InitGraphParams(op->params, runtime_operator); + this->operators_.push_back(runtime_operator); + } else { + LOG(FATAL) << "UnSupported quantize operator in the model " << op->name + << " type: " << op->type; + } } } @@ -199,31 +206,34 @@ void RuntimeGraph::InitGraphOperatorsInput( if (!input) { continue; } + + std::vector dims; const pnnx::Operator* producer = input->producer; - std::shared_ptr runtime_operand = std::make_shared(); - runtime_operand->name = producer->name; for (int32_t dim : input->shape) { - runtime_operand->shapes.push_back(dim); + dims.push_back(dim); } - CHECK(!runtime_operand->shapes.empty()); + CHECK(!dims.empty()); + std::shared_ptr> runtime_operand = + std::make_shared>(); + runtime_operand->name = producer->name; + runtime_operand->shapes = dims; + runtime_operator->input_operands.insert({producer->name, runtime_operand}); + runtime_operator->input_operands_seq.push_back(runtime_operand); switch (input->type) { case 1: { runtime_operand->type = RuntimeDataType::kTypeFloat32; break; } - case 0: { - runtime_operand->type = RuntimeDataType::kTypeUnknown; + case 7: { + runtime_operand->type = RuntimeDataType::kTypeInt8; break; } default: { LOG(FATAL) << "Unknown input operand type: " << input->type; } } - - runtime_operator->input_operands.insert({producer->name, runtime_operand}); - runtime_operator->input_operands_seq.push_back(runtime_operand); } } @@ -441,9 +451,7 @@ void RuntimeGraph::CreateNodeRelation() { RuntimeGraph::GraphState RuntimeGraph::graph_state() const { return this->graph_state_; } -template -void RuntimeGraph::set_inputs(const std::string& input_name, - const std::vector>& inputs) { +void RuntimeGraph::set_inputs(const std::string& input_name, const std::vector& inputs) { CHECK(this->graph_state_ == GraphState::Complete); std::shared_ptr input_op; for (auto op : this->input_ops_) { @@ -456,8 +464,7 @@ void RuntimeGraph::set_inputs(const std::string& input_name, PropagateLayerOutputs(input_op, inputs); } -template -std::vector> RuntimeGraph::get_outputs(const std::string& output_name) const { +std::vector RuntimeGraph::get_outputs(const std::string& output_name) const { CHECK(this->graph_state_ == GraphState::Complete); std::shared_ptr output_op; for (auto op : this->output_ops_) {