diff options
author | Kyle Lucke <klucke@google.com> | 2024-05-16 15:20:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2024-05-16 17:07:50 -0700 |
commit | 9c5e56cd91bb3206f9f9751bbfcc4426abca9cca (patch) | |
tree | 5debf674310cd60faec401d202764283e284db3d | |
parent | 6ec902447b28138b7fc45720cd9cba6c341a1bd5 (diff) | |
download | tensorflow-upstream-master.tar.gz |
Use absl::Status instead of xla::Status now that they're identical.upstream-master
PiperOrigin-RevId: 634545999
26 files changed, 262 insertions, 254 deletions
diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index ac9fa032c88..6bed3bd4263 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -351,13 +351,13 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { : hlo_to_profile_idx_(hlo_to_profile_idx), assigned_indices_(assigned_indices) {} - Status DefaultAction(HloInstruction* hlo_instruction) override { + absl::Status DefaultAction(HloInstruction* hlo_instruction) override { hlo_to_profile_idx_->insert( {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)}); return absl::OkStatus(); } - Status HandleCall(HloInstruction* call) override { + absl::Status HandleCall(HloInstruction* call) override { TF_RETURN_IF_ERROR(DefaultAction(call)); CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_, assigned_indices_); @@ -365,7 +365,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { return absl::OkStatus(); } // Recurse into "conditional" so we can profile inside of it. - Status HandleConditional(HloInstruction* conditional) override { + absl::Status HandleConditional(HloInstruction* conditional) override { TF_RETURN_IF_ERROR(DefaultAction(conditional)); CollectProfileCandidates candidates_for_true(hlo_to_profile_idx_, @@ -382,12 +382,16 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { } // Skip constants, there is nothing to profile. - Status HandleConstant(HloInstruction*) override { return absl::OkStatus(); } + absl::Status HandleConstant(HloInstruction*) override { + return absl::OkStatus(); + } // Skip parameters, they are a simple load. - Status HandleParameter(HloInstruction*) override { return absl::OkStatus(); } + absl::Status HandleParameter(HloInstruction*) override { + return absl::OkStatus(); + } // It is important to recurse for "while" or else we risk overly coarse // profiling information. - Status HandleWhile(HloInstruction* xla_while) override { + absl::Status HandleWhile(HloInstruction* xla_while) override { TF_RETURN_IF_ERROR(DefaultAction(xla_while)); CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_, @@ -423,7 +427,7 @@ void AddHloVerifier(HloPassPipeline* pipeline, HloVerifierOpts&& opts = {}, } // namespace -Status CpuCompiler::RunHloPassesThroughLayoutAssn( +absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( HloModule* module, bool is_aot_compile, LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile) { const DebugOptions& debug_options = module->config().debug_options(); @@ -696,7 +700,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( return pipeline.Run(module).status(); } -Status CpuCompiler::RunHloPassesAfterLayoutAssn( +absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn( HloModule* module, bool is_aot_compile, LLVMTargetMachineFeatures* target_machine_features, const CompileOptions& compile_options, bool is_mlir_compile) { @@ -799,10 +803,10 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( return pipeline.Run(module).status(); } -Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, - llvm::TargetMachine* target_machine, - const CompileOptions& compile_options, - bool is_mlir_compile) { +absl::Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine, + const CompileOptions& compile_options, + bool is_mlir_compile) { LLVMTargetMachineFeatures target_machine_features(target_machine); TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn( module, is_aot_compile, &target_machine_features, is_mlir_compile)); @@ -870,7 +874,7 @@ std::pair<LLVMCompiler::ModuleHook, LLVMCompiler::ModuleHook> GetIRModuleHooks( }}; } -Status VerifyLlvmModule(const llvm::Module& llvm_module) { +absl::Status VerifyLlvmModule(const llvm::Module& llvm_module) { XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier"); std::string err; @@ -885,7 +889,7 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) { return absl::OkStatus(); } -Status CreateHloProfilingArtifacts( +absl::Status CreateHloProfilingArtifacts( const HloModule& module, absl::flat_hash_map<const HloInstruction*, int64_t>* instruction_to_profile_idx, @@ -1404,7 +1408,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, // Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run // the pre-optimization IR dump hook before returning. { - Status verify_status = VerifyLlvmModule(*llvm_module); + absl::Status verify_status = VerifyLlvmModule(*llvm_module); if (!verify_status.ok() && pre_optimization_ir_hook) { pre_optimization_ir_hook(*llvm_module); } diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.h b/third_party/xla/xla/service/cpu/cpu_compiler.h index 69ba7f498f3..3cc54c4a616 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.h +++ b/third_party/xla/xla/service/cpu/cpu_compiler.h @@ -195,19 +195,19 @@ class CpuCompiler : public LLVMCompiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* module, bool is_aot_compile, - llvm::TargetMachine* target_machine, - const CompileOptions& compile_options, - bool is_mlir_compile = false); + absl::Status RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine, + const CompileOptions& compile_options, + bool is_mlir_compile = false); // Runs HLO passes up to and including layout assignment. - Status RunHloPassesThroughLayoutAssn( + absl::Status RunHloPassesThroughLayoutAssn( HloModule* module, bool /*is_aot_compile*/, LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile = false); // Runs HLO passes after layout assignment. - Status RunHloPassesAfterLayoutAssn( + absl::Status RunHloPassesAfterLayoutAssn( HloModule* module, bool is_aot_compile, LLVMTargetMachineFeatures* target_machine_features, const CompileOptions& compile_options, bool is_mlir_compile); diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index 01a5007a3de..02fb083554c 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -180,7 +180,7 @@ CpuExecutable::CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator, return std::move(buffers); } -Status CpuExecutable::ExecuteComputeFunction( +absl::Status CpuExecutable::ExecuteComputeFunction( const ExecutableRunOptions* run_options, absl::Span<MaybeOwningDeviceMemory const> buffers, HloExecutionProfile* hlo_execution_profile) { @@ -399,7 +399,7 @@ absl::StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream( std::shared_ptr<std::vector<MaybeOwningDeviceMemory>> task_buffers; HloExecutionProfile* hlo_execution_profile; - Status operator()() { + absl::Status operator()() { return executable->ExecuteComputeFunction( &run_options.run_options(), *task_buffers, hlo_execution_profile); } diff --git a/third_party/xla/xla/service/cpu/cpu_executable.h b/third_party/xla/xla/service/cpu/cpu_executable.h index 2dc7e2d9700..5cf8aa83357 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.h +++ b/third_party/xla/xla/service/cpu/cpu_executable.h @@ -73,7 +73,7 @@ class CpuExecutable : public Executable { // Calls the generated function performing the computation with the given // arguments using the supplied buffers. - Status ExecuteComputeFunction( + absl::Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, absl::Span<MaybeOwningDeviceMemory const> buffers, HloExecutionProfile* hlo_execution_profile); diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc index 35371ea9463..48f5ae59bfe 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc @@ -120,7 +120,7 @@ static bool OperandsAndResultMustHaveRowMajorLayout( return false; } -Status CpuLayoutAssignment::AddBackendConstraints( +absl::Status CpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { ShouldMakeOperandColMajorCache cache; diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.h b/third_party/xla/xla/service/cpu/cpu_layout_assignment.h index 35ecde418bc..26e155458f5 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.h +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.h @@ -37,7 +37,7 @@ class CpuLayoutAssignment : public LayoutAssignment { ~CpuLayoutAssignment() override {} protected: - Status AddBackendConstraints(LayoutConstraints* constraints) override; + absl::Status AddBackendConstraints(LayoutConstraints* constraints) override; const TargetMachineFeatures& target_machine_features_; }; diff --git a/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc b/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc index 27e62fc1af7..fd68f299635 100644 --- a/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc +++ b/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc @@ -43,19 +43,19 @@ CpuTransferManager::CpuTransferManager() : GenericTransferManager(se::host::kHostPlatformId, /*pointer_size=*/sizeof(void*)) {} -Status CpuTransferManager::TransferLiteralToInfeed( +absl::Status CpuTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const LiteralSlice& literal) { return TransferLiteralToInfeedOnCpu(executor->device_ordinal(), literal); } -Status CpuTransferManager::TransferLiteralFromOutfeed( +absl::Status CpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, MutableBorrowingLiteral literal) { return TransferLiteralFromOutfeedOnCpu(executor->device_ordinal(), literal); } -Status CpuTransferManager::ReadDynamicShapes(se::Stream* stream, - const ShapedBuffer* device_buffer, - Shape* device_shape) { +absl::Status CpuTransferManager::ReadDynamicShapes( + se::Stream* stream, const ShapedBuffer* device_buffer, + Shape* device_shape) { if (stream != nullptr) { // When a stream is presented, respect the stream dependency. return TransferManager::ReadDynamicShapes(stream, device_buffer, diff --git a/third_party/xla/xla/service/cpu/cpu_transfer_manager.h b/third_party/xla/xla/service/cpu/cpu_transfer_manager.h index 9cdf0478a47..ed7f0fe3f7b 100644 --- a/third_party/xla/xla/service/cpu/cpu_transfer_manager.h +++ b/third_party/xla/xla/service/cpu/cpu_transfer_manager.h @@ -37,10 +37,10 @@ class CpuTransferManager : public GenericTransferManager { CpuTransferManager(); ~CpuTransferManager() override {} - Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const LiteralSlice& literal) override; - Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - MutableBorrowingLiteral literal) override; + absl::Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const LiteralSlice& literal) override; + absl::Status TransferLiteralFromOutfeed( + se::StreamExecutor* executor, MutableBorrowingLiteral literal) override; bool CanShapedBufferBeAccessedNow( se::StreamExecutor* executor, @@ -54,9 +54,9 @@ class CpuTransferManager : public GenericTransferManager { return true; } - Status ReadDynamicShapes(se::Stream* stream, - const ShapedBuffer* device_buffer, - Shape* device_shape) override; + absl::Status ReadDynamicShapes(se::Stream* stream, + const ShapedBuffer* device_buffer, + Shape* device_shape) override; private: bool PackSubbyteTypes() const override { return true; } diff --git a/third_party/xla/xla/service/cpu/cpu_xfeed.cc b/third_party/xla/xla/service/cpu/cpu_xfeed.cc index b6fea6cb0d3..8b4ea238f51 100644 --- a/third_party/xla/xla/service/cpu/cpu_xfeed.cc +++ b/third_party/xla/xla/service/cpu/cpu_xfeed.cc @@ -102,8 +102,8 @@ absl::StatusOr<cpu::runtime::XfeedBuffer*> TransferBufferToInfeedInternal( return queued_buffer; } -Status TransferBufferToInfeed(int device_ordinal, int64_t size, - const void* source) { +absl::Status TransferBufferToInfeed(int device_ordinal, int64_t size, + const void* source) { TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer, TransferBufferToInfeedInternal(size, source)); @@ -175,8 +175,8 @@ absl::StatusOr<Shape> TransferTupleBuffersFromOutfeed( } } // namespace -Status TransferLiteralToInfeedOnCpu(int device_ordinal, - const LiteralSlice& literal) { +absl::Status TransferLiteralToInfeedOnCpu(int device_ordinal, + const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); @@ -221,8 +221,8 @@ Status TransferLiteralToInfeedOnCpu(int device_ordinal, return OkStatus(); } -Status TransferLiteralFromOutfeedOnCpu(int device_ordinal, - MutableBorrowingLiteral literal) { +absl::Status TransferLiteralFromOutfeedOnCpu(int device_ordinal, + MutableBorrowingLiteral literal) { if (!literal.shape().IsTuple()) { int64_t size = cpu::runtime::GetByteSizeRequirement(literal.shape(), sizeof(void*)); @@ -275,7 +275,7 @@ Status TransferLiteralFromOutfeedOnCpu(int device_ordinal, return OkStatus(); } -Status ReadDynamicShapesOnCpu( +absl::Status ReadDynamicShapesOnCpu( const ShapedBuffer* device_buffer, Shape* device_shape, HloCostAnalysis::ShapeSizeFunction shape_size_fn) { TF_RET_CHECK(device_shape->is_dynamic()); diff --git a/third_party/xla/xla/service/cpu/cpu_xfeed.h b/third_party/xla/xla/service/cpu/cpu_xfeed.h index 26512839d50..7d09ec862d8 100644 --- a/third_party/xla/xla/service/cpu/cpu_xfeed.h +++ b/third_party/xla/xla/service/cpu/cpu_xfeed.h @@ -30,17 +30,17 @@ limitations under the License. namespace xla { // Helper function to transfers to infeed on CPU. -Status TransferLiteralToInfeedOnCpu(int device_ordinal, - const LiteralSlice& literal); +absl::Status TransferLiteralToInfeedOnCpu(int device_ordinal, + const LiteralSlice& literal); // Helper function to transfers from outfeed on CPU. -Status TransferLiteralFromOutfeedOnCpu(int device_ordinal, - MutableBorrowingLiteral literal); +absl::Status TransferLiteralFromOutfeedOnCpu(int device_ordinal, + MutableBorrowingLiteral literal); // Helper function to retrieve dynamic shape on CPU. -Status ReadDynamicShapesOnCpu(const ShapedBuffer* device_buffer, - Shape* device_shape, - HloCostAnalysis::ShapeSizeFunction shape_size_fn); +absl::Status ReadDynamicShapesOnCpu( + const ShapedBuffer* device_buffer, Shape* device_shape, + HloCostAnalysis::ShapeSizeFunction shape_size_fn); } // namespace xla #endif // XLA_SERVICE_CPU_CPU_XFEED_H_ diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.cc b/third_party/xla/xla/service/cpu/dot_op_emitter.cc index c6bad1e54ec..9548b9919ac 100644 --- a/third_party/xla/xla/service/cpu/dot_op_emitter.cc +++ b/third_party/xla/xla/service/cpu/dot_op_emitter.cc @@ -134,21 +134,21 @@ class DotOpEmitter { const TargetMachineFeatures& target_machine_features); // Emits the IR to perform the dot operation. - Status Emit(); + absl::Status Emit(); // Emits the IR to perform the batch dot operation. - Status EmitBatch(); + absl::Status EmitBatch(); private: // Emits instructions to perform a scalar dot product (a multiply of the // LHS and RHS) and store the results in the target. - Status EmitScalarDot(); + absl::Status EmitScalarDot(); // Emits a call to the CPU runtime to perform the matrix multiply. - Status EmitCallToRuntime(); + absl::Status EmitCallToRuntime(); // Emits a call to the CPU runtime to perform the batch matrix multiply. - Status EmitCallToBatchRuntime(); + absl::Status EmitCallToBatchRuntime(); // Represents the dimensions of a matrix-matrix multiply operation. struct MatMultDims { @@ -192,7 +192,7 @@ class DotOpEmitter { void EmitTiledLlvmIrGemm(); // Lowers the dot operation through MLIR's linalg.matmul. - Status EmitLinalgMatmul(); + absl::Status EmitLinalgMatmul(); // Lowers the dot operation as a naive nested loop that computes the result // one element at a time. @@ -264,7 +264,7 @@ DotOpEmitter::DotOpEmitter( hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} -Status DotOpEmitter::EmitLinalgMatmul() { +absl::Status DotOpEmitter::EmitLinalgMatmul() { Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape}; llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(), rhs_array_.GetBasePointer()}; @@ -529,7 +529,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() { } } -Status DotOpEmitter::Emit() { +absl::Status DotOpEmitter::Emit() { // The dot operation performs a sum of products over dimension 0 of the left // hand side operand and dimension 1 of the right hand side operand. // @@ -584,7 +584,7 @@ Status DotOpEmitter::Emit() { } } -Status DotOpEmitter::EmitBatch() { +absl::Status DotOpEmitter::EmitBatch() { // The dot operation performs a sum of products over dimension 0 of the left // hand side operand and dimension 1 of the right hand side operand. // @@ -756,7 +756,7 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() { b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); } -Status DotOpEmitter::EmitScalarDot() { +absl::Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. llvm::Value* result; // Use the same index_type for all tensor accesses in the same kernel. @@ -791,7 +791,7 @@ Status DotOpEmitter::EmitScalarDot() { return OkStatus(); } -Status DotOpEmitter::EmitCallToRuntime() { +absl::Status DotOpEmitter::EmitCallToRuntime() { // The signature of the Eigen runtime matmul function is: // // (void)(void* run_options, float* out, float* lhs, float* rhs, @@ -902,7 +902,7 @@ Status DotOpEmitter::EmitCallToRuntime() { return OkStatus(); } -Status DotOpEmitter::EmitCallToBatchRuntime() { +absl::Status DotOpEmitter::EmitCallToBatchRuntime() { // The signature of the runtime batch matmul function is: // // (void)(void* run_options, float* out, float* lhs, float* rhs, @@ -1218,7 +1218,7 @@ DotImplementationStrategy GetDotImplementationStrategy( return DotImplementationStrategy::kNaiveLlvmIr; } -Status EmitNonBatchDotOperation( +absl::Status EmitNonBatchDotOperation( DotInfo dot_info, std::string hlo_name, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, @@ -1270,7 +1270,8 @@ llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b, std::move(new_shape)); } -Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) { +absl::Status ValidateDotDimensionNumbers( + const DotDimensionNumbers& dim_numbers) { // Checks some invariants that do not hold in general, but DotDecomposer // should have established for us. This is just a debugging aid. TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1); @@ -1357,7 +1358,7 @@ bool PotentiallyImplementedAsEigenMatmul( return impl_strategy == DotImplementationStrategy::kEigen; } -Status EmitBatchDotOperation( +absl::Status EmitBatchDotOperation( const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, @@ -1481,15 +1482,13 @@ bool DotOperandsAndResultMustHaveRowMajorLayout( impl_strategy == DotImplementationStrategy::kEigen; } -Status EmitDotOperation(const HloInstruction& dot, - const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) { +absl::Status EmitDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { // This routine assumes that the dot operation is not in a parallelized // enclosing computation. CHECK(dot.parent() diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.h b/third_party/xla/xla/service/cpu/dot_op_emitter.h index 9d4f5617960..58e3afee737 100644 --- a/third_party/xla/xla/service/cpu/dot_op_emitter.h +++ b/third_party/xla/xla/service/cpu/dot_op_emitter.h @@ -57,15 +57,13 @@ std::optional<int64_t> ProfitableToMakeDotOperandColumnMajor( // dimensions as the result, and the result is computed as `addend_array` + // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported // for Matrix-vector products. -Status EmitDotOperation(const HloInstruction& dot, - const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features); +absl::Status EmitDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index dff2ee5e7cd..d6a5a832d03 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -126,7 +126,7 @@ IrEmitter::IrEmitter(mlir::MLIRContext* mlir_context, target_machine_features_(*target_machine_features), emit_code_for_msan_(emit_code_for_msan) { b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_)); - Status s = GatherComputationsByAllocationType( + absl::Status s = GatherComputationsByAllocationType( &hlo_module, &thread_local_computations_, &global_computations_); absl::c_sort(thread_local_computations_); absl::c_sort(global_computations_); @@ -255,7 +255,7 @@ void IrEmitter::InitializeIrFunction(const std::string& function_name) { module_, &b_, num_dynamic_loop_bounds_); } -Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { +absl::Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { VLOG(2) << "HandleBitcast: " << bitcast->ToString(); emitted_value_[bitcast] = GetEmittedValueFor(bitcast->operand(0)); return OkStatus(); @@ -277,7 +277,7 @@ llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { return result_global; } -Status IrEmitter::EmitConstantGlobals() { +absl::Status IrEmitter::EmitConstantGlobals() { for (const BufferAllocation& allocation : assignment_.Allocations()) { if (!allocation.is_constant()) { continue; @@ -300,14 +300,14 @@ Status IrEmitter::EmitConstantGlobals() { return OkStatus(); } -Status IrEmitter::HandleConstant(HloInstruction* constant) { +absl::Status IrEmitter::HandleConstant(HloInstruction* constant) { VLOG(2) << "HandleConstant: " << constant->ToString(); // IrEmitter::EmitConstantGlobals has already taken care of emitting the body // of the constant. return EmitTargetAddressForOp(constant); } -Status IrEmitter::HandleCopy(HloInstruction* copy) { +absl::Status IrEmitter::HandleCopy(HloInstruction* copy) { if (copy->shape().IsTuple() || (copy->shape().IsArray() && LayoutUtil::Equal(copy->operand(0)->shape().layout(), @@ -382,7 +382,8 @@ void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, } } -Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { +absl::Status IrEmitter::HandleGetTupleElement( + HloInstruction* get_tuple_element) { // A tuple is an array of pointers, one for each operand. Each pointer points // to the output buffer of its corresponding operand. A GetTupleElement // instruction forwards a pointer to the tuple element buffer at the given @@ -395,13 +396,13 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { return OkStatus(); } -Status IrEmitter::HandleSelect(HloInstruction* select) { +absl::Status IrEmitter::HandleSelect(HloInstruction* select) { auto pred = select->operand(0); TF_RET_CHECK(pred->shape().element_type() == PRED); return DefaultAction(select); } -Status IrEmitter::HandleInfeed(HloInstruction* instruction) { +absl::Status IrEmitter::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction); VLOG(2) << "HandleInfeed: " << infeed->ToString(); @@ -461,8 +462,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { return OkStatus(); } -Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, - llvm::Value* program_buffer_address) { +absl::Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, + llvm::Value* program_buffer_address) { int64_t length = ByteSizeOf(shape); if (length < 0 || length > std::numeric_limits<int32_t>::max()) { return InvalidArgument( @@ -526,7 +527,7 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, return OkStatus(); } -Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { +absl::Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { // Outfeed produces no useful result, but it does return a token[] that can be // threaded through to other side effecting operations to ensure ordering. In // the IR emitter we treat this token as a normal u8[] and thus need to insert @@ -556,7 +557,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { return OkStatus(); } -Status IrEmitter::HandleSort(HloInstruction* hlo) { +absl::Status IrEmitter::HandleSort(HloInstruction* hlo) { const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); Shape keys_shape = sort->keys()->shape(); @@ -651,7 +652,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { return OkStatus(); } -Status IrEmitter::HandleTuple(HloInstruction* tuple) { +absl::Status IrEmitter::HandleTuple(HloInstruction* tuple) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple)); llvm::SmallVector<llvm::Value*> base_ptrs; for (auto operand : tuple->operands()) { @@ -661,7 +662,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { return OkStatus(); } -Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { +absl::Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { // Pseudo code for reduce window: // // for (coordinates O in the output) @@ -677,12 +678,13 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { // that works. bool saved_allow_reassociation = allow_reassociation_; allow_reassociation_ = true; - Status status = DefaultAction(reduce_window); + absl::Status status = DefaultAction(reduce_window); allow_reassociation_ = saved_allow_reassociation; return status; } -Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { +absl::Status IrEmitter::HandleSelectAndScatter( + HloInstruction* select_and_scatter) { CHECK_EQ(select_and_scatter->operand_count(), 3); const auto operand = select_and_scatter->operand(0); const auto source = select_and_scatter->operand(1); @@ -864,7 +866,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { return OkStatus(); } -Status IrEmitter::HandleDot(HloInstruction* dot) { +absl::Status IrEmitter::HandleDot(HloInstruction* dot) { auto lhs = dot->operand(0); auto rhs = dot->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( @@ -900,7 +902,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { hlo_module_config_, target_machine_features_); } -Status IrEmitter::HandleConvolution(HloInstruction* convolution) { +absl::Status IrEmitter::HandleConvolution(HloInstruction* convolution) { auto lhs = convolution->operand(0); auto rhs = convolution->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( @@ -1085,7 +1087,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { return DefaultAction(convolution); } -Status IrEmitter::HandleFft(HloInstruction* fft) { +absl::Status IrEmitter::HandleFft(HloInstruction* fft) { auto operand = fft->operand(0); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*fft, /*operands=*/{operand}, @@ -1137,7 +1139,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { return OkStatus(); } -Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) { +absl::Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) { // When there is a single replica, a cross replica sum is the identity // function, and the buffer assignment expects a copy. // @@ -1196,7 +1198,7 @@ static bool DataTypeIsSupportedByReduceScatter(PrimitiveType datatype) { } } -Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { +absl::Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { CHECK_GE(crs->operand_count(), 1); PrimitiveType datatype = crs->operand(0)->shape().element_type(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); @@ -1280,7 +1282,7 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { return OkStatus(); } -Status IrEmitter::HandleAllReduce(HloInstruction* crs) { +absl::Status IrEmitter::HandleAllReduce(HloInstruction* crs) { if (hlo_module_config_.replica_count() == 1 && hlo_module_config_.num_partitions() == 1) { return HandleAllReduceSingleReplica(crs); @@ -1288,7 +1290,7 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { return HandleAllReduceMultipleReplica(crs); } -Status IrEmitter::HandleReduceScatter(HloInstruction* rs) { +absl::Status IrEmitter::HandleReduceScatter(HloInstruction* rs) { CHECK_EQ(rs->operand_count(), 1); PrimitiveType datatype = rs->operand(0)->shape().element_type(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rs)); @@ -1344,7 +1346,7 @@ Status IrEmitter::HandleReduceScatter(HloInstruction* rs) { return OkStatus(); } -Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { +absl::Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { auto* instr = Cast<HloAllToAllInstruction>(instruction); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction)); CHECK(!instr->split_dimension() && instr->shape().IsTuple()) @@ -1400,7 +1402,7 @@ Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { return OkStatus(); } -Status IrEmitter::HandleAllGather(HloInstruction* instruction) { +absl::Status IrEmitter::HandleAllGather(HloInstruction* instruction) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction)); std::string replica_groups = @@ -1451,7 +1453,7 @@ Status IrEmitter::HandleAllGather(HloInstruction* instruction) { return OkStatus(); } -Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { +absl::Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { auto* instr = Cast<HloCollectivePermuteInstruction>(crs); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr)); std::string source_target_pairs = absl::StrJoin( @@ -1488,7 +1490,7 @@ Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { return OkStatus(); } -Status IrEmitter::HandlePartitionId(HloInstruction* hlo) { +absl::Status IrEmitter::HandlePartitionId(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, assignment_.GetUniqueSlice(hlo, {})); @@ -1500,7 +1502,7 @@ Status IrEmitter::HandlePartitionId(HloInstruction* hlo) { return OkStatus(); } -Status IrEmitter::HandleReplicaId(HloInstruction* hlo) { +absl::Status IrEmitter::HandleReplicaId(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, assignment_.GetUniqueSlice(hlo, {})); @@ -1512,7 +1514,7 @@ Status IrEmitter::HandleReplicaId(HloInstruction* hlo) { return OkStatus(); } -Status IrEmitter::HandleParameter(HloInstruction* parameter) { +absl::Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); return EmitTargetAddressForOp(parameter); } @@ -1998,7 +2000,7 @@ absl::StatusOr<bool> IrEmitter::EmitVectorizedReduce( return true; } -Status IrEmitter::HandleReduce(HloInstruction* reduce) { +absl::Status IrEmitter::HandleReduce(HloInstruction* reduce) { auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); absl::Span<const int64_t> dimensions(reduce->dimensions()); @@ -2027,21 +2029,21 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { return DefaultAction(reduce); } -Status IrEmitter::HandleSend(HloInstruction* send) { +absl::Status IrEmitter::HandleSend(HloInstruction* send) { // TODO(b/33942983): Support Send/Recv on CPU. return Unimplemented("Send is not implemented on CPU."); } -Status IrEmitter::HandleSendDone(HloInstruction* send_done) { +absl::Status IrEmitter::HandleSendDone(HloInstruction* send_done) { // TODO(b/33942983): Support Send/Recv on CPU. return Unimplemented("Send-done is not implemented on CPU."); } -Status IrEmitter::HandleScatter(HloInstruction*) { +absl::Status IrEmitter::HandleScatter(HloInstruction*) { return Unimplemented("Scatter is not implemented on CPUs."); } -Status IrEmitter::HandleSlice(HloInstruction* slice) { +absl::Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); // The code below emits a sequential loop nest. For the parallel backend, use @@ -2178,7 +2180,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { return OkStatus(); } -Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) { +absl::Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) { if (ShapeUtil::IsScalar(dynamic_slice->shape())) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice)); return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice); @@ -2186,7 +2188,7 @@ Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) { return DefaultAction(dynamic_slice); } -Status IrEmitter::HandleDynamicUpdateSlice( +absl::Status IrEmitter::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { auto update = dynamic_update_slice->operand(1); if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { @@ -2203,17 +2205,17 @@ Status IrEmitter::HandleDynamicUpdateSlice( return DefaultAction(dynamic_update_slice); } -Status IrEmitter::HandleRecv(HloInstruction* recv) { +absl::Status IrEmitter::HandleRecv(HloInstruction* recv) { // TODO(b/33942983): Support Send/Recv on CPU. return Unimplemented("Recv is not implemented on CPU."); } -Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) { +absl::Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) { // TODO(b/33942983): Support Send/Recv on CPU. return Unimplemented("Recv-done is not implemented on CPU."); } -Status IrEmitter::HandlePad(HloInstruction* pad) { +absl::Status IrEmitter::HandlePad(HloInstruction* pad) { // CPU backend does not properly handle negative padding but this is ok // because negative padding should be removed by the algebraic simplifier. for (auto& padding_dimension : pad->padding_config().dimensions()) { @@ -2272,7 +2274,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { return OkStatus(); } -Status IrEmitter::HandleFusion(HloInstruction* fusion) { +absl::Status IrEmitter::HandleFusion(HloInstruction* fusion) { auto* root = fusion->fused_expression_root(); if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; @@ -2327,7 +2329,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { } } -Status IrEmitter::HandleCall(HloInstruction* call) { +absl::Status IrEmitter::HandleCall(HloInstruction* call) { HloComputation* computation = call->to_apply(); llvm::Function* call_ir_function = FindOrDie( emitted_functions_, ComputationToEmit{computation, allow_reassociation_}); @@ -2372,7 +2374,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { return OkStatus(); } -Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { +absl::Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); std::vector<llvm::Value*> dynamic_dims; int32_t raw_data_size = @@ -2400,7 +2402,7 @@ Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { // dest_index = delinearize(linearize(i, dynamic_dim), static_dim) // dest[dest_index] = source[i] auto loop_body_emitter = - [&](const llvm_ir::IrArray::Index& array_index) -> Status { + [&](const llvm_ir::IrArray::Index& array_index) -> absl::Status { llvm::Value* source_element = GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(array_index, &b_); llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_); @@ -2415,7 +2417,7 @@ Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { .EmitLoop(IrName(hlo)); } -Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { +absl::Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, @@ -2463,7 +2465,7 @@ Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { // source_index = delinearize(inearize(i, dynamic_dim), static_dim) // dest[i] = source[source_index] auto loop_body_emitter = - [&](const llvm_ir::IrArray::Index& array_index) -> Status { + [&](const llvm_ir::IrArray::Index& array_index) -> absl::Status { llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_); llvm_ir::IrArray::Index source_index(linear_index, input_shape, &b_); llvm::Value* source_element = @@ -2480,7 +2482,7 @@ Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { return OkStatus(); } -Status IrEmitter::HandleTopK(HloInstruction* hlo) { +absl::Status IrEmitter::HandleTopK(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); const HloInstruction* input = hlo->operand(0); const int64_t k = hlo->shape().tuple_shapes(0).dimensions().back(); @@ -2520,8 +2522,8 @@ Status IrEmitter::HandleTopK(HloInstruction* hlo) { } #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) -Status IrEmitter::HandleOneDnnMatMulCalls(HloInstruction* custom_call, - std::string runtime_symbol_name) { +absl::Status IrEmitter::HandleOneDnnMatMulCalls( + HloInstruction* custom_call, std::string runtime_symbol_name) { // We would like to emit LLVM IR for the following function call // custom_call_target(void* result, void** args) // args can be thought of an array of pointers allocated on the stack, @@ -2650,7 +2652,7 @@ Status IrEmitter::HandleOneDnnMatMulCalls(HloInstruction* custom_call, return OkStatus(); } -Status IrEmitter::HandleOneDnnLayerNorm(HloInstruction* custom_call) { +absl::Status IrEmitter::HandleOneDnnLayerNorm(HloInstruction* custom_call) { // args[0]: ptr to nargs // args[1]: ptr to ExecutableRunOptions // args[2]: ptr to OneDnnLayerNormConfig @@ -2727,7 +2729,7 @@ Status IrEmitter::HandleOneDnnLayerNorm(HloInstruction* custom_call) { return OkStatus(); } -Status IrEmitter::HandleOneDnnSoftmax(HloInstruction* custom_call) { +absl::Status IrEmitter::HandleOneDnnSoftmax(HloInstruction* custom_call) { auto input = custom_call->operand(0); llvm_ir::IrArray input_array(GetIrArrayFor(input)); auto input_stack_alloca = GetAllocaAndEmitMemrefInfo(b_, input_array); @@ -2751,7 +2753,7 @@ Status IrEmitter::HandleOneDnnSoftmax(HloInstruction* custom_call) { } #endif // INTEL_MKL && ENABLE_ONEDNN_V3 -Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { +absl::Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { if (custom_call->custom_call_target() == "PadToStatic") { return HandlePadToStatic(custom_call); } @@ -2856,7 +2858,7 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { return OkStatus(); } -Status IrEmitter::HandleWhile(HloInstruction* xla_while) { +absl::Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Precondition: Condition computation must return a scalar bool. HloComputation* condition = xla_while->while_condition(); TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) && @@ -2867,7 +2869,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( xla_while->shape(), [this, &xla_while](const Shape& /*subshape*/, - const ShapeIndex& index) -> Status { + const ShapeIndex& index) -> absl::Status { auto check = [this](const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index) -> absl::Status { const BufferAllocation::Slice slice_a = @@ -3236,7 +3238,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, } } -Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { +absl::Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { absl::Span<HloInstruction* const> operands(concatenate->operands()); std::string failure_reason; TF_ASSIGN_OR_RETURN( @@ -3253,7 +3255,7 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { return DefaultAction(concatenate); } -Status IrEmitter::HandleConditional(HloInstruction* conditional) { +absl::Status IrEmitter::HandleConditional(HloInstruction* conditional) { auto branch_index = conditional->operand(0); int num_branches = conditional->branch_count(); TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) && @@ -3362,25 +3364,25 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { return OkStatus(); } -Status IrEmitter::HandleAfterAll(HloInstruction* after_all) { +absl::Status IrEmitter::HandleAfterAll(HloInstruction* after_all) { TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0); // No code to generate, but we need to emit an address for book-keeping. TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all)); return OkStatus(); } -Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { +absl::Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { // AddDedendency just forwards its zero-th operand. emitted_value_[add_dependency] = GetEmittedValueFor(add_dependency->operand(0)); return OkStatus(); } -Status IrEmitter::HandleRng(HloInstruction* rng) { +absl::Status IrEmitter::HandleRng(HloInstruction* rng) { return Unimplemented("Rng should be expanded for CPU."); } -Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) { +absl::Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) { VLOG(2) << "RngGetAndUpdateState: " << rng_state->ToString(); llvm::Value* old_state = llvm_ir::RngGetAndUpdateState( Cast<HloRngGetAndUpdateStateInstruction>(rng_state)->delta(), module_, @@ -3398,7 +3400,7 @@ Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) { return OkStatus(); } -Status IrEmitter::FinishVisit(HloInstruction* root) { +absl::Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding // the value. If the root is a constant or parameter, we perform a memcpy from @@ -3591,7 +3593,7 @@ bool IsHloVeryCheap(const HloInstruction* hlo) { } } // namespace -Status IrEmitter::Preprocess(HloInstruction* hlo) { +absl::Status IrEmitter::Preprocess(HloInstruction* hlo) { VLOG(3) << "Visiting: " << hlo->ToString(); // When profiling is enabled, trace the same HLOs that the profiler does. if (instruction_to_profile_idx_.count(hlo) || @@ -3604,7 +3606,7 @@ Status IrEmitter::Preprocess(HloInstruction* hlo) { return OkStatus(); } -Status IrEmitter::Postprocess(HloInstruction* hlo) { +absl::Status IrEmitter::Postprocess(HloInstruction* hlo) { if (auto* prof_counter = GetProfileCounterFor(*hlo)) { profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter); } @@ -3764,7 +3766,7 @@ llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice, } } -Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { +absl::Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { const Shape& target_shape = op->shape(); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, assignment_.GetUniqueTopLevelSlice(op)); @@ -3774,13 +3776,13 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { return OkStatus(); } -Status IrEmitter::EmitTargetElementLoop( +absl::Status IrEmitter::EmitTargetElementLoop( HloInstruction* target_op, const llvm_ir::ElementGenerator& element_generator) { return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator); } -Status IrEmitter::EmitTargetElementLoop( +absl::Status IrEmitter::EmitTargetElementLoop( HloInstruction* target_op, absl::string_view desc, const llvm_ir::ElementGenerator& element_generator) { VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); @@ -3833,8 +3835,8 @@ Status IrEmitter::EmitTargetElementLoop( return OkStatus(); } -Status IrEmitter::EmitMemcpy(const HloInstruction& source, - const HloInstruction& destination) { +absl::Status IrEmitter::EmitMemcpy(const HloInstruction& source, + const HloInstruction& destination) { llvm::Value* source_value = GetEmittedValueFor(&source); llvm::Value* destination_value = GetEmittedValueFor(&destination); int64_t source_size = ByteSizeOf(source.shape()); @@ -3844,7 +3846,7 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, return OkStatus(); } -Status IrEmitter::ElementTypesSameAndSupported( +absl::Status IrEmitter::ElementTypesSameAndSupported( const HloInstruction& instruction, absl::Span<const HloInstruction* const> operands, absl::Span<const PrimitiveType> supported_types) { @@ -3863,7 +3865,7 @@ Status IrEmitter::ElementTypesSameAndSupported( return OkStatus(); } -Status IrEmitter::DefaultAction(HloInstruction* hlo) { +absl::Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index bc99acd4f2b..21a37181011 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -124,7 +124,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::IRBuilder<>* builder() { return &b_; } // Emit an LLVM global variable for every constant buffer allocation. - Status EmitConstantGlobals(); + absl::Status EmitConstantGlobals(); protected: // @@ -132,55 +132,57 @@ class IrEmitter : public DfsHloVisitorWithDefault, // // Default action which emits code for most operations. Operations which are // special in some way are handled explicitly in HandleFoo methods. - Status DefaultAction(HloInstruction* hlo) override; - - Status HandleAllGather(HloInstruction* instruction) override; - Status HandleAllToAll(HloInstruction* instruction) override; - Status HandleBitcast(HloInstruction* bitcast) override; - Status HandleConstant(HloInstruction* constant) override; - Status HandleCopy(HloInstruction* copy) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; - Status HandleSelect(HloInstruction* select) override; - Status HandleDot(HloInstruction* dot) override; - Status HandleConvolution(HloInstruction* convolution) override; - Status HandleFft(HloInstruction* fft) override; - Status HandleAllReduce(HloInstruction* crs) override; - Status HandleReduceScatter(HloInstruction* crs) override; - Status HandleCollectivePermute(HloInstruction* crs) override; - Status HandleInfeed(HloInstruction* instruction) override; - Status HandleOutfeed(HloInstruction* outfeed) override; - Status HandleSort(HloInstruction* hlo) override; - Status HandleParameter(HloInstruction* parameter) override; - Status HandleReduce(HloInstruction* reduce) override; - Status HandleReduceWindow(HloInstruction* reduce_window) override; - Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override; - Status HandleSend(HloInstruction* send) override; - Status HandleSendDone(HloInstruction* send_done) override; - Status HandleSlice(HloInstruction* slice) override; - Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; - Status HandleDynamicUpdateSlice( + absl::Status DefaultAction(HloInstruction* hlo) override; + + absl::Status HandleAllGather(HloInstruction* instruction) override; + absl::Status HandleAllToAll(HloInstruction* instruction) override; + absl::Status HandleBitcast(HloInstruction* bitcast) override; + absl::Status HandleConstant(HloInstruction* constant) override; + absl::Status HandleCopy(HloInstruction* copy) override; + absl::Status HandleGetTupleElement( + HloInstruction* get_tuple_element) override; + absl::Status HandleSelect(HloInstruction* select) override; + absl::Status HandleDot(HloInstruction* dot) override; + absl::Status HandleConvolution(HloInstruction* convolution) override; + absl::Status HandleFft(HloInstruction* fft) override; + absl::Status HandleAllReduce(HloInstruction* crs) override; + absl::Status HandleReduceScatter(HloInstruction* crs) override; + absl::Status HandleCollectivePermute(HloInstruction* crs) override; + absl::Status HandleInfeed(HloInstruction* instruction) override; + absl::Status HandleOutfeed(HloInstruction* outfeed) override; + absl::Status HandleSort(HloInstruction* hlo) override; + absl::Status HandleParameter(HloInstruction* parameter) override; + absl::Status HandleReduce(HloInstruction* reduce) override; + absl::Status HandleReduceWindow(HloInstruction* reduce_window) override; + absl::Status HandleSelectAndScatter( + HloInstruction* select_and_scatter) override; + absl::Status HandleSend(HloInstruction* send) override; + absl::Status HandleSendDone(HloInstruction* send_done) override; + absl::Status HandleSlice(HloInstruction* slice) override; + absl::Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; + absl::Status HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) override; - Status HandleRecv(HloInstruction* recv) override; - Status HandleRecvDone(HloInstruction* recv_done) override; - Status HandlePad(HloInstruction* pad) override; - Status HandleTuple(HloInstruction* tuple) override; - Status HandleFusion(HloInstruction* fusion) override; - Status HandleCall(HloInstruction* call) override; - Status HandleCustomCall(HloInstruction* custom_call) override; - Status HandleWhile(HloInstruction* xla_while) override; - Status HandleConcatenate(HloInstruction* concatenate) override; - Status HandleConditional(HloInstruction* conditional) override; - Status HandleScatter(HloInstruction* scatter) override; - Status HandleAfterAll(HloInstruction* after_all) override; - Status HandleAddDependency(HloInstruction* add_dependency) override; - Status HandlePartitionId(HloInstruction* hlo) override; - Status HandleReplicaId(HloInstruction* hlo) override; - Status HandleRng(HloInstruction* rng) override; - Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override; - Status FinishVisit(HloInstruction* root) override; - - Status Preprocess(HloInstruction* hlo) override; - Status Postprocess(HloInstruction* hlo) override; + absl::Status HandleRecv(HloInstruction* recv) override; + absl::Status HandleRecvDone(HloInstruction* recv_done) override; + absl::Status HandlePad(HloInstruction* pad) override; + absl::Status HandleTuple(HloInstruction* tuple) override; + absl::Status HandleFusion(HloInstruction* fusion) override; + absl::Status HandleCall(HloInstruction* call) override; + absl::Status HandleCustomCall(HloInstruction* custom_call) override; + absl::Status HandleWhile(HloInstruction* xla_while) override; + absl::Status HandleConcatenate(HloInstruction* concatenate) override; + absl::Status HandleConditional(HloInstruction* conditional) override; + absl::Status HandleScatter(HloInstruction* scatter) override; + absl::Status HandleAfterAll(HloInstruction* after_all) override; + absl::Status HandleAddDependency(HloInstruction* add_dependency) override; + absl::Status HandlePartitionId(HloInstruction* hlo) override; + absl::Status HandleReplicaId(HloInstruction* hlo) override; + absl::Status HandleRng(HloInstruction* rng) override; + absl::Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override; + absl::Status FinishVisit(HloInstruction* root) override; + + absl::Status Preprocess(HloInstruction* hlo) override; + absl::Status Postprocess(HloInstruction* hlo) override; // A convenient helper for calling BufferAssignment::GetUniqueSlice. BufferAllocation::Slice GetAllocationSlice( @@ -189,16 +191,16 @@ class IrEmitter : public DfsHloVisitorWithDefault, } private: - Status HandleSliceToDynamic(HloInstruction* hlo); - Status HandlePadToStatic(HloInstruction* hlo); - Status HandleTopK(HloInstruction* hlo); - Status HandleAllReduceSingleReplica(HloInstruction* crs); - Status HandleAllReduceMultipleReplica(HloInstruction* crs); + absl::Status HandleSliceToDynamic(HloInstruction* hlo); + absl::Status HandlePadToStatic(HloInstruction* hlo); + absl::Status HandleTopK(HloInstruction* hlo); + absl::Status HandleAllReduceSingleReplica(HloInstruction* crs); + absl::Status HandleAllReduceMultipleReplica(HloInstruction* crs); #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - Status HandleOneDnnMatMulCalls(HloInstruction* hlo, - std::string runtime_symbol_name); - Status HandleOneDnnSoftmax(HloInstruction* hlo); - Status HandleOneDnnLayerNorm(HloInstruction* hlo); + absl::Status HandleOneDnnMatMulCalls(HloInstruction* hlo, + std::string runtime_symbol_name); + absl::Status HandleOneDnnSoftmax(HloInstruction* hlo); + absl::Status HandleOneDnnLayerNorm(HloInstruction* hlo); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 // Private helper to initialize an IR function for the computation. void InitializeIrFunction(const std::string& function_name); @@ -316,7 +318,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. - Status ElementTypesSameAndSupported( + absl::Status ElementTypesSameAndSupported( const HloInstruction& instruction, absl::Span<const HloInstruction* const> operands, absl::Span<const PrimitiveType> supported_types); @@ -331,23 +333,23 @@ class IrEmitter : public DfsHloVisitorWithDefault, // in the loop name. // // TODO(jingyue): target_op should be a `const HloInstruction*`. - Status EmitTargetElementLoop( + absl::Status EmitTargetElementLoop( HloInstruction* target_op, const llvm_ir::ElementGenerator& element_generator); - Status EmitTargetElementLoop( + absl::Status EmitTargetElementLoop( HloInstruction* target_op, absl::string_view desc, const llvm_ir::ElementGenerator& element_generator); // Emits a memcpy from the source instruction's result value to the // destination's. Both source and destination must have an entry in the // emitted_value_ table. - Status EmitMemcpy(const HloInstruction& source, - const HloInstruction& destination); + absl::Status EmitMemcpy(const HloInstruction& source, + const HloInstruction& destination); // Emits IR to compute the target address of the buffer for the given op. // After calling this function, you can get a pointer to this buffer by // calling GetIrArrayForOp or GetEmittedValueFor. - Status EmitTargetAddressForOp(const HloInstruction* op); + absl::Status EmitTargetAddressForOp(const HloInstruction* op); // Structurizes "array_elements" into an MD array that represents "shape". // This is a recursive function, and "dimension_index" indicates the index of @@ -652,8 +654,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emit IR to transfer between a {infeed,outfeed} buffer and an in-program // address. - Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, - llvm::Value* program_buffer_address); + absl::Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, + llvm::Value* program_buffer_address); // Returns a ConstExpr bitcast. llvm::Constant* EmitGlobalForLiteral(const Literal& literal); diff --git a/third_party/xla/xla/service/cpu/ir_function.cc b/third_party/xla/xla/service/cpu/ir_function.cc index 69961501a37..cf66fe382e7 100644 --- a/third_party/xla/xla/service/cpu/ir_function.cc +++ b/third_party/xla/xla/service/cpu/ir_function.cc @@ -234,7 +234,7 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments( // Emits a call to a runtime fork/join function which dispatches parallel // calls to 'parallel_function' (and joins threads before returning). -Status EmitCallToParallelForkJoin( +absl::Status EmitCallToParallelForkJoin( const std::vector<llvm::Value*>& arguments, const Shape& shape, absl::Span<const int64_t> dimension_partition_counts, llvm::IRBuilder<>* b, llvm::Function* parallel_function, absl::string_view name) { diff --git a/third_party/xla/xla/service/cpu/ir_function.h b/third_party/xla/xla/service/cpu/ir_function.h index 47034675bdf..73a85867438 100644 --- a/third_party/xla/xla/service/cpu/ir_function.h +++ b/third_party/xla/xla/service/cpu/ir_function.h @@ -145,7 +145,7 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments( // Emits a call to a runtime fork/join function which dispatches parallel // calls to 'parallel_function' (and joins threads before returning). -Status EmitCallToParallelForkJoin( +absl::Status EmitCallToParallelForkJoin( const std::vector<llvm::Value*>& arguments, const Shape& shape, absl::Span<const int64_t> dimension_partition_counts, llvm::IRBuilder<>* b, llvm::Function* parallel_function, absl::string_view name); diff --git a/third_party/xla/xla/service/cpu/mlir_emitter.cc b/third_party/xla/xla/service/cpu/mlir_emitter.cc index 8d3a28815ee..396f567cf2d 100644 --- a/third_party/xla/xla/service/cpu/mlir_emitter.cc +++ b/third_party/xla/xla/service/cpu/mlir_emitter.cc @@ -82,7 +82,7 @@ void BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value *> *args, } } // namespace -Status EmitMlirFuncAndCall( +absl::Status EmitMlirFuncAndCall( mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape, llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr, llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name, diff --git a/third_party/xla/xla/service/cpu/mlir_emitter.h b/third_party/xla/xla/service/cpu/mlir_emitter.h index c7a8480e4e9..af1b1626af2 100644 --- a/third_party/xla/xla/service/cpu/mlir_emitter.h +++ b/third_party/xla/xla/service/cpu/mlir_emitter.h @@ -32,7 +32,7 @@ namespace cpu { // `emitter` and create a call, passing it the buffers defined by // resultShape/resultPtr and operandShapes/operandPtrs. The function is added to // the LLVM module at `b`s insertion point. -Status EmitMlirFuncAndCall( +absl::Status EmitMlirFuncAndCall( mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape, llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr, llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name, diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc index a792a5ba521..1fc542088d1 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc @@ -43,7 +43,7 @@ namespace { namespace m = match; namespace pu = ::xla::cpu::onednn_pattern_utils_internal; -inline Status ValidateDotDimensionNumbers( +inline absl::Status ValidateDotDimensionNumbers( const DotDimensionNumbers& dim_numbers) { // Checks some invariants that do not hold in general, but DotDecomposer // should have established for us. @@ -276,8 +276,8 @@ auto GELUActivation(HloInstruction* instr, HloInstruction** src) { // OneDNN matmul can fuse add operation with automatic broadcasting along the // addend's dimensions that are 1s. When compatible, Broadcast can be replaced // by Bitcast, which is much cheaper. Compute new shape for the Bitcast. -StatusOr<Shape> AdjustBiasShape(const HloInstruction* broadcast_instr, - const Shape& dot_shape) { +absl::StatusOr<Shape> AdjustBiasShape(const HloInstruction* broadcast_instr, + const Shape& dot_shape) { if (broadcast_instr->opcode() != HloOpcode::kBroadcast) { return absl::InvalidArgumentError( "Hlo instruction is not a Broadcast insruction."); @@ -428,7 +428,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { public: // Matches patterns for possible MatMul fusions that are supported by oneDNN // library. Matched HLO instruction(s) are replaced by custom call. - Status HandleDot(HloInstruction* instr) override { + absl::Status HandleDot(HloInstruction* instr) override { HloInstruction* dot_instr; auto pattern = m::Op(&dot_instr).WithOpcode(HloOpcode::kDot); if (!Match(instr, pattern)) return OkStatus(); @@ -463,7 +463,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { return OkStatus(); } - Status HandleAdd(HloInstruction* instr) override { + absl::Status HandleAdd(HloInstruction* instr) override { // Try to do a fusion for Dot(onednn-matmul) + Add. However, // HLO Add instruction might receive the addends after additional // processing like Broadcast, Bitcast, Convert, etc. is applied to the raw @@ -616,7 +616,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { return OkStatus(); } - Status HandleMaximum(HloInstruction* instr) override { + absl::Status HandleMaximum(HloInstruction* instr) override { HloInstruction* matmul_call; HloInstruction* intermediate_instr = nullptr; HloInstruction* optional_bitcast = nullptr; @@ -699,7 +699,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { return OkStatus(); } - Status HandleMultiply(HloInstruction* instr) override { + absl::Status HandleMultiply(HloInstruction* instr) override { HloInstruction* matmul_call; HloInstruction* intermediate_instr = nullptr; HloInstruction* src; @@ -760,10 +760,11 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { return OkStatus(); } - Status FuseActivation(OneDnnMatMulConfig_FusionKind kind, - HloInstruction* activation, HloInstruction* matmul, - HloInstruction* intermediate_instr = nullptr, - HloInstruction* optional_bitcast = nullptr) { + absl::Status FuseActivation(OneDnnMatMulConfig_FusionKind kind, + HloInstruction* activation, + HloInstruction* matmul, + HloInstruction* intermediate_instr = nullptr, + HloInstruction* optional_bitcast = nullptr) { TF_ASSIGN_OR_RETURN(auto backend_config, matmul->backend_config<BackendConfig>()); auto* matmul_config = backend_config.mutable_onednn_matmul_config(); @@ -811,7 +812,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { // lhs: [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim] // rhs: [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1] // result: [batch_dims] to [batch_dims,1,1] - StatusOr<HloInstruction*> ReconfigureDotDimensions( + absl::StatusOr<HloInstruction*> ReconfigureDotDimensions( HloInstruction* dot_instr) { HloInstruction* lhs = dot_instr->mutable_operand(0); HloInstruction* rhs = dot_instr->mutable_operand(1); @@ -900,7 +901,7 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { #endif } - Status HandleCustomCall(HloInstruction* custom_call) override { + absl::Status HandleCustomCall(HloInstruction* custom_call) override { HloInstruction* matmul; if (Match(custom_call, OneDnnMatmulInstr(&matmul))) { return HandleCustomCallInternal<dnnl::matmul::primitive_desc>( @@ -911,7 +912,7 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { } template <typename PrimDesc> - Status HandleCustomCallInternal(HloInstruction* custom_call) { + absl::Status HandleCustomCallInternal(HloInstruction* custom_call) { auto scratch_add = AddScratch<PrimDesc>(custom_call); if (scratch_add.ok()) { custom_call = *scratch_add; @@ -926,10 +927,10 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { } template <typename> - Status SetWeightsPrepack(HloInstruction*, bool); + absl::Status SetWeightsPrepack(HloInstruction*, bool); template <typename> - Status SetUserScratch(HloInstruction*, bool); + absl::Status SetUserScratch(HloInstruction*, bool); template <typename> bool GetWeightsPrepack(HloInstruction*); @@ -940,7 +941,7 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { // Add scratch for matmul by changing the result of custom-call to // tuple(result, scratch) template <typename PrimDesc> - StatusOr<HloInstruction*> AddScratch(HloInstruction* custom_call) { + absl::StatusOr<HloInstruction*> AddScratch(HloInstruction* custom_call) { if (GetUserScratch<PrimDesc>(custom_call)) { return custom_call; } @@ -964,7 +965,7 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { } template <typename PrimDesc> - StatusOr<HloInstruction*> PrepackWeights(HloInstruction* custom_call) { + absl::StatusOr<HloInstruction*> PrepackWeights(HloInstruction* custom_call) { if (GetWeightsPrepack<PrimDesc>(custom_call)) { return custom_call; } @@ -1042,7 +1043,7 @@ EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GetWeightsPrepack, #define EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SETTER, PRIM_DESC, CONFIG_TYPE, \ CONFIG, FIELD) \ template <> \ - inline Status OneDnnPostRewriteVisitor::SETTER<PRIM_DESC>( \ + inline absl::Status OneDnnPostRewriteVisitor::SETTER<PRIM_DESC>( \ HloInstruction * custom_call, bool value) { \ TF_ASSIGN_OR_RETURN(auto backend_config, \ custom_call->backend_config<BackendConfig>()); \ @@ -1060,7 +1061,7 @@ EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetUserScratch, OneDnnMatMulConfig, onednn_matmul_config, user_scratchpad); -StatusOr<bool> OneDnnMatMulRewriter::Run( +absl::StatusOr<bool> OneDnnMatMulRewriter::Run( HloModule* module, const absl::flat_hash_set<absl::string_view>& execution_threads) { OneDnnMatMulRewriteVisitor visitor; diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h index 36cab7ee949..eaeea210021 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h +++ b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h @@ -41,7 +41,7 @@ class OneDnnMatMulRewriter : public HloModulePass { absl::string_view name() const override { return "onednn-matmul-rewriter"; } using HloPassInterface::Run; - StatusOr<bool> Run( + absl::StatusOr<bool> Run( HloModule* module, const absl::flat_hash_set<absl::string_view>& execution_threads) override; diff --git a/third_party/xla/xla/service/cpu/onednn_memory_util.cc b/third_party/xla/xla/service/cpu/onednn_memory_util.cc index fd0b6f92795..6ab913161a7 100644 --- a/third_party/xla/xla/service/cpu/onednn_memory_util.cc +++ b/third_party/xla/xla/service/cpu/onednn_memory_util.cc @@ -169,7 +169,7 @@ int64_t MemrefInfo::GetChannels() const { return pod_->dims[pod_->rank - 1]; } int64_t MemrefInfo::GetRank() const { return pod_->rank; } -StatusOr<dnnl::memory::desc> TransposeLastTwoDims( +absl::StatusOr<dnnl::memory::desc> TransposeLastTwoDims( const dnnl::memory::desc& md) { int64_t ndims = md.get_ndims(); if (ndims < 2) { diff --git a/third_party/xla/xla/service/cpu/onednn_memory_util.h b/third_party/xla/xla/service/cpu/onednn_memory_util.h index 5793c1bdf4f..c0c956a32dc 100644 --- a/third_party/xla/xla/service/cpu/onednn_memory_util.h +++ b/third_party/xla/xla/service/cpu/onednn_memory_util.h @@ -119,7 +119,8 @@ class MemrefInfo { MemrefInfoPOD* pod_; }; -StatusOr<dnnl::memory::desc> TransposeLastTwoDims(const dnnl::memory::desc& md); +absl::StatusOr<dnnl::memory::desc> TransposeLastTwoDims( + const dnnl::memory::desc& md); #define TRANSPOSE_LAST_TWO_DIMS_IF(pred, mem_desc) \ if (pred) { \ auto trans_mem_desc = TransposeLastTwoDims(mem_desc); \ diff --git a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc index 06c2137a782..25530aa59c8 100644 --- a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc @@ -385,7 +385,7 @@ bool MatchFlaxLayerNorm(HloInstruction* instr, HloInstruction** src, class OneDnnOpsRewriterVisitor : public DfsHloRewriteVisitor { public: - Status HandleAdd(HloInstruction* instr) override { + absl::Status HandleAdd(HloInstruction* instr) override { HloInstruction *src, *scale, *bias; float eps; bool is_bf16orfp16_convert = false; @@ -445,7 +445,7 @@ class OneDnnOpsRewriterVisitor : public DfsHloRewriteVisitor { return OkStatus(); } - Status HandleConvert(HloInstruction* instr) override { + absl::Status HandleConvert(HloInstruction* instr) override { HloInstruction* custom_call; HloInstruction* convert_instr; auto pattern = @@ -478,7 +478,7 @@ class OneDnnOpsRewriterVisitor : public DfsHloRewriteVisitor { return OkStatus(); } - Status HandleDivide(HloInstruction* divide_instr) override { + absl::Status HandleDivide(HloInstruction* divide_instr) override { if (divide_instr->HasControlDependencies()) return OkStatus(); if (!IsSupportedType(divide_instr->shape().element_type())) return OkStatus(); @@ -495,7 +495,7 @@ class OneDnnOpsRewriterVisitor : public DfsHloRewriteVisitor { } }; -StatusOr<bool> OneDnnOpsRewriter::Run( +absl::StatusOr<bool> OneDnnOpsRewriter::Run( HloModule* module, const absl::flat_hash_set<absl::string_view>& execution_threads) { OneDnnOpsRewriterVisitor visitor; diff --git a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.h b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.h index ea62f33ebcf..8e777d8889a 100644 --- a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.h +++ b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.h @@ -32,7 +32,7 @@ class OneDnnOpsRewriter : public HloModulePass { absl::string_view name() const override { return "onednn-ops-rewriter"; } using HloPassInterface::Run; - StatusOr<bool> Run( + absl::StatusOr<bool> Run( HloModule* module, const absl::flat_hash_set<absl::string_view>& execution_threads) override; }; diff --git a/third_party/xla/xla/service/cpu/onednn_rewriter.h b/third_party/xla/xla/service/cpu/onednn_rewriter.h index a1ba3205c96..53fd5c0f977 100644 --- a/third_party/xla/xla/service/cpu/onednn_rewriter.h +++ b/third_party/xla/xla/service/cpu/onednn_rewriter.h @@ -33,7 +33,7 @@ class OneDnnRewriter : public HloModulePass { absl::string_view name() const override { return "onednn-rewriter"; } using HloPassInterface::Run; - StatusOr<bool> Run( + absl::StatusOr<bool> Run( HloModule* module, const absl::flat_hash_set<absl::string_view>& execution_threads) override; }; diff --git a/third_party/xla/xla/service/cpu/parallel_task_assignment.cc b/third_party/xla/xla/service/cpu/parallel_task_assignment.cc index 6d103759585..b82707a367e 100644 --- a/third_party/xla/xla/service/cpu/parallel_task_assignment.cc +++ b/third_party/xla/xla/service/cpu/parallel_task_assignment.cc @@ -133,7 +133,8 @@ ParallelTaskAssignment::ParallelTaskAssignment( // Run cost analysis on 'module'. auto cost_analysis = std::make_unique<HloCostAnalysis>(shape_size); HloComputation* computation = module->entry_computation(); - Status status = computation->root_instruction()->Accept(cost_analysis.get()); + absl::Status status = + computation->root_instruction()->Accept(cost_analysis.get()); if (status.ok()) { // Set default cost model based on 'cost_analysis'. cost_model_ = std::make_unique<DefaultCostModel>( |