aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorA. Unique TensorFlower <gardener@tensorflow.org>2024-02-08 15:07:00 -0800
committerTensorFlower Gardener <gardener@tensorflow.org>2024-02-08 15:22:01 -0800
commit57b25471e0bdad7d86f569e8276b07b7e57e61b8 (patch)
tree58e7401219a24254b37b2ae01729311dd3b52074
parent88d57160fe986d1a32d6697b58fc085c3bc97cf1 (diff)
downloadtensorflow-57b25471e0bdad7d86f569e8276b07b7e57e61b8.tar.gz
Add support for send/send-done ops in the auto-sharding pass.
PiperOrigin-RevId: 605439471
-rw-r--r--third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc47
-rw-r--r--third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc21
-rw-r--r--third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc3
3 files changed, 62 insertions, 9 deletions
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
index 88cbe2bd58d..a3d017909d2 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
@@ -1940,7 +1940,9 @@ void SetHloSharding(const HloInstructionSequence& sequence,
const std::vector<HloInstruction*>& instructions = sequence.instructions();
for (HloInstruction* inst : instructions) {
- if (inst->opcode() == HloOpcode::kOutfeed) {
+ if (inst->opcode() == HloOpcode::kOutfeed ||
+ inst->opcode() == HloOpcode::kSend ||
+ inst->opcode() == HloOpcode::kSendDone) {
continue;
}
auto iter = strategy_map.find(inst);
@@ -2085,11 +2087,17 @@ Status SetHloShardingPostProcessing(
device_mesh, resharding_cache);
}
}
- } else if (inst->opcode() == HloOpcode::kOutfeed) {
- // Outfeed operand shardings are handled in downstream passes and so we
- // ignore outfeed ops here. However, we need to ensure that outfeed ops
- // which have user shardings have their shardings restored at the end. If
- // not, this can lead to errors downstream in the spmd_partitioner pass.
+ } else if (inst->opcode() == HloOpcode::kOutfeed ||
+ inst->opcode() == HloOpcode::kSendDone) {
+ // Outfeed: Outfeed operand shardings are handled in downstream passes and
+ // so we ignore outfeed ops here. However, we need to ensure that outfeed
+ // ops which have user shardings have their shardings restored at the
+ // end. If not, this can lead to errors downstream in the spmd_partitioner
+ // pass.
+
+ // In the analysis itself, we use replicated strategies as a stand-in for
+ // the (expected) maximal sharding annotations that send-done ops usually
+ // have. Here we restore these maximal shardings if present.
auto preserved_sharding_iter = preserve_shardings->find(inst->name());
if (preserved_sharding_iter != preserve_shardings->end()) {
const auto& preserved_sharding = preserved_sharding_iter->second;
@@ -2111,7 +2119,22 @@ Status SetHloShardingPostProcessing(
inst->set_sharding(preserved_sharding.at(0));
}
}
-
+ continue;
+ } else if (inst->opcode() == HloOpcode::kSend) {
+ // In the analysis itself, we use replicated strategies as a stand-in for
+ // the (expected) maximal sharding annotations that send ops usually
+ // have. Here we restore these maximal shardings if present.
+ auto preserved_sharding_iter = preserve_shardings->find(inst->name());
+ if (preserved_sharding_iter != preserve_shardings->end()) {
+ const auto& preserved_sharding = preserved_sharding_iter->second;
+ if (preserved_sharding.size() > 1) {
+ inst->set_sharding(
+ HloSharding::Tuple(inst->shape(), preserved_sharding));
+ } else {
+ CHECK_EQ(preserved_sharding.size(), 1);
+ inst->set_sharding(preserved_sharding[0]);
+ }
+ }
continue;
} else {
if (inst->shape().IsTuple()) {
@@ -3211,7 +3234,9 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation(
for (const HloComputation* computation :
module->computations(execution_threads)) {
for (const auto inst : computation->instructions()) {
- if (inst->opcode() == HloOpcode::kOutfeed) {
+ if (inst->opcode() == HloOpcode::kOutfeed ||
+ inst->opcode() == HloOpcode::kSend ||
+ inst->opcode() == HloOpcode::kSendDone) {
spmd::SaveShardingForInstruction(inst,
/* save_for_copy_users */ false,
preserve_shardings);
@@ -3267,6 +3292,12 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation(
continue;
}
+ if (ins->opcode() == HloOpcode::kOutfeed ||
+ ins->opcode() == HloOpcode::kSend ||
+ ins->opcode() == HloOpcode::kSendDone) {
+ continue;
+ }
+
if (ins->has_sharding()) {
module_is_changed |= true;
ins->clear_sharding();
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
index 1147962c6d1..414cbff91ab 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
@@ -730,6 +730,27 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence,
strategy_group, replicated_penalty);
break;
}
+ case HloOpcode::kSend: {
+ strategy_group = CreateTupleStrategyGroup(instruction_id);
+ strategy_group->childs.reserve(ins->shape().tuple_shapes_size());
+ for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) {
+ std::unique_ptr<StrategyGroup> child_strategies =
+ CreateLeafStrategyGroup(instruction_id, ins, strategy_map,
+ strategy_groups);
+ AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), cluster_env,
+ strategy_map, child_strategies, 0);
+ child_strategies->tuple_element_idx = i;
+ strategy_group->childs.push_back(std::move(child_strategies));
+ }
+ break;
+ }
+ case HloOpcode::kSendDone: {
+ strategy_group = CreateLeafStrategyGroup(instruction_id, ins,
+ strategy_map, strategy_groups);
+ AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map,
+ strategy_group, 0);
+ break;
+ }
case HloOpcode::kAfterAll: {
strategy_group = CreateLeafStrategyGroup(instruction_id, ins,
strategy_map, strategy_groups);
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
index 08aafe0300f..33bfcf92c47 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
@@ -1429,7 +1429,8 @@ void FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num,
const Array<int64_t>& device_mesh,
ReshardingCache* resharding_cache) {
HloInstruction* operand = inst->mutable_operand(operand_num);
- if (operand->opcode() == HloOpcode::kOutfeed) {
+ if (operand->opcode() == HloOpcode::kOutfeed ||
+ operand->opcode() == HloOpcode::kSendDone) {
return;
}