aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVictor Stone <victorstone@google.com>2024-02-08 16:51:27 -0800
committerTensorFlower Gardener <gardener@tensorflow.org>2024-02-08 16:57:01 -0800
commita4cb4209c8f3f9c9497c421fb2926a8feb63efe6 (patch)
treebc22cb8a960ac3ba870fc386d5caea595923f1b1
parentfa0f62b268dbd9687432e00d827db59cbe6c83f7 (diff)
downloadtensorflow-a4cb4209c8f3f9c9497c421fb2926a8feb63efe6.tar.gz
[XLA] Open source pass which converts external host memory offload annotations to internal annotations.
PiperOrigin-RevId: 605465947
-rw-r--r--third_party/xla/xla/service/BUILD44
-rw-r--r--third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc99
-rw-r--r--third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h44
-rw-r--r--third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc259
-rw-r--r--third_party/xla/xla/service/host_memory_offload_annotations.h42
-rw-r--r--third_party/xla/xla/service/host_offloader.cc23
-rw-r--r--third_party/xla/xla/service/host_offloader.h4
-rw-r--r--third_party/xla/xla/service/host_offloader_test.cc6
8 files changed, 508 insertions, 13 deletions
diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD
index 3bb989be129..8960d7e2f03 100644
--- a/third_party/xla/xla/service/BUILD
+++ b/third_party/xla/xla/service/BUILD
@@ -5853,6 +5853,48 @@ xla_cc_test(
)
cc_library(
+ name = "host_memory_offload_annotations_hdr",
+ hdrs = ["host_memory_offload_annotations.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@com_google_absl//absl/strings:string_view",
+ ],
+)
+
+cc_library(
+ name = "convert_memory_placement_to_internal_annotations",
+ srcs = ["convert_memory_placement_to_internal_annotations.cc"],
+ hdrs = ["convert_memory_placement_to_internal_annotations.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":host_memory_offload_annotations_hdr",
+ "//xla:side_effect_util",
+ "//xla:statusor",
+ "//xla:util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "//xla/service:hlo_pass",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/strings:string_view",
+ "@local_tsl//tsl/platform:errors",
+ ],
+)
+
+xla_cc_test(
+ name = "convert_memory_placement_to_internal_annotations_test",
+ srcs = ["convert_memory_placement_to_internal_annotations_test.cc"],
+ deps = [
+ ":convert_memory_placement_to_internal_annotations",
+ "//xla:statusor",
+ "//xla:xla_data_proto_cc",
+ "//xla/tests:hlo_test_base",
+ "//xla/tests:xla_internal_test_main",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
name = "host_memory_transfer_asyncifier",
srcs = ["host_memory_transfer_asyncifier.cc"],
hdrs = ["host_memory_transfer_asyncifier.h"],
@@ -5902,6 +5944,7 @@ cc_library(
":hlo_buffer",
":hlo_pass",
":hlo_value",
+ ":host_memory_offload_annotations_hdr",
":pattern_matcher",
"//xla:literal_util",
"//xla:shape_util",
@@ -5924,6 +5967,7 @@ xla_cc_test(
name = "host_offloader_test",
srcs = ["host_offloader_test.cc"],
deps = [
+ ":host_memory_offload_annotations_hdr",
":host_offloader",
":pattern_matcher",
":pattern_matcher_gmock",
diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc
new file mode 100644
index 00000000000..05b35a99b89
--- /dev/null
+++ b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc
@@ -0,0 +1,99 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================*/
+
+#include "xla/service/convert_memory_placement_to_internal_annotations.h"
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/host_memory_offload_annotations.h"
+#include "xla/side_effect_util.h"
+#include "xla/statusor.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+
+namespace xla {
+
+StatusOr<bool> ConvertMemoryPlacementToInternalAnnotations::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (HloComputation* c : module->MakeNonfusionComputations()) {
+ for (HloInstruction* instruction : c->MakeInstructionPostOrder()) {
+ if (instruction->IsCustomCall(
+ host_memory_offload_annotations::kDevicePlacement)) {
+ const auto& frontend_attributes = instruction->frontend_attributes();
+ const auto it = frontend_attributes.map().find(kXlaBufferPlacementAttr);
+ if (it == frontend_attributes.map().end()) {
+ continue;
+ }
+ const bool is_to_host_case =
+ it->second == host_memory_offload_annotations::kMemoryTargetHost;
+ const bool is_to_device_case =
+ (it->second ==
+ host_memory_offload_annotations::kMemoryTargetDeviceTpu ||
+ it->second ==
+ host_memory_offload_annotations::kMemoryTargetDeviceGpu);
+ if (!is_to_host_case && !is_to_device_case) {
+ continue;
+ }
+ if (is_to_host_case) {
+ VLOG(1) << "Process forward case: " << instruction->ToString();
+ if (instruction->users().size() != 1) {
+ VLOG(1) << "Skip because of too many users on instruction";
+ continue;
+ }
+ if (instruction->operand_count() != 1) {
+ return Internal(
+ "Custom calls with target %s must have exactly one operand. %s "
+ "has %d.",
+ host_memory_offload_annotations::kDevicePlacement,
+ instruction->name(), instruction->operand_count());
+ }
+ HloInstruction* input = instruction->mutable_operand(0);
+ TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(
+ c->AddInstruction(HloInstruction::CreateCustomCall(
+ input->shape(), {input},
+ host_memory_offload_annotations::
+ kMoveToHostCustomCallTarget))));
+ TF_RETURN_IF_ERROR(
+ c->RemoveInstructionAndUnusedOperands(instruction));
+ changed = true;
+ } else if (is_to_device_case) {
+ VLOG(1) << "Process backward case: " << instruction->ToString();
+ HloInstruction* custom_call_operand = instruction->mutable_operand(0);
+ if (custom_call_operand->users().size() != 1) {
+ VLOG(1) << "Skip because operand is used by more than one user";
+ continue;
+ }
+ HloInstruction* new_result =
+ c->AddInstruction(HloInstruction::CreateCustomCall(
+ custom_call_operand->shape(), {custom_call_operand},
+ host_memory_offload_annotations::
+ kMoveToDeviceCustomCallTarget));
+ TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_result));
+ TF_RETURN_IF_ERROR(
+ c->RemoveInstructionAndUnusedOperands(instruction));
+ changed = true;
+ }
+ }
+ }
+ }
+ return changed;
+}
+
+} // namespace xla
diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h
new file mode 100644
index 00000000000..87fff9d715e
--- /dev/null
+++ b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h
@@ -0,0 +1,44 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================*/
+
+#ifndef XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_
+#define XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_
+
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+class ConvertMemoryPlacementToInternalAnnotations : public HloModulePass {
+ public:
+ ConvertMemoryPlacementToInternalAnnotations() = default;
+
+ absl::string_view name() const override {
+ return "convert-memory-placement-to-internal-annotations";
+ }
+ using HloPassInterface::Run;
+ StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+} // namespace xla
+
+#endif // XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_
diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc
new file mode 100644
index 00000000000..4639738d66a
--- /dev/null
+++ b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc
@@ -0,0 +1,259 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================*/
+
+#include "xla/service/convert_memory_placement_to_internal_annotations.h"
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/statusor.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+class ConvertMemoryPlacementToInternalAnnotationsTest : public HloTestBase {
+ public:
+ ConvertMemoryPlacementToInternalAnnotationsTest() = default;
+};
+
+TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, ConvertTest) {
+ const char* hlo_string = R"(
+HloModule jit_f, entry_computation_layout={(f32[16]{0})->f32[16]{0}}
+
+region_0.9 {
+ arg_tuple.10 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0)
+ get-tuple-element.11 = s32[] get-tuple-element(arg_tuple.10), index=0
+ constant.15 = s32[] constant(1)
+ add.33 = s32[] add(get-tuple-element.11, constant.15)
+ get-tuple-element.12 = f32[16]{0} get-tuple-element(arg_tuple.10), index=1
+ sine.18 = f32[16]{0} sine(get-tuple-element.12)
+ sine.19 = f32[16]{0} sine(sine.18)
+ sine.20 = f32[16]{0} sine(sine.19)
+ get-tuple-element.13 = f32[16,16]{1,0} get-tuple-element(arg_tuple.10), index=2
+ custom-call.21 = f32[16]{0} custom-call(sine.19), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="pinned_host"}
+ reshape.23 = f32[1,16]{1,0} reshape(custom-call.21)
+ constant.17 = s32[] constant(0)
+ compare.24 = pred[] compare(get-tuple-element.11, constant.17), direction=LT
+ constant.16 = s32[] constant(16)
+ add.25 = s32[] add(get-tuple-element.11, constant.16)
+ select.26 = s32[] select(compare.24, add.25, get-tuple-element.11)
+ dynamic-update-slice.27 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.13, reshape.23, select.26, constant.17)
+ get-tuple-element.14 = f32[16,16]{1,0} get-tuple-element(arg_tuple.10), index=3
+ custom-call.22 = f32[16]{0} custom-call(sine.20), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="pinned_host"}
+ reshape.28 = f32[1,16]{1,0} reshape(custom-call.22)
+ compare.29 = pred[] compare(get-tuple-element.11, constant.17), direction=LT
+ add.30 = s32[] add(get-tuple-element.11, constant.16)
+ select.31 = s32[] select(compare.29, add.30, get-tuple-element.11)
+ dynamic-update-slice.32 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.14, reshape.28, select.31, constant.17)
+ ROOT tuple.34 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(add.33, sine.20, dynamic-update-slice.27, dynamic-update-slice.32)
+}
+
+region_1.35 {
+ arg_tuple.36 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0)
+ get-tuple-element.38 = f32[16]{0} get-tuple-element(arg_tuple.36), index=1
+ get-tuple-element.39 = f32[16,16]{1,0} get-tuple-element(arg_tuple.36), index=2
+ get-tuple-element.40 = f32[16,16]{1,0} get-tuple-element(arg_tuple.36), index=3
+ get-tuple-element.37 = s32[] get-tuple-element(arg_tuple.36), index=0
+ constant.41 = s32[] constant(16)
+ ROOT compare.42 = pred[] compare(get-tuple-element.37, constant.41), direction=LT
+}
+
+core_closed_call.43 {
+ constant.47 = s32[] constant(0)
+ Arg_0.44 = f32[16]{0} parameter(0)
+ constant.45 = f32[] constant(0)
+ broadcast.46 = f32[16,16]{1,0} broadcast(constant.45), dimensions={}
+ tuple.48 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(constant.47, Arg_0.44, broadcast.46, broadcast.46)
+ while.49 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) while(tuple.48), condition=region_1.35, body=region_0.9
+ get-tuple-element.50 = s32[] get-tuple-element(while.49), index=0
+ get-tuple-element.51 = f32[16]{0} get-tuple-element(while.49), index=1
+ get-tuple-element.52 = f32[16,16]{1,0} get-tuple-element(while.49), index=2
+ get-tuple-element.53 = f32[16,16]{1,0} get-tuple-element(while.49), index=3
+ ROOT tuple.54 = (f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(get-tuple-element.52, get-tuple-element.53)
+}
+
+region_2.65 {
+ arg_tuple.66 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0)
+ get-tuple-element.67 = s32[] get-tuple-element(arg_tuple.66), index=0
+ constant.74 = s32[] constant(1)
+ add.108 = s32[] add(get-tuple-element.67, constant.74)
+ get-tuple-element.73 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=6
+ constant.76 = s32[] constant(0)
+ compare.82 = pred[] compare(get-tuple-element.67, constant.76), direction=LT
+ constant.75 = s32[] constant(16)
+ add.83 = s32[] add(get-tuple-element.67, constant.75)
+ select.84 = s32[] select(compare.82, add.83, get-tuple-element.67)
+ dynamic-slice.85 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.73, select.84, constant.76), dynamic_slice_sizes={1,16}
+ reshape.86 = f32[16]{0} reshape(dynamic-slice.85)
+ custom-call.87 = f32[16]{0} custom-call(reshape.86), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="tpu_hbm"}
+ get-tuple-element.69 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=2
+ get-tuple-element.68 = f32[16]{0} get-tuple-element(arg_tuple.66), index=1
+ cosine.88 = f32[16]{0} cosine(get-tuple-element.68)
+ reshape.93 = f32[1,16]{1,0} reshape(cosine.88)
+ compare.94 = pred[] compare(get-tuple-element.67, constant.76), direction=LT
+ add.95 = s32[] add(get-tuple-element.67, constant.75)
+ select.96 = s32[] select(compare.94, add.95, get-tuple-element.67)
+ dynamic-update-slice.97 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.69, reshape.93, select.96, constant.76)
+ get-tuple-element.70 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=3
+ sine.89 = f32[16]{0} sine(get-tuple-element.68)
+ cosine.90 = f32[16]{0} cosine(sine.89)
+ reshape.98 = f32[1,16]{1,0} reshape(cosine.90)
+ compare.99 = pred[] compare(get-tuple-element.67, constant.76), direction=LT
+ add.100 = s32[] add(get-tuple-element.67, constant.75)
+ select.101 = s32[] select(compare.99, add.100, get-tuple-element.67)
+ dynamic-update-slice.102 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.70, reshape.98, select.101, constant.76)
+ get-tuple-element.71 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=4
+ get-tuple-element.72 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=5
+ compare.77 = pred[] compare(get-tuple-element.67, constant.76), direction=LT
+ add.78 = s32[] add(get-tuple-element.67, constant.75)
+ select.79 = s32[] select(compare.77, add.78, get-tuple-element.67)
+ dynamic-slice.80 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.72, select.79, constant.76), dynamic_slice_sizes={1,16}
+ reshape.81 = f32[16]{0} reshape(dynamic-slice.80)
+ custom-call.91 = f32[16]{0} custom-call(reshape.81), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="tpu_hbm"}
+ cosine.92 = f32[16]{0} cosine(custom-call.91)
+ reshape.103 = f32[1,16]{1,0} reshape(cosine.92)
+ compare.104 = pred[] compare(get-tuple-element.67, constant.76), direction=LT
+ add.105 = s32[] add(get-tuple-element.67, constant.75)
+ select.106 = s32[] select(compare.104, add.105, get-tuple-element.67)
+ dynamic-update-slice.107 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.71, reshape.103, select.106, constant.76)
+ ROOT tuple.109 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(add.108, custom-call.87, dynamic-update-slice.97, dynamic-update-slice.102, dynamic-update-slice.107, get-tuple-element.72, get-tuple-element.73)
+}
+
+region_3.110 {
+ arg_tuple.111 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0)
+ get-tuple-element.113 = f32[16]{0} get-tuple-element(arg_tuple.111), index=1
+ get-tuple-element.114 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=2
+ get-tuple-element.115 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=3
+ get-tuple-element.116 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=4
+ get-tuple-element.117 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=5
+ get-tuple-element.118 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=6
+ get-tuple-element.112 = s32[] get-tuple-element(arg_tuple.111), index=0
+ constant.119 = s32[] constant(16)
+ ROOT compare.120 = pred[] compare(get-tuple-element.112, constant.119), direction=LT
+}
+
+region_4.130 {
+ arg_tuple.131 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) parameter(0)
+ get-tuple-element.132 = s32[] get-tuple-element(arg_tuple.131), index=0
+ constant.140 = s32[] constant(1)
+ add.164 = s32[] add(get-tuple-element.132, constant.140)
+ get-tuple-element.133 = f32[16]{0} get-tuple-element(arg_tuple.131), index=1
+ get-tuple-element.134 = f32[] get-tuple-element(arg_tuple.131), index=2
+ broadcast.159 = f32[16]{0} broadcast(get-tuple-element.134), dimensions={}
+ add.160 = f32[16]{0} add(get-tuple-element.133, broadcast.159)
+ get-tuple-element.137 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=5
+ constant.141 = s32[] constant(16)
+ subtract.142 = s32[] subtract(constant.141, get-tuple-element.132)
+ subtract.143 = s32[] subtract(subtract.142, constant.140)
+ constant.139 = s32[] constant(0)
+ compare.154 = pred[] compare(subtract.143, constant.139), direction=LT
+ add.155 = s32[] add(subtract.143, constant.141)
+ select.156 = s32[] select(compare.154, add.155, subtract.143)
+ dynamic-slice.157 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.137, select.156, constant.139), dynamic_slice_sizes={1,16}
+ reshape.158 = f32[16]{0} reshape(dynamic-slice.157)
+ multiply.161 = f32[16]{0} multiply(add.160, reshape.158)
+ get-tuple-element.136 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=4
+ compare.149 = pred[] compare(subtract.143, constant.139), direction=LT
+ add.150 = s32[] add(subtract.143, constant.141)
+ select.151 = s32[] select(compare.149, add.150, subtract.143)
+ dynamic-slice.152 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.136, select.151, constant.139), dynamic_slice_sizes={1,16}
+ reshape.153 = f32[16]{0} reshape(dynamic-slice.152)
+ multiply.162 = f32[16]{0} multiply(multiply.161, reshape.153)
+ get-tuple-element.135 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=3
+ compare.144 = pred[] compare(subtract.143, constant.139), direction=LT
+ add.145 = s32[] add(subtract.143, constant.141)
+ select.146 = s32[] select(compare.144, add.145, subtract.143)
+ dynamic-slice.147 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.135, select.146, constant.139), dynamic_slice_sizes={1,16}
+ reshape.148 = f32[16]{0} reshape(dynamic-slice.147)
+ multiply.163 = f32[16]{0} multiply(multiply.162, reshape.148)
+ constant.138 = f32[] constant(0)
+ ROOT tuple.165 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) tuple(add.164, multiply.163, constant.138, get-tuple-element.135, get-tuple-element.136, get-tuple-element.137)
+}
+
+region_5.166 {
+ arg_tuple.167 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) parameter(0)
+ get-tuple-element.169 = f32[16]{0} get-tuple-element(arg_tuple.167), index=1
+ get-tuple-element.170 = f32[] get-tuple-element(arg_tuple.167), index=2
+ get-tuple-element.171 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=3
+ get-tuple-element.172 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=4
+ get-tuple-element.173 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=5
+ get-tuple-element.168 = s32[] get-tuple-element(arg_tuple.167), index=0
+ constant.174 = s32[] constant(16)
+ ROOT compare.175 = pred[] compare(get-tuple-element.168, constant.174), direction=LT
+}
+
+ENTRY main.183 {
+ constant.6 = s32[] constant(0)
+ Arg_0.1 = f32[16]{0} parameter(0), sharding={devices=[2]<=[2]}
+ call.55 = (f32[16,16]{1,0}, f32[16,16]{1,0}) call(Arg_0.1), to_apply=core_closed_call.43
+ get-tuple-element.56 = f32[16,16]{1,0} get-tuple-element(call.55), index=0
+ get-tuple-element.57 = f32[16,16]{1,0} get-tuple-element(call.55), index=1
+ constant.7 = f32[] constant(1)
+ tuple.58 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16]{0}, f32[]) tuple(get-tuple-element.56, get-tuple-element.57, Arg_0.1, constant.7)
+ opt-barrier.59 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16]{0}, f32[]) opt-barrier(tuple.58)
+ get-tuple-element.62 = f32[16]{0} get-tuple-element(opt-barrier.59), index=2
+ constant.4 = f32[] constant(0)
+ broadcast.5 = f32[16,16]{1,0} broadcast(constant.4), dimensions={}
+ get-tuple-element.60 = f32[16,16]{1,0} get-tuple-element(opt-barrier.59), index=0
+ get-tuple-element.61 = f32[16,16]{1,0} get-tuple-element(opt-barrier.59), index=1
+ tuple.64 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(constant.6, get-tuple-element.62, broadcast.5, broadcast.5, broadcast.5, get-tuple-element.60, get-tuple-element.61)
+ while.121 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) while(tuple.64), condition=region_3.110, body=region_2.65
+ get-tuple-element.122 = s32[] get-tuple-element(while.121), index=0
+ get-tuple-element.123 = f32[16]{0} get-tuple-element(while.121), index=1
+ get-tuple-element.127 = f32[16,16]{1,0} get-tuple-element(while.121), index=5
+ get-tuple-element.128 = f32[16,16]{1,0} get-tuple-element(while.121), index=6
+ constant.2 = f32[] constant(0)
+ broadcast.3 = f32[16]{0} broadcast(constant.2), dimensions={}
+ get-tuple-element.63 = f32[] get-tuple-element(opt-barrier.59), index=3
+ get-tuple-element.124 = f32[16,16]{1,0} get-tuple-element(while.121), index=2
+ get-tuple-element.125 = f32[16,16]{1,0} get-tuple-element(while.121), index=3
+ get-tuple-element.126 = f32[16,16]{1,0} get-tuple-element(while.121), index=4
+ tuple.129 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) tuple(constant.6, broadcast.3, get-tuple-element.63, get-tuple-element.124, get-tuple-element.125, get-tuple-element.126)
+ while.176 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) while(tuple.129), condition=region_5.166, body=region_4.130
+ get-tuple-element.177 = s32[] get-tuple-element(while.176), index=0
+ ROOT get-tuple-element.178 = f32[16]{0} get-tuple-element(while.176), index=1
+ get-tuple-element.179 = f32[] get-tuple-element(while.176), index=2
+ get-tuple-element.180 = f32[16,16]{1,0} get-tuple-element(while.176), index=3
+ get-tuple-element.181 = f32[16,16]{1,0} get-tuple-element(while.176), index=4
+ get-tuple-element.182 = f32[16,16]{1,0} get-tuple-element(while.176), index=5
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ bool changed =
+ ConvertMemoryPlacementToInternalAnnotations().Run(module.get()).value();
+ EXPECT_TRUE(changed);
+ XLA_VLOG_LINES(1, module->ToString());
+ int64_t custom_calls_count = 0;
+ for (auto* c : module->computations()) {
+ for (auto* instr : c->instructions()) {
+ if (instr->IsCustomCall("PipelineForward") ||
+ instr->IsCustomCall("PipelineBackward")) {
+ ++custom_calls_count;
+ }
+ }
+ }
+ EXPECT_EQ(custom_calls_count, 4);
+}
+
+} // namespace
+} // namespace xla
diff --git a/third_party/xla/xla/service/host_memory_offload_annotations.h b/third_party/xla/xla/service/host_memory_offload_annotations.h
new file mode 100644
index 00000000000..4cdb5866b42
--- /dev/null
+++ b/third_party/xla/xla/service/host_memory_offload_annotations.h
@@ -0,0 +1,42 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================*/
+
+#ifndef XLA_SERVICE_HOST_MEMORY_OFFLOAD_ANNOTATIONS_H_
+#define XLA_SERVICE_HOST_MEMORY_OFFLOAD_ANNOTATIONS_H_
+
+#include "absl/strings/string_view.h"
+
+namespace xla {
+namespace host_memory_offload_annotations {
+
+// External annotations:
+inline const absl::string_view kDevicePlacement = "annotate_device_placement";
+inline const absl::string_view kMemoryTargetHost = "pinned_host";
+inline const absl::string_view kMemoryTargetDeviceTpu = "tpu_hbm";
+inline const absl::string_view kMemoryTargetDeviceGpu = "gpu_hbm";
+
+// Internal annotations:
+// This are currently called PipelineForward/PipelineBackward, because they were
+// originally meant as a hook point for the collective-pipeliner.
+// They do more than just that though (identify memory movement direction), so
+// should be renamed to something related to memory movement.
+inline const absl::string_view kMoveToHostCustomCallTarget = "PipelineForward";
+inline const absl::string_view kMoveToDeviceCustomCallTarget =
+ "PipelineBackward";
+
+} // namespace host_memory_offload_annotations
+} // namespace xla
+
+#endif // XLA_SERVICE_HOST_MEMORY_OFFLOAD_ANNOTATIONS_H_
diff --git a/third_party/xla/xla/service/host_offloader.cc b/third_party/xla/xla/service/host_offloader.cc
index a5ed02ad1c2..3275fc71c3f 100644
--- a/third_party/xla/xla/service/host_offloader.cc
+++ b/third_party/xla/xla/service/host_offloader.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "xla/literal_util.h"
#include "xla/service/hlo_buffer.h"
#include "xla/service/hlo_value.h"
+#include "xla/service/host_memory_offload_annotations.h"
#include "xla/service/pattern_matcher.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
@@ -148,7 +149,8 @@ StatusOr<HloInstruction*> BufferHasPositionWithUser(const HloBuffer& buffer,
}
HloInstruction* FindDSAnnotation(HloInstruction* hlo) {
- while (!hlo->IsCustomCall(HostOffloader::kPipelineBackwardTarget)) {
+ while (!hlo->IsCustomCall(
+ host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) {
if (hlo->opcode() != HloOpcode::kReshape &&
hlo->opcode() != HloOpcode::kBitcast) {
break;
@@ -301,7 +303,8 @@ Status HostOffloader::MemoryOnlyOffloadStartingWithDus(
if (consuming_ds_user->opcode() != HloOpcode::kCustomCall) {
return Internal("Dynamic-slice does not have a matching annotation.");
}
- if (consuming_ds_user->custom_call_target() != kPipelineBackwardTarget) {
+ if (consuming_ds_user->custom_call_target() !=
+ host_memory_offload_annotations::kMoveToDeviceCustomCallTarget) {
return Internal(
"Found custom-call is not the expected matching host offload "
"annotation");
@@ -396,7 +399,8 @@ Status HostOffloader::MemoryOnlyOffloadStartingWithCopy(
if (consuming_copy_user->opcode() != HloOpcode::kCustomCall) {
return Internal("Copy does not have a matching annotation.");
}
- if (consuming_copy_user->custom_call_target() != kPipelineBackwardTarget) {
+ if (consuming_copy_user->custom_call_target() !=
+ host_memory_offload_annotations::kMoveToDeviceCustomCallTarget) {
return Internal(
"Found custom-call is not the expected matching host offload "
"annotation");
@@ -426,8 +430,10 @@ Status HostOffloader::MemoryOnlyOffloadInsertCopies(
// Check that this buffer is finally an input to a load-from-host custom-call.
TF_ASSIGN_OR_RETURN(
HloInstruction * matching_annotation,
- BufferHasPositionWithUser(unique_buffer,
- match::CustomCall({kPipelineBackwardTarget})));
+ BufferHasPositionWithUser(
+ unique_buffer,
+ match::CustomCall({host_memory_offload_annotations::
+ kMoveToDeviceCustomCallTarget})));
if (matching_annotation == nullptr) {
return Internal(
"The offloaded data (from %s) never feeds into a matching \"load\" "
@@ -531,9 +537,12 @@ StatusOr<bool> HostOffloader::Run(
if (instruction->opcode() != HloOpcode::kCustomCall) {
continue;
}
- if (instruction->custom_call_target() == kPipelineForwardTarget) {
+ if (instruction->custom_call_target() ==
+ host_memory_offload_annotations::kMoveToHostCustomCallTarget) {
TF_RETURN_IF_ERROR(HandlePipelineForwardCustomCall(instruction));
- } else if (instruction->custom_call_target() == kPipelineBackwardTarget) {
+ } else if (instruction->custom_call_target() ==
+ host_memory_offload_annotations::
+ kMoveToDeviceCustomCallTarget) {
TF_RETURN_IF_ERROR(HandlePipelineBackwardCustomCall(instruction));
}
}
diff --git a/third_party/xla/xla/service/host_offloader.h b/third_party/xla/xla/service/host_offloader.h
index cd5ea7e5d06..666fbc5d178 100644
--- a/third_party/xla/xla/service/host_offloader.h
+++ b/third_party/xla/xla/service/host_offloader.h
@@ -35,10 +35,6 @@ class HloCostAnalysis;
// an error will be returned.
class HostOffloader : public HloModulePass {
public:
- static constexpr absl::string_view kPipelineForwardTarget = "PipelineForward";
- static constexpr absl::string_view kPipelineBackwardTarget =
- "PipelineBackward";
-
explicit HostOffloader(int64_t host_memory_space_color)
: kHostMemorySpaceColor(host_memory_space_color) {}
~HostOffloader() override = default;
diff --git a/third_party/xla/xla/service/host_offloader_test.cc b/third_party/xla/xla/service/host_offloader_test.cc
index b843063cc3f..55c0ab3f70b 100644
--- a/third_party/xla/xla/service/host_offloader_test.cc
+++ b/third_party/xla/xla/service/host_offloader_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/host_memory_offload_annotations.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/shape.h"
@@ -64,8 +65,9 @@ class HostOffloaderTest : public HloTestBase {
for (const HloComputation* computation : module->computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
if (instruction->IsCustomCall(
- {HostOffloader::kPipelineForwardTarget,
- HostOffloader::kPipelineBackwardTarget})) {
+ {host_memory_offload_annotations::kMoveToHostCustomCallTarget,
+ host_memory_offload_annotations::
+ kMoveToDeviceCustomCallTarget})) {
return true;
}
}