diff options
author | Victor Stone <victorstone@google.com> | 2024-02-08 16:51:27 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2024-02-08 16:57:01 -0800 |
commit | a4cb4209c8f3f9c9497c421fb2926a8feb63efe6 (patch) | |
tree | bc22cb8a960ac3ba870fc386d5caea595923f1b1 | |
parent | fa0f62b268dbd9687432e00d827db59cbe6c83f7 (diff) | |
download | tensorflow-a4cb4209c8f3f9c9497c421fb2926a8feb63efe6.tar.gz |
[XLA] Open source pass which converts external host memory offload annotations to internal annotations.
PiperOrigin-RevId: 605465947
8 files changed, 508 insertions, 13 deletions
diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 3bb989be129..8960d7e2f03 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -5853,6 +5853,48 @@ xla_cc_test( ) cc_library( + name = "host_memory_offload_annotations_hdr", + hdrs = ["host_memory_offload_annotations.h"], + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "convert_memory_placement_to_internal_annotations", + srcs = ["convert_memory_placement_to_internal_annotations.cc"], + hdrs = ["convert_memory_placement_to_internal_annotations.h"], + visibility = ["//visibility:public"], + deps = [ + ":host_memory_offload_annotations_hdr", + "//xla:side_effect_util", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "convert_memory_placement_to_internal_annotations_test", + srcs = ["convert_memory_placement_to_internal_annotations_test.cc"], + deps = [ + ":convert_memory_placement_to_internal_annotations", + "//xla:statusor", + "//xla:xla_data_proto_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + ], +) + +cc_library( name = "host_memory_transfer_asyncifier", srcs = ["host_memory_transfer_asyncifier.cc"], hdrs = ["host_memory_transfer_asyncifier.h"], @@ -5902,6 +5944,7 @@ cc_library( ":hlo_buffer", ":hlo_pass", ":hlo_value", + ":host_memory_offload_annotations_hdr", ":pattern_matcher", "//xla:literal_util", "//xla:shape_util", @@ -5924,6 +5967,7 @@ xla_cc_test( name = "host_offloader_test", srcs = ["host_offloader_test.cc"], deps = [ + ":host_memory_offload_annotations_hdr", ":host_offloader", ":pattern_matcher", ":pattern_matcher_gmock", diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc new file mode 100644 index 00000000000..05b35a99b89 --- /dev/null +++ b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc @@ -0,0 +1,99 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#include "xla/service/convert_memory_placement_to_internal_annotations.h" + +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/side_effect_util.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" + +namespace xla { + +StatusOr<bool> ConvertMemoryPlacementToInternalAnnotations::Run( + HloModule* module, + const absl::flat_hash_set<absl::string_view>& execution_threads) { + bool changed = false; + for (HloComputation* c : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : c->MakeInstructionPostOrder()) { + if (instruction->IsCustomCall( + host_memory_offload_annotations::kDevicePlacement)) { + const auto& frontend_attributes = instruction->frontend_attributes(); + const auto it = frontend_attributes.map().find(kXlaBufferPlacementAttr); + if (it == frontend_attributes.map().end()) { + continue; + } + const bool is_to_host_case = + it->second == host_memory_offload_annotations::kMemoryTargetHost; + const bool is_to_device_case = + (it->second == + host_memory_offload_annotations::kMemoryTargetDeviceTpu || + it->second == + host_memory_offload_annotations::kMemoryTargetDeviceGpu); + if (!is_to_host_case && !is_to_device_case) { + continue; + } + if (is_to_host_case) { + VLOG(1) << "Process forward case: " << instruction->ToString(); + if (instruction->users().size() != 1) { + VLOG(1) << "Skip because of too many users on instruction"; + continue; + } + if (instruction->operand_count() != 1) { + return Internal( + "Custom calls with target %s must have exactly one operand. %s " + "has %d.", + host_memory_offload_annotations::kDevicePlacement, + instruction->name(), instruction->operand_count()); + } + HloInstruction* input = instruction->mutable_operand(0); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith( + c->AddInstruction(HloInstruction::CreateCustomCall( + input->shape(), {input}, + host_memory_offload_annotations:: + kMoveToHostCustomCallTarget)))); + TF_RETURN_IF_ERROR( + c->RemoveInstructionAndUnusedOperands(instruction)); + changed = true; + } else if (is_to_device_case) { + VLOG(1) << "Process backward case: " << instruction->ToString(); + HloInstruction* custom_call_operand = instruction->mutable_operand(0); + if (custom_call_operand->users().size() != 1) { + VLOG(1) << "Skip because operand is used by more than one user"; + continue; + } + HloInstruction* new_result = + c->AddInstruction(HloInstruction::CreateCustomCall( + custom_call_operand->shape(), {custom_call_operand}, + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget)); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_result)); + TF_RETURN_IF_ERROR( + c->RemoveInstructionAndUnusedOperands(instruction)); + changed = true; + } + } + } + } + return changed; +} + +} // namespace xla diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h new file mode 100644 index 00000000000..87fff9d715e --- /dev/null +++ b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#ifndef XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ +#define XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ + +#include <optional> +#include <string> +#include <vector> + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { + +class ConvertMemoryPlacementToInternalAnnotations : public HloModulePass { + public: + ConvertMemoryPlacementToInternalAnnotations() = default; + + absl::string_view name() const override { + return "convert-memory-placement-to-internal-annotations"; + } + using HloPassInterface::Run; + StatusOr<bool> Run( + HloModule* module, + const absl::flat_hash_set<absl::string_view>& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc new file mode 100644 index 00000000000..4639738d66a --- /dev/null +++ b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc @@ -0,0 +1,259 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#include "xla/service/convert_memory_placement_to_internal_annotations.h" + +#include <memory> +#include <optional> +#include <string> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "xla/statusor.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace { + +class ConvertMemoryPlacementToInternalAnnotationsTest : public HloTestBase { + public: + ConvertMemoryPlacementToInternalAnnotationsTest() = default; +}; + +TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, ConvertTest) { + const char* hlo_string = R"( +HloModule jit_f, entry_computation_layout={(f32[16]{0})->f32[16]{0}} + +region_0.9 { + arg_tuple.10 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.11 = s32[] get-tuple-element(arg_tuple.10), index=0 + constant.15 = s32[] constant(1) + add.33 = s32[] add(get-tuple-element.11, constant.15) + get-tuple-element.12 = f32[16]{0} get-tuple-element(arg_tuple.10), index=1 + sine.18 = f32[16]{0} sine(get-tuple-element.12) + sine.19 = f32[16]{0} sine(sine.18) + sine.20 = f32[16]{0} sine(sine.19) + get-tuple-element.13 = f32[16,16]{1,0} get-tuple-element(arg_tuple.10), index=2 + custom-call.21 = f32[16]{0} custom-call(sine.19), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="pinned_host"} + reshape.23 = f32[1,16]{1,0} reshape(custom-call.21) + constant.17 = s32[] constant(0) + compare.24 = pred[] compare(get-tuple-element.11, constant.17), direction=LT + constant.16 = s32[] constant(16) + add.25 = s32[] add(get-tuple-element.11, constant.16) + select.26 = s32[] select(compare.24, add.25, get-tuple-element.11) + dynamic-update-slice.27 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.13, reshape.23, select.26, constant.17) + get-tuple-element.14 = f32[16,16]{1,0} get-tuple-element(arg_tuple.10), index=3 + custom-call.22 = f32[16]{0} custom-call(sine.20), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="pinned_host"} + reshape.28 = f32[1,16]{1,0} reshape(custom-call.22) + compare.29 = pred[] compare(get-tuple-element.11, constant.17), direction=LT + add.30 = s32[] add(get-tuple-element.11, constant.16) + select.31 = s32[] select(compare.29, add.30, get-tuple-element.11) + dynamic-update-slice.32 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.14, reshape.28, select.31, constant.17) + ROOT tuple.34 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(add.33, sine.20, dynamic-update-slice.27, dynamic-update-slice.32) +} + +region_1.35 { + arg_tuple.36 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.38 = f32[16]{0} get-tuple-element(arg_tuple.36), index=1 + get-tuple-element.39 = f32[16,16]{1,0} get-tuple-element(arg_tuple.36), index=2 + get-tuple-element.40 = f32[16,16]{1,0} get-tuple-element(arg_tuple.36), index=3 + get-tuple-element.37 = s32[] get-tuple-element(arg_tuple.36), index=0 + constant.41 = s32[] constant(16) + ROOT compare.42 = pred[] compare(get-tuple-element.37, constant.41), direction=LT +} + +core_closed_call.43 { + constant.47 = s32[] constant(0) + Arg_0.44 = f32[16]{0} parameter(0) + constant.45 = f32[] constant(0) + broadcast.46 = f32[16,16]{1,0} broadcast(constant.45), dimensions={} + tuple.48 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(constant.47, Arg_0.44, broadcast.46, broadcast.46) + while.49 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) while(tuple.48), condition=region_1.35, body=region_0.9 + get-tuple-element.50 = s32[] get-tuple-element(while.49), index=0 + get-tuple-element.51 = f32[16]{0} get-tuple-element(while.49), index=1 + get-tuple-element.52 = f32[16,16]{1,0} get-tuple-element(while.49), index=2 + get-tuple-element.53 = f32[16,16]{1,0} get-tuple-element(while.49), index=3 + ROOT tuple.54 = (f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(get-tuple-element.52, get-tuple-element.53) +} + +region_2.65 { + arg_tuple.66 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.67 = s32[] get-tuple-element(arg_tuple.66), index=0 + constant.74 = s32[] constant(1) + add.108 = s32[] add(get-tuple-element.67, constant.74) + get-tuple-element.73 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=6 + constant.76 = s32[] constant(0) + compare.82 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + constant.75 = s32[] constant(16) + add.83 = s32[] add(get-tuple-element.67, constant.75) + select.84 = s32[] select(compare.82, add.83, get-tuple-element.67) + dynamic-slice.85 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.73, select.84, constant.76), dynamic_slice_sizes={1,16} + reshape.86 = f32[16]{0} reshape(dynamic-slice.85) + custom-call.87 = f32[16]{0} custom-call(reshape.86), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="tpu_hbm"} + get-tuple-element.69 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=2 + get-tuple-element.68 = f32[16]{0} get-tuple-element(arg_tuple.66), index=1 + cosine.88 = f32[16]{0} cosine(get-tuple-element.68) + reshape.93 = f32[1,16]{1,0} reshape(cosine.88) + compare.94 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.95 = s32[] add(get-tuple-element.67, constant.75) + select.96 = s32[] select(compare.94, add.95, get-tuple-element.67) + dynamic-update-slice.97 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.69, reshape.93, select.96, constant.76) + get-tuple-element.70 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=3 + sine.89 = f32[16]{0} sine(get-tuple-element.68) + cosine.90 = f32[16]{0} cosine(sine.89) + reshape.98 = f32[1,16]{1,0} reshape(cosine.90) + compare.99 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.100 = s32[] add(get-tuple-element.67, constant.75) + select.101 = s32[] select(compare.99, add.100, get-tuple-element.67) + dynamic-update-slice.102 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.70, reshape.98, select.101, constant.76) + get-tuple-element.71 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=4 + get-tuple-element.72 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=5 + compare.77 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.78 = s32[] add(get-tuple-element.67, constant.75) + select.79 = s32[] select(compare.77, add.78, get-tuple-element.67) + dynamic-slice.80 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.72, select.79, constant.76), dynamic_slice_sizes={1,16} + reshape.81 = f32[16]{0} reshape(dynamic-slice.80) + custom-call.91 = f32[16]{0} custom-call(reshape.81), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="tpu_hbm"} + cosine.92 = f32[16]{0} cosine(custom-call.91) + reshape.103 = f32[1,16]{1,0} reshape(cosine.92) + compare.104 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.105 = s32[] add(get-tuple-element.67, constant.75) + select.106 = s32[] select(compare.104, add.105, get-tuple-element.67) + dynamic-update-slice.107 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.71, reshape.103, select.106, constant.76) + ROOT tuple.109 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(add.108, custom-call.87, dynamic-update-slice.97, dynamic-update-slice.102, dynamic-update-slice.107, get-tuple-element.72, get-tuple-element.73) +} + +region_3.110 { + arg_tuple.111 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.113 = f32[16]{0} get-tuple-element(arg_tuple.111), index=1 + get-tuple-element.114 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=2 + get-tuple-element.115 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=3 + get-tuple-element.116 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=4 + get-tuple-element.117 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=5 + get-tuple-element.118 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=6 + get-tuple-element.112 = s32[] get-tuple-element(arg_tuple.111), index=0 + constant.119 = s32[] constant(16) + ROOT compare.120 = pred[] compare(get-tuple-element.112, constant.119), direction=LT +} + +region_4.130 { + arg_tuple.131 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) parameter(0) + get-tuple-element.132 = s32[] get-tuple-element(arg_tuple.131), index=0 + constant.140 = s32[] constant(1) + add.164 = s32[] add(get-tuple-element.132, constant.140) + get-tuple-element.133 = f32[16]{0} get-tuple-element(arg_tuple.131), index=1 + get-tuple-element.134 = f32[] get-tuple-element(arg_tuple.131), index=2 + broadcast.159 = f32[16]{0} broadcast(get-tuple-element.134), dimensions={} + add.160 = f32[16]{0} add(get-tuple-element.133, broadcast.159) + get-tuple-element.137 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=5 + constant.141 = s32[] constant(16) + subtract.142 = s32[] subtract(constant.141, get-tuple-element.132) + subtract.143 = s32[] subtract(subtract.142, constant.140) + constant.139 = s32[] constant(0) + compare.154 = pred[] compare(subtract.143, constant.139), direction=LT + add.155 = s32[] add(subtract.143, constant.141) + select.156 = s32[] select(compare.154, add.155, subtract.143) + dynamic-slice.157 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.137, select.156, constant.139), dynamic_slice_sizes={1,16} + reshape.158 = f32[16]{0} reshape(dynamic-slice.157) + multiply.161 = f32[16]{0} multiply(add.160, reshape.158) + get-tuple-element.136 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=4 + compare.149 = pred[] compare(subtract.143, constant.139), direction=LT + add.150 = s32[] add(subtract.143, constant.141) + select.151 = s32[] select(compare.149, add.150, subtract.143) + dynamic-slice.152 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.136, select.151, constant.139), dynamic_slice_sizes={1,16} + reshape.153 = f32[16]{0} reshape(dynamic-slice.152) + multiply.162 = f32[16]{0} multiply(multiply.161, reshape.153) + get-tuple-element.135 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=3 + compare.144 = pred[] compare(subtract.143, constant.139), direction=LT + add.145 = s32[] add(subtract.143, constant.141) + select.146 = s32[] select(compare.144, add.145, subtract.143) + dynamic-slice.147 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.135, select.146, constant.139), dynamic_slice_sizes={1,16} + reshape.148 = f32[16]{0} reshape(dynamic-slice.147) + multiply.163 = f32[16]{0} multiply(multiply.162, reshape.148) + constant.138 = f32[] constant(0) + ROOT tuple.165 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) tuple(add.164, multiply.163, constant.138, get-tuple-element.135, get-tuple-element.136, get-tuple-element.137) +} + +region_5.166 { + arg_tuple.167 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) parameter(0) + get-tuple-element.169 = f32[16]{0} get-tuple-element(arg_tuple.167), index=1 + get-tuple-element.170 = f32[] get-tuple-element(arg_tuple.167), index=2 + get-tuple-element.171 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=3 + get-tuple-element.172 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=4 + get-tuple-element.173 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=5 + get-tuple-element.168 = s32[] get-tuple-element(arg_tuple.167), index=0 + constant.174 = s32[] constant(16) + ROOT compare.175 = pred[] compare(get-tuple-element.168, constant.174), direction=LT +} + +ENTRY main.183 { + constant.6 = s32[] constant(0) + Arg_0.1 = f32[16]{0} parameter(0), sharding={devices=[2]<=[2]} + call.55 = (f32[16,16]{1,0}, f32[16,16]{1,0}) call(Arg_0.1), to_apply=core_closed_call.43 + get-tuple-element.56 = f32[16,16]{1,0} get-tuple-element(call.55), index=0 + get-tuple-element.57 = f32[16,16]{1,0} get-tuple-element(call.55), index=1 + constant.7 = f32[] constant(1) + tuple.58 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16]{0}, f32[]) tuple(get-tuple-element.56, get-tuple-element.57, Arg_0.1, constant.7) + opt-barrier.59 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16]{0}, f32[]) opt-barrier(tuple.58) + get-tuple-element.62 = f32[16]{0} get-tuple-element(opt-barrier.59), index=2 + constant.4 = f32[] constant(0) + broadcast.5 = f32[16,16]{1,0} broadcast(constant.4), dimensions={} + get-tuple-element.60 = f32[16,16]{1,0} get-tuple-element(opt-barrier.59), index=0 + get-tuple-element.61 = f32[16,16]{1,0} get-tuple-element(opt-barrier.59), index=1 + tuple.64 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(constant.6, get-tuple-element.62, broadcast.5, broadcast.5, broadcast.5, get-tuple-element.60, get-tuple-element.61) + while.121 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) while(tuple.64), condition=region_3.110, body=region_2.65 + get-tuple-element.122 = s32[] get-tuple-element(while.121), index=0 + get-tuple-element.123 = f32[16]{0} get-tuple-element(while.121), index=1 + get-tuple-element.127 = f32[16,16]{1,0} get-tuple-element(while.121), index=5 + get-tuple-element.128 = f32[16,16]{1,0} get-tuple-element(while.121), index=6 + constant.2 = f32[] constant(0) + broadcast.3 = f32[16]{0} broadcast(constant.2), dimensions={} + get-tuple-element.63 = f32[] get-tuple-element(opt-barrier.59), index=3 + get-tuple-element.124 = f32[16,16]{1,0} get-tuple-element(while.121), index=2 + get-tuple-element.125 = f32[16,16]{1,0} get-tuple-element(while.121), index=3 + get-tuple-element.126 = f32[16,16]{1,0} get-tuple-element(while.121), index=4 + tuple.129 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) tuple(constant.6, broadcast.3, get-tuple-element.63, get-tuple-element.124, get-tuple-element.125, get-tuple-element.126) + while.176 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) while(tuple.129), condition=region_5.166, body=region_4.130 + get-tuple-element.177 = s32[] get-tuple-element(while.176), index=0 + ROOT get-tuple-element.178 = f32[16]{0} get-tuple-element(while.176), index=1 + get-tuple-element.179 = f32[] get-tuple-element(while.176), index=2 + get-tuple-element.180 = f32[16,16]{1,0} get-tuple-element(while.176), index=3 + get-tuple-element.181 = f32[16,16]{1,0} get-tuple-element(while.176), index=4 + get-tuple-element.182 = f32[16,16]{1,0} get-tuple-element(while.176), index=5 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + bool changed = + ConvertMemoryPlacementToInternalAnnotations().Run(module.get()).value(); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + int64_t custom_calls_count = 0; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + if (instr->IsCustomCall("PipelineForward") || + instr->IsCustomCall("PipelineBackward")) { + ++custom_calls_count; + } + } + } + EXPECT_EQ(custom_calls_count, 4); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/host_memory_offload_annotations.h b/third_party/xla/xla/service/host_memory_offload_annotations.h new file mode 100644 index 00000000000..4cdb5866b42 --- /dev/null +++ b/third_party/xla/xla/service/host_memory_offload_annotations.h @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#ifndef XLA_SERVICE_HOST_MEMORY_OFFLOAD_ANNOTATIONS_H_ +#define XLA_SERVICE_HOST_MEMORY_OFFLOAD_ANNOTATIONS_H_ + +#include "absl/strings/string_view.h" + +namespace xla { +namespace host_memory_offload_annotations { + +// External annotations: +inline const absl::string_view kDevicePlacement = "annotate_device_placement"; +inline const absl::string_view kMemoryTargetHost = "pinned_host"; +inline const absl::string_view kMemoryTargetDeviceTpu = "tpu_hbm"; +inline const absl::string_view kMemoryTargetDeviceGpu = "gpu_hbm"; + +// Internal annotations: +// This are currently called PipelineForward/PipelineBackward, because they were +// originally meant as a hook point for the collective-pipeliner. +// They do more than just that though (identify memory movement direction), so +// should be renamed to something related to memory movement. +inline const absl::string_view kMoveToHostCustomCallTarget = "PipelineForward"; +inline const absl::string_view kMoveToDeviceCustomCallTarget = + "PipelineBackward"; + +} // namespace host_memory_offload_annotations +} // namespace xla + +#endif // XLA_SERVICE_HOST_MEMORY_OFFLOAD_ANNOTATIONS_H_ diff --git a/third_party/xla/xla/service/host_offloader.cc b/third_party/xla/xla/service/host_offloader.cc index a5ed02ad1c2..3275fc71c3f 100644 --- a/third_party/xla/xla/service/host_offloader.cc +++ b/third_party/xla/xla/service/host_offloader.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_value.h" +#include "xla/service/host_memory_offload_annotations.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -148,7 +149,8 @@ StatusOr<HloInstruction*> BufferHasPositionWithUser(const HloBuffer& buffer, } HloInstruction* FindDSAnnotation(HloInstruction* hlo) { - while (!hlo->IsCustomCall(HostOffloader::kPipelineBackwardTarget)) { + while (!hlo->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { if (hlo->opcode() != HloOpcode::kReshape && hlo->opcode() != HloOpcode::kBitcast) { break; @@ -301,7 +303,8 @@ Status HostOffloader::MemoryOnlyOffloadStartingWithDus( if (consuming_ds_user->opcode() != HloOpcode::kCustomCall) { return Internal("Dynamic-slice does not have a matching annotation."); } - if (consuming_ds_user->custom_call_target() != kPipelineBackwardTarget) { + if (consuming_ds_user->custom_call_target() != + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget) { return Internal( "Found custom-call is not the expected matching host offload " "annotation"); @@ -396,7 +399,8 @@ Status HostOffloader::MemoryOnlyOffloadStartingWithCopy( if (consuming_copy_user->opcode() != HloOpcode::kCustomCall) { return Internal("Copy does not have a matching annotation."); } - if (consuming_copy_user->custom_call_target() != kPipelineBackwardTarget) { + if (consuming_copy_user->custom_call_target() != + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget) { return Internal( "Found custom-call is not the expected matching host offload " "annotation"); @@ -426,8 +430,10 @@ Status HostOffloader::MemoryOnlyOffloadInsertCopies( // Check that this buffer is finally an input to a load-from-host custom-call. TF_ASSIGN_OR_RETURN( HloInstruction * matching_annotation, - BufferHasPositionWithUser(unique_buffer, - match::CustomCall({kPipelineBackwardTarget}))); + BufferHasPositionWithUser( + unique_buffer, + match::CustomCall({host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget}))); if (matching_annotation == nullptr) { return Internal( "The offloaded data (from %s) never feeds into a matching \"load\" " @@ -531,9 +537,12 @@ StatusOr<bool> HostOffloader::Run( if (instruction->opcode() != HloOpcode::kCustomCall) { continue; } - if (instruction->custom_call_target() == kPipelineForwardTarget) { + if (instruction->custom_call_target() == + host_memory_offload_annotations::kMoveToHostCustomCallTarget) { TF_RETURN_IF_ERROR(HandlePipelineForwardCustomCall(instruction)); - } else if (instruction->custom_call_target() == kPipelineBackwardTarget) { + } else if (instruction->custom_call_target() == + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget) { TF_RETURN_IF_ERROR(HandlePipelineBackwardCustomCall(instruction)); } } diff --git a/third_party/xla/xla/service/host_offloader.h b/third_party/xla/xla/service/host_offloader.h index cd5ea7e5d06..666fbc5d178 100644 --- a/third_party/xla/xla/service/host_offloader.h +++ b/third_party/xla/xla/service/host_offloader.h @@ -35,10 +35,6 @@ class HloCostAnalysis; // an error will be returned. class HostOffloader : public HloModulePass { public: - static constexpr absl::string_view kPipelineForwardTarget = "PipelineForward"; - static constexpr absl::string_view kPipelineBackwardTarget = - "PipelineBackward"; - explicit HostOffloader(int64_t host_memory_space_color) : kHostMemorySpaceColor(host_memory_space_color) {} ~HostOffloader() override = default; diff --git a/third_party/xla/xla/service/host_offloader_test.cc b/third_party/xla/xla/service/host_offloader_test.cc index b843063cc3f..55c0ab3f70b 100644 --- a/third_party/xla/xla/service/host_offloader_test.cc +++ b/third_party/xla/xla/service/host_offloader_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/host_memory_offload_annotations.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" @@ -64,8 +65,9 @@ class HostOffloaderTest : public HloTestBase { for (const HloComputation* computation : module->computations()) { for (const HloInstruction* instruction : computation->instructions()) { if (instruction->IsCustomCall( - {HostOffloader::kPipelineForwardTarget, - HostOffloader::kPipelineBackwardTarget})) { + {host_memory_offload_annotations::kMoveToHostCustomCallTarget, + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget})) { return true; } } |