Skip to content

Commit

Permalink
封装conv的计算方法
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Aug 27, 2023
1 parent 765db5f commit 21d0b60
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 21 deletions.
52 changes: 31 additions & 21 deletions source/layer/details/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,27 +247,9 @@ InferStatus ConvolutionLayer::Forward(
CHECK(input_c_group == kernel_c)
<< "The number of channel for the kernel "
"matrix and input tensor do not match";

if (conv_type_ == ConvType::OpConv) {
const arma::fmat& input_matrix = ConvIm2Col(
input, kernel_h, kernel_w, input_h, input_w, input_c_group,
output_h, output_w, g, kernel_h * kernel_w, output_h * output_w);
#pragma omp parallel for
for (uint32_t k = 0; k < kernel_count_group; ++k) {
ConvGemmBias(input_matrix, output_tensor, g, k, kernel_count_group,
output_h, output_w);
}
} else {
CHECK(conv_type_ == ConvType::OpDeconv);
#pragma omp parallel for
for (uint32_t k = 0; k < kernel_count_group; ++k) {
const arma::fmat& gemm_result = DeconvGemm(
input, input_h, input_w, input_c_group, g, k, kernel_count_group);
DeconvCol2ImBias(gemm_result, output_tensor, input_h, input_w, g, k,
kernel_count_group, kernel_h, kernel_w, output_h,
output_w);
}
}
ComputeOutput(input, output_tensor, kernel_h, kernel_w,
kernel_count_group, input_h, input_w, input_c_group,
output_h, output_w, g);
}
}
return InferStatus::kInferSuccess;
Expand Down Expand Up @@ -478,6 +460,34 @@ void ConvolutionLayer::InitIm2ColWeight() {
this->kernel_matrix_arr_ = std::move(kernel_matrix_arr);
}

void ConvolutionLayer::ComputeOutput(sftensor input, sftensor output_tensor,
uint32_t kernel_h, uint32_t kernel_w,
uint32_t kernel_count_group,
uint32_t input_h, uint32_t input_w,
uint32_t input_c_group, uint32_t output_h,
uint32_t output_w, uint32_t group) {
if (conv_type_ == ConvType::OpConv) {
const arma::fmat& input_matrix = ConvIm2Col(
input, kernel_h, kernel_w, input_h, input_w, input_c_group, output_h,
output_w, group, kernel_h * kernel_w, output_h * output_w);
#pragma omp parallel for
for (uint32_t k = 0; k < kernel_count_group; ++k) {
ConvGemmBias(input_matrix, output_tensor, group, k, kernel_count_group,
output_h, output_w);
}
} else {
CHECK(conv_type_ == ConvType::OpDeconv);
#pragma omp parallel for
for (uint32_t k = 0; k < kernel_count_group; ++k) {
const arma::fmat& gemm_result = DeconvGemm(
input, input_h, input_w, input_c_group, group, k, kernel_count_group);
DeconvCol2ImBias(gemm_result, output_tensor, input_h, input_w, group, k,
kernel_count_group, kernel_h, kernel_w, output_h,
output_w);
}
}
}

std::pair<uint32_t, uint32_t> ConvolutionLayer::ComputeOutputSize(
const uint32_t input_h, const uint32_t input_w, const uint32_t kernel_h,
const uint32_t kernel_w) {
Expand Down
5 changes: 5 additions & 0 deletions source/layer/details/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class ConvolutionLayer : public ParamLayer {
void InitIm2ColWeight();

private:
void ComputeOutput(sftensor input, sftensor output_tensor, uint32_t kernel_h,
uint32_t kernel_w, uint32_t kernel_count_group,
uint32_t input_h, uint32_t input_w, uint32_t input_c_group,
uint32_t output_h, uint32_t output_w, uint32_t group);

std::pair<uint32_t, uint32_t> ComputeOutputSize(const uint32_t input_h,
const uint32_t input_w,
const uint32_t kernel_h,
Expand Down

0 comments on commit 21d0b60

Please sign in to comment.