Skip to content

Commit

Permalink
修改模型runtimeir中的模板实现
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Feb 8, 2024
1 parent 8c66f7c commit 5ccdb71
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 34 deletions.
9 changes: 3 additions & 6 deletions include/runtime/runtime_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class RuntimeGraph {
* @param param_path Path to the parameter file defining the graph structure
* @param bin_path Path to the bin file containing the graph weights
*/
explicit RuntimeGraph(std::string param_path, std::string bin_path);
RuntimeGraph(std::string param_path, std::string bin_path);

/**
* @brief Sets the inputs to the graph
Expand All @@ -61,8 +61,7 @@ class RuntimeGraph {
* @param input_name Name of the input
* @param inputs Vector of input tensors
*/
template <typename T>
void set_inputs(const std::string& input_name, const std::vector<stensor<T>>& inputs);
void set_inputs(const std::string& input_name, const std::vector<sftensor>& inputs);

/**
* @brief Gets output tensors from the graph
Expand All @@ -72,9 +71,7 @@ class RuntimeGraph {
* @param output_name Name of the graph output
* @return Vector of output tensors
*/

template <typename T>
std::vector<stensor<T>> get_outputs(const std::string& output_name) const;
std::vector<sftensor> get_outputs(const std::string& output_name) const;

/**
* @brief Checks if an op is an input op
Expand Down
63 changes: 35 additions & 28 deletions source/runtime/runtime_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ const std::string& RuntimeGraph::param_path() const { return this->param_path_;

const std::string& RuntimeGraph::bin_path() const { return this->bin_path_; }

static bool IsQuantizeOp(const pnnx::Operator* op) { return false; }

bool RuntimeGraph::Init() {
if (this->bin_path_.empty() || this->param_path_.empty()) {
LOG(ERROR) << "The bin path or param path is empty";
Expand All @@ -63,26 +65,31 @@ bool RuntimeGraph::Init() {
operators_.clear();
for (const pnnx::Operator* op : operators) {
if (!op) {
LOG(ERROR) << "Meet the empty node";
LOG(ERROR) << "Meet the empty node in the model";
continue;
} else {
std::shared_ptr<RuntimeOperator> runtime_operator = std::make_shared<RuntimeOperator>();
// 初始化算子的名称
runtime_operator->name = op->name;
runtime_operator->type = op->type;
if (!IsQuantizeOp(op)) {
std::shared_ptr<RuntimeOperator> runtime_operator = std::make_shared<RuntimeOperator>();
// 初始化算子的名称
runtime_operator->name = op->name;
runtime_operator->type = op->type;

// 初始化算子中的input
InitGraphOperatorsInput(op->inputs, runtime_operator);
// 初始化算子中的input
InitGraphOperatorsInput(op->inputs, runtime_operator);

// 记录输出operand中的名称
InitGraphOperatorsOutput(op->outputs, runtime_operator);
// 记录输出operand中的名称
InitGraphOperatorsOutput(op->outputs, runtime_operator);

// 初始化算子中的attribute(权重)
InitGraphAttrs(op->attrs, runtime_operator);
// 初始化算子中的attribute(权重)
InitGraphAttrs(op->attrs, runtime_operator);

// 初始化算子中的parameter
InitGraphParams(op->params, runtime_operator);
this->operators_.push_back(runtime_operator);
// 初始化算子中的parameter
InitGraphParams(op->params, runtime_operator);
this->operators_.push_back(runtime_operator);
} else {
LOG(FATAL) << "UnSupported quantize operator in the model " << op->name
<< " type: " << op->type;
}
}
}

Expand Down Expand Up @@ -199,31 +206,34 @@ void RuntimeGraph::InitGraphOperatorsInput(
if (!input) {
continue;
}

std::vector<int32_t> dims;
const pnnx::Operator* producer = input->producer;
std::shared_ptr<RuntimeOperand> runtime_operand = std::make_shared<RuntimeOperand>();

runtime_operand->name = producer->name;
for (int32_t dim : input->shape) {
runtime_operand->shapes.push_back(dim);
dims.push_back(dim);
}
CHECK(!runtime_operand->shapes.empty());
CHECK(!dims.empty());
std::shared_ptr<RuntimeOperandBase<T>> runtime_operand =
std::make_shared<RuntimeOperandBase<T>>();
runtime_operand->name = producer->name;
runtime_operand->shapes = dims;
runtime_operator->input_operands.insert({producer->name, runtime_operand});
runtime_operator->input_operands_seq.push_back(runtime_operand);

switch (input->type) {
case 1: {
runtime_operand->type = RuntimeDataType::kTypeFloat32;
break;
}
case 0: {
runtime_operand->type = RuntimeDataType::kTypeUnknown;
case 7: {
runtime_operand->type = RuntimeDataType::kTypeInt8;
break;
}
default: {
LOG(FATAL) << "Unknown input operand type: " << input->type;
}
}

runtime_operator->input_operands.insert({producer->name, runtime_operand});
runtime_operator->input_operands_seq.push_back(runtime_operand);
}
}

Expand Down Expand Up @@ -441,9 +451,7 @@ void RuntimeGraph::CreateNodeRelation() {

RuntimeGraph::GraphState RuntimeGraph::graph_state() const { return this->graph_state_; }

template <typename T>
void RuntimeGraph::set_inputs(const std::string& input_name,
const std::vector<stensor<T>>& inputs) {
void RuntimeGraph::set_inputs(const std::string& input_name, const std::vector<sftensor>& inputs) {
CHECK(this->graph_state_ == GraphState::Complete);
std::shared_ptr<RuntimeOperator> input_op;
for (auto op : this->input_ops_) {
Expand All @@ -456,8 +464,7 @@ void RuntimeGraph::set_inputs(const std::string& input_name,
PropagateLayerOutputs(input_op, inputs);
}

template <typename T>
std::vector<stensor<T>> RuntimeGraph::get_outputs(const std::string& output_name) const {
std::vector<sftensor> RuntimeGraph::get_outputs(const std::string& output_name) const {
CHECK(this->graph_state_ == GraphState::Complete);
std::shared_ptr<RuntimeOperator> output_op;
for (auto op : this->output_ops_) {
Expand Down

0 comments on commit 5ccdb71

Please sign in to comment.