diff --git a/source/layer/details/matmul.cpp b/source/layer/details/matmul.cpp index 21bef513..9ab08bbb 100644 --- a/source/layer/details/matmul.cpp +++ b/source/layer/details/matmul.cpp @@ -114,8 +114,13 @@ StatusCode LLamaMatmulLayer::Forward(const std::vectorraw_ptr(), weight_dim0_,input_dim1, false, true); - output_mat = (input_vec * weight_data).t(); + if (input_dim1 == 1 || weight_dim0_ == 1) { + arma::fmat output_mat(output->raw_ptr(), input_dim1, weight_dim0_, false, true); + output_mat = input_vec * weight_data; + } else { + arma::fmat output_mat(output->raw_ptr(), weight_dim0_, input_dim1, false, true); + output_mat = (input_vec * weight_data).t(); + } } return StatusCode::kSuccess; } diff --git a/test/test_layer/test_matmul.cpp b/test/test_layer/test_matmul.cpp index fa157188..0c6140b7 100644 --- a/test/test_layer/test_matmul.cpp +++ b/test/test_layer/test_matmul.cpp @@ -65,3 +65,28 @@ TEST(test_layer, forward_matmul2) { ASSERT_EQ(output->at(0, 1, 0), 38); ASSERT_EQ(output->at(0, 2, 0), 62); } + + +TEST(test_layer, forward_matmul3) { + using namespace kuiper_infer; + std::vector weights; + for (int i = 0; i < 3 * 4; ++i) { + weights.push_back(float(i)); + } + LLamaMatmulLayer llama_matmul(1, 4); + std::shared_ptr> weight = std::make_shared>(weights.data(), 1, 4); + llama_matmul.set_weights({weight}); + + std::vector inputs; + for (int i = 0; i < 4; ++i) { + inputs.push_back(float(i)); + } + + std::shared_ptr> input = std::make_shared>(inputs.data(), 4, 1); + std::shared_ptr> output = std::make_shared>(1, 1); + std::vector outputs; + outputs.push_back(output); + + llama_matmul.Forward({input}, outputs); + ASSERT_EQ(output->at(0, 0, 0), 14); +}