From 398c632b6f7c01656b908442df2b5a789a7b04e4 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 16 Oct 2024 12:09:01 -0700 Subject: [PATCH] [hlo] Do not overwrite derived instruction backend config Derived instruction might set its own backend config, and it's not safe to overwrite it with an original one. Fix for: https://github.com/openxla/xla/issues/18214 PiperOrigin-RevId: 686594662 --- xla/hlo/ir/hlo_instruction.cc | 8 +++++--- xla/service/hlo_instruction_test.cc | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/xla/hlo/ir/hlo_instruction.cc b/xla/hlo/ir/hlo_instruction.cc index 652a08c5e75a6..4a1ba75535cec 100644 --- a/xla/hlo/ir/hlo_instruction.cc +++ b/xla/hlo/ir/hlo_instruction.cc @@ -2216,9 +2216,11 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->mutable_rare()->frontend_attributes.Clear(); derived_instruction->mutable_rare()->statistics_viz.Clear(); } - // If the derived instruction has the same opcode as current, - // then the backend config is also applicable. - if (opcode() == derived_instruction->opcode() && has_backend_config()) { + // If the derived instruction has the same opcode as current, then the backend + // config is also applicable (only if derived instruction doesn't have its own + // backend config which might be different from the original one). + if (opcode() == derived_instruction->opcode() && has_backend_config() && + !derived_instruction->has_backend_config()) { derived_instruction->CopyBackendConfigFrom(this); } } diff --git a/xla/service/hlo_instruction_test.cc b/xla/service/hlo_instruction_test.cc index 1db4a17c8fc39..19ab4d15e8131 100644 --- a/xla/service/hlo_instruction_test.cc +++ b/xla/service/hlo_instruction_test.cc @@ -2898,6 +2898,30 @@ TEST_F(HloInstructionTest, BackendConfigNotCopiedToDerivedWithDiffOpcode) { EXPECT_FALSE(add2->has_backend_config()); } +TEST_F(HloInstructionTest, BackendConfigNotCopiedToDerivedWithConfig) { + HloComputation::Builder b(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto p0 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto p1 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p1")); + auto add = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p1)); + + gpu::GpuBackendConfig gpu_config0; + gpu::GpuBackendConfig gpu_config1; + gpu_config0.set_operation_queue_id(2); + gpu_config1.set_operation_queue_id(3); + + TF_ASSERT_OK(add->set_backend_config(gpu_config0)); + auto add2 = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0)); + TF_ASSERT_OK(add2->set_backend_config(gpu_config1)); + + add->SetupDerivedInstruction(add2); + auto backend_config = add2->backend_config(); + EXPECT_TRUE(backend_config.ok()); + EXPECT_EQ(backend_config->operation_queue_id(), 3); +} + TEST_F(HloInstructionTest, MergeMultiOutputProducerFusionIntoMultiOutputFusion) { const std::string& hlo_string = R"(