Skip to content

Commit

Permalink
增加matmul算子的单元测试
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Feb 10, 2024
1 parent 2458ba5 commit 3de632b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
9 changes: 7 additions & 2 deletions source/layer/details/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,13 @@ StatusCode LLamaMatmulLayer::Forward(const std::vector<std::shared_ptr<Tensor<fl
} else {
LOG(FATAL) << "The shape of output tensor need be equal to one or two";
}
arma::fmat output_mat(output->raw_ptr(), weight_dim0_,input_dim1, false, true);
output_mat = (input_vec * weight_data).t();
if (input_dim1 == 1 || weight_dim0_ == 1) {
arma::fmat output_mat(output->raw_ptr(), input_dim1, weight_dim0_, false, true);
output_mat = input_vec * weight_data;
} else {
arma::fmat output_mat(output->raw_ptr(), weight_dim0_, input_dim1, false, true);
output_mat = (input_vec * weight_data).t();
}
}
return StatusCode::kSuccess;
}
Expand Down
25 changes: 25 additions & 0 deletions test/test_layer/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,28 @@ TEST(test_layer, forward_matmul2) {
ASSERT_EQ(output->at(0, 1, 0), 38);
ASSERT_EQ(output->at(0, 2, 0), 62);
}


TEST(test_layer, forward_matmul3) {
using namespace kuiper_infer;
std::vector<float> weights;
for (int i = 0; i < 3 * 4; ++i) {
weights.push_back(float(i));
}
LLamaMatmulLayer llama_matmul(1, 4);
std::shared_ptr<Tensor<float>> weight = std::make_shared<Tensor<float>>(weights.data(), 1, 4);
llama_matmul.set_weights({weight});

std::vector<float> inputs;
for (int i = 0; i < 4; ++i) {
inputs.push_back(float(i));
}

std::shared_ptr<Tensor<float>> input = std::make_shared<Tensor<float>>(inputs.data(), 4, 1);
std::shared_ptr<Tensor<float>> output = std::make_shared<Tensor<float>>(1, 1);
std::vector<sftensor> outputs;
outputs.push_back(output);

llama_matmul.Forward({input}, outputs);
ASSERT_EQ(output->at(0, 0, 0), 14);
}

0 comments on commit 3de632b

Please sign in to comment.