diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2024-02-08 15:07:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2024-02-08 15:22:01 -0800 |
commit | 57b25471e0bdad7d86f569e8276b07b7e57e61b8 (patch) | |
tree | 58e7401219a24254b37b2ae01729311dd3b52074 | |
parent | 88d57160fe986d1a32d6697b58fc085c3bc97cf1 (diff) | |
download | tensorflow-57b25471e0bdad7d86f569e8276b07b7e57e61b8.tar.gz |
Add support for send/send-done ops in the auto-sharding pass.
PiperOrigin-RevId: 605439471
3 files changed, 62 insertions, 9 deletions
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 88cbe2bd58d..a3d017909d2 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1940,7 +1940,9 @@ void SetHloSharding(const HloInstructionSequence& sequence, const std::vector<HloInstruction*>& instructions = sequence.instructions(); for (HloInstruction* inst : instructions) { - if (inst->opcode() == HloOpcode::kOutfeed) { + if (inst->opcode() == HloOpcode::kOutfeed || + inst->opcode() == HloOpcode::kSend || + inst->opcode() == HloOpcode::kSendDone) { continue; } auto iter = strategy_map.find(inst); @@ -2085,11 +2087,17 @@ Status SetHloShardingPostProcessing( device_mesh, resharding_cache); } } - } else if (inst->opcode() == HloOpcode::kOutfeed) { - // Outfeed operand shardings are handled in downstream passes and so we - // ignore outfeed ops here. However, we need to ensure that outfeed ops - // which have user shardings have their shardings restored at the end. If - // not, this can lead to errors downstream in the spmd_partitioner pass. + } else if (inst->opcode() == HloOpcode::kOutfeed || + inst->opcode() == HloOpcode::kSendDone) { + // Outfeed: Outfeed operand shardings are handled in downstream passes and + // so we ignore outfeed ops here. However, we need to ensure that outfeed + // ops which have user shardings have their shardings restored at the + // end. If not, this can lead to errors downstream in the spmd_partitioner + // pass. + + // In the analysis itself, we use replicated strategies as a stand-in for + // the (expected) maximal sharding annotations that send-done ops usually + // have. Here we restore these maximal shardings if present. auto preserved_sharding_iter = preserve_shardings->find(inst->name()); if (preserved_sharding_iter != preserve_shardings->end()) { const auto& preserved_sharding = preserved_sharding_iter->second; @@ -2111,7 +2119,22 @@ Status SetHloShardingPostProcessing( inst->set_sharding(preserved_sharding.at(0)); } } - + continue; + } else if (inst->opcode() == HloOpcode::kSend) { + // In the analysis itself, we use replicated strategies as a stand-in for + // the (expected) maximal sharding annotations that send ops usually + // have. Here we restore these maximal shardings if present. + auto preserved_sharding_iter = preserve_shardings->find(inst->name()); + if (preserved_sharding_iter != preserve_shardings->end()) { + const auto& preserved_sharding = preserved_sharding_iter->second; + if (preserved_sharding.size() > 1) { + inst->set_sharding( + HloSharding::Tuple(inst->shape(), preserved_sharding)); + } else { + CHECK_EQ(preserved_sharding.size(), 1); + inst->set_sharding(preserved_sharding[0]); + } + } continue; } else { if (inst->shape().IsTuple()) { @@ -3211,7 +3234,9 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation( for (const HloComputation* computation : module->computations(execution_threads)) { for (const auto inst : computation->instructions()) { - if (inst->opcode() == HloOpcode::kOutfeed) { + if (inst->opcode() == HloOpcode::kOutfeed || + inst->opcode() == HloOpcode::kSend || + inst->opcode() == HloOpcode::kSendDone) { spmd::SaveShardingForInstruction(inst, /* save_for_copy_users */ false, preserve_shardings); @@ -3267,6 +3292,12 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation( continue; } + if (ins->opcode() == HloOpcode::kOutfeed || + ins->opcode() == HloOpcode::kSend || + ins->opcode() == HloOpcode::kSendDone) { + continue; + } + if (ins->has_sharding()) { module_is_changed |= true; ins->clear_sharding(); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 1147962c6d1..414cbff91ab 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -730,6 +730,27 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, strategy_group, replicated_penalty); break; } + case HloOpcode::kSend: { + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); + for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { + std::unique_ptr<StrategyGroup> child_strategies = + CreateLeafStrategyGroup(instruction_id, ins, strategy_map, + strategy_groups); + AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), cluster_env, + strategy_map, child_strategies, 0); + child_strategies->tuple_element_idx = i; + strategy_group->childs.push_back(std::move(child_strategies)); + } + break; + } + case HloOpcode::kSendDone: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, 0); + break; + } case HloOpcode::kAfterAll: { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 08aafe0300f..33bfcf92c47 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -1429,7 +1429,8 @@ void FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num, const Array<int64_t>& device_mesh, ReshardingCache* resharding_cache) { HloInstruction* operand = inst->mutable_operand(operand_num); - if (operand->opcode() == HloOpcode::kOutfeed) { + if (operand->opcode() == HloOpcode::kOutfeed || + operand->opcode() == HloOpcode::kSendDone) { return; } |