aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKyle Lucke <klucke@google.com>2024-05-16 15:20:27 -0700
committerTensorFlower Gardener <gardener@tensorflow.org>2024-05-16 17:07:50 -0700
commit9c5e56cd91bb3206f9f9751bbfcc4426abca9cca (patch)
tree5debf674310cd60faec401d202764283e284db3d
parent6ec902447b28138b7fc45720cd9cba6c341a1bd5 (diff)
downloadtensorflow-upstream-master.tar.gz
Use absl::Status instead of xla::Status now that they're identical.upstream-master
PiperOrigin-RevId: 634545999
-rw-r--r--third_party/xla/xla/service/cpu/cpu_compiler.cc34
-rw-r--r--third_party/xla/xla/service/cpu/cpu_compiler.h12
-rw-r--r--third_party/xla/xla/service/cpu/cpu_executable.cc4
-rw-r--r--third_party/xla/xla/service/cpu/cpu_executable.h2
-rw-r--r--third_party/xla/xla/service/cpu/cpu_layout_assignment.cc2
-rw-r--r--third_party/xla/xla/service/cpu/cpu_layout_assignment.h2
-rw-r--r--third_party/xla/xla/service/cpu/cpu_transfer_manager.cc10
-rw-r--r--third_party/xla/xla/service/cpu/cpu_transfer_manager.h14
-rw-r--r--third_party/xla/xla/service/cpu/cpu_xfeed.cc14
-rw-r--r--third_party/xla/xla/service/cpu/cpu_xfeed.h14
-rw-r--r--third_party/xla/xla/service/cpu/dot_op_emitter.cc47
-rw-r--r--third_party/xla/xla/service/cpu/dot_op_emitter.h16
-rw-r--r--third_party/xla/xla/service/cpu/ir_emitter.cc140
-rw-r--r--third_party/xla/xla/service/cpu/ir_emitter.h134
-rw-r--r--third_party/xla/xla/service/cpu/ir_function.cc2
-rw-r--r--third_party/xla/xla/service/cpu/ir_function.h2
-rw-r--r--third_party/xla/xla/service/cpu/mlir_emitter.cc2
-rw-r--r--third_party/xla/xla/service/cpu/mlir_emitter.h2
-rw-r--r--third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc41
-rw-r--r--third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h2
-rw-r--r--third_party/xla/xla/service/cpu/onednn_memory_util.cc2
-rw-r--r--third_party/xla/xla/service/cpu/onednn_memory_util.h3
-rw-r--r--third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc8
-rw-r--r--third_party/xla/xla/service/cpu/onednn_ops_rewriter.h2
-rw-r--r--third_party/xla/xla/service/cpu/onednn_rewriter.h2
-rw-r--r--third_party/xla/xla/service/cpu/parallel_task_assignment.cc3
26 files changed, 262 insertions, 254 deletions
diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc
index ac9fa032c88..6bed3bd4263 100644
--- a/third_party/xla/xla/service/cpu/cpu_compiler.cc
+++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc
@@ -351,13 +351,13 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
: hlo_to_profile_idx_(hlo_to_profile_idx),
assigned_indices_(assigned_indices) {}
- Status DefaultAction(HloInstruction* hlo_instruction) override {
+ absl::Status DefaultAction(HloInstruction* hlo_instruction) override {
hlo_to_profile_idx_->insert(
{hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
return absl::OkStatus();
}
- Status HandleCall(HloInstruction* call) override {
+ absl::Status HandleCall(HloInstruction* call) override {
TF_RETURN_IF_ERROR(DefaultAction(call));
CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
assigned_indices_);
@@ -365,7 +365,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
return absl::OkStatus();
}
// Recurse into "conditional" so we can profile inside of it.
- Status HandleConditional(HloInstruction* conditional) override {
+ absl::Status HandleConditional(HloInstruction* conditional) override {
TF_RETURN_IF_ERROR(DefaultAction(conditional));
CollectProfileCandidates candidates_for_true(hlo_to_profile_idx_,
@@ -382,12 +382,16 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
}
// Skip constants, there is nothing to profile.
- Status HandleConstant(HloInstruction*) override { return absl::OkStatus(); }
+ absl::Status HandleConstant(HloInstruction*) override {
+ return absl::OkStatus();
+ }
// Skip parameters, they are a simple load.
- Status HandleParameter(HloInstruction*) override { return absl::OkStatus(); }
+ absl::Status HandleParameter(HloInstruction*) override {
+ return absl::OkStatus();
+ }
// It is important to recurse for "while" or else we risk overly coarse
// profiling information.
- Status HandleWhile(HloInstruction* xla_while) override {
+ absl::Status HandleWhile(HloInstruction* xla_while) override {
TF_RETURN_IF_ERROR(DefaultAction(xla_while));
CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
@@ -423,7 +427,7 @@ void AddHloVerifier(HloPassPipeline* pipeline, HloVerifierOpts&& opts = {},
} // namespace
-Status CpuCompiler::RunHloPassesThroughLayoutAssn(
+absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
HloModule* module, bool is_aot_compile,
LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile) {
const DebugOptions& debug_options = module->config().debug_options();
@@ -696,7 +700,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
return pipeline.Run(module).status();
}
-Status CpuCompiler::RunHloPassesAfterLayoutAssn(
+absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn(
HloModule* module, bool is_aot_compile,
LLVMTargetMachineFeatures* target_machine_features,
const CompileOptions& compile_options, bool is_mlir_compile) {
@@ -799,10 +803,10 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn(
return pipeline.Run(module).status();
}
-Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
- llvm::TargetMachine* target_machine,
- const CompileOptions& compile_options,
- bool is_mlir_compile) {
+absl::Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
+ llvm::TargetMachine* target_machine,
+ const CompileOptions& compile_options,
+ bool is_mlir_compile) {
LLVMTargetMachineFeatures target_machine_features(target_machine);
TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(
module, is_aot_compile, &target_machine_features, is_mlir_compile));
@@ -870,7 +874,7 @@ std::pair<LLVMCompiler::ModuleHook, LLVMCompiler::ModuleHook> GetIRModuleHooks(
}};
}
-Status VerifyLlvmModule(const llvm::Module& llvm_module) {
+absl::Status VerifyLlvmModule(const llvm::Module& llvm_module) {
XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier");
std::string err;
@@ -885,7 +889,7 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) {
return absl::OkStatus();
}
-Status CreateHloProfilingArtifacts(
+absl::Status CreateHloProfilingArtifacts(
const HloModule& module,
absl::flat_hash_map<const HloInstruction*, int64_t>*
instruction_to_profile_idx,
@@ -1404,7 +1408,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
// Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run
// the pre-optimization IR dump hook before returning.
{
- Status verify_status = VerifyLlvmModule(*llvm_module);
+ absl::Status verify_status = VerifyLlvmModule(*llvm_module);
if (!verify_status.ok() && pre_optimization_ir_hook) {
pre_optimization_ir_hook(*llvm_module);
}
diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.h b/third_party/xla/xla/service/cpu/cpu_compiler.h
index 69ba7f498f3..3cc54c4a616 100644
--- a/third_party/xla/xla/service/cpu/cpu_compiler.h
+++ b/third_party/xla/xla/service/cpu/cpu_compiler.h
@@ -195,19 +195,19 @@ class CpuCompiler : public LLVMCompiler {
// Runs the HLO passes which are necessary for both optimizations and
// correctness.
- Status RunHloPasses(HloModule* module, bool is_aot_compile,
- llvm::TargetMachine* target_machine,
- const CompileOptions& compile_options,
- bool is_mlir_compile = false);
+ absl::Status RunHloPasses(HloModule* module, bool is_aot_compile,
+ llvm::TargetMachine* target_machine,
+ const CompileOptions& compile_options,
+ bool is_mlir_compile = false);
// Runs HLO passes up to and including layout assignment.
- Status RunHloPassesThroughLayoutAssn(
+ absl::Status RunHloPassesThroughLayoutAssn(
HloModule* module, bool /*is_aot_compile*/,
LLVMTargetMachineFeatures* target_machine_features,
bool is_mlir_compile = false);
// Runs HLO passes after layout assignment.
- Status RunHloPassesAfterLayoutAssn(
+ absl::Status RunHloPassesAfterLayoutAssn(
HloModule* module, bool is_aot_compile,
LLVMTargetMachineFeatures* target_machine_features,
const CompileOptions& compile_options, bool is_mlir_compile);
diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc
index 01a5007a3de..02fb083554c 100644
--- a/third_party/xla/xla/service/cpu/cpu_executable.cc
+++ b/third_party/xla/xla/service/cpu/cpu_executable.cc
@@ -180,7 +180,7 @@ CpuExecutable::CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
return std::move(buffers);
}
-Status CpuExecutable::ExecuteComputeFunction(
+absl::Status CpuExecutable::ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
absl::Span<MaybeOwningDeviceMemory const> buffers,
HloExecutionProfile* hlo_execution_profile) {
@@ -399,7 +399,7 @@ absl::StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
std::shared_ptr<std::vector<MaybeOwningDeviceMemory>> task_buffers;
HloExecutionProfile* hlo_execution_profile;
- Status operator()() {
+ absl::Status operator()() {
return executable->ExecuteComputeFunction(
&run_options.run_options(), *task_buffers, hlo_execution_profile);
}
diff --git a/third_party/xla/xla/service/cpu/cpu_executable.h b/third_party/xla/xla/service/cpu/cpu_executable.h
index 2dc7e2d9700..5cf8aa83357 100644
--- a/third_party/xla/xla/service/cpu/cpu_executable.h
+++ b/third_party/xla/xla/service/cpu/cpu_executable.h
@@ -73,7 +73,7 @@ class CpuExecutable : public Executable {
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.
- Status ExecuteComputeFunction(
+ absl::Status ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
absl::Span<MaybeOwningDeviceMemory const> buffers,
HloExecutionProfile* hlo_execution_profile);
diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc
index 35371ea9463..48f5ae59bfe 100644
--- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc
+++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc
@@ -120,7 +120,7 @@ static bool OperandsAndResultMustHaveRowMajorLayout(
return false;
}
-Status CpuLayoutAssignment::AddBackendConstraints(
+absl::Status CpuLayoutAssignment::AddBackendConstraints(
LayoutConstraints* constraints) {
ShouldMakeOperandColMajorCache cache;
diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.h b/third_party/xla/xla/service/cpu/cpu_layout_assignment.h
index 35ecde418bc..26e155458f5 100644
--- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.h
+++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.h
@@ -37,7 +37,7 @@ class CpuLayoutAssignment : public LayoutAssignment {
~CpuLayoutAssignment() override {}
protected:
- Status AddBackendConstraints(LayoutConstraints* constraints) override;
+ absl::Status AddBackendConstraints(LayoutConstraints* constraints) override;
const TargetMachineFeatures& target_machine_features_;
};
diff --git a/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc b/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc
index 27e62fc1af7..fd68f299635 100644
--- a/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc
+++ b/third_party/xla/xla/service/cpu/cpu_transfer_manager.cc
@@ -43,19 +43,19 @@ CpuTransferManager::CpuTransferManager()
: GenericTransferManager(se::host::kHostPlatformId,
/*pointer_size=*/sizeof(void*)) {}
-Status CpuTransferManager::TransferLiteralToInfeed(
+absl::Status CpuTransferManager::TransferLiteralToInfeed(
se::StreamExecutor* executor, const LiteralSlice& literal) {
return TransferLiteralToInfeedOnCpu(executor->device_ordinal(), literal);
}
-Status CpuTransferManager::TransferLiteralFromOutfeed(
+absl::Status CpuTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* executor, MutableBorrowingLiteral literal) {
return TransferLiteralFromOutfeedOnCpu(executor->device_ordinal(), literal);
}
-Status CpuTransferManager::ReadDynamicShapes(se::Stream* stream,
- const ShapedBuffer* device_buffer,
- Shape* device_shape) {
+absl::Status CpuTransferManager::ReadDynamicShapes(
+ se::Stream* stream, const ShapedBuffer* device_buffer,
+ Shape* device_shape) {
if (stream != nullptr) {
// When a stream is presented, respect the stream dependency.
return TransferManager::ReadDynamicShapes(stream, device_buffer,
diff --git a/third_party/xla/xla/service/cpu/cpu_transfer_manager.h b/third_party/xla/xla/service/cpu/cpu_transfer_manager.h
index 9cdf0478a47..ed7f0fe3f7b 100644
--- a/third_party/xla/xla/service/cpu/cpu_transfer_manager.h
+++ b/third_party/xla/xla/service/cpu/cpu_transfer_manager.h
@@ -37,10 +37,10 @@ class CpuTransferManager : public GenericTransferManager {
CpuTransferManager();
~CpuTransferManager() override {}
- Status TransferLiteralToInfeed(se::StreamExecutor* executor,
- const LiteralSlice& literal) override;
- Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
- MutableBorrowingLiteral literal) override;
+ absl::Status TransferLiteralToInfeed(se::StreamExecutor* executor,
+ const LiteralSlice& literal) override;
+ absl::Status TransferLiteralFromOutfeed(
+ se::StreamExecutor* executor, MutableBorrowingLiteral literal) override;
bool CanShapedBufferBeAccessedNow(
se::StreamExecutor* executor,
@@ -54,9 +54,9 @@ class CpuTransferManager : public GenericTransferManager {
return true;
}
- Status ReadDynamicShapes(se::Stream* stream,
- const ShapedBuffer* device_buffer,
- Shape* device_shape) override;
+ absl::Status ReadDynamicShapes(se::Stream* stream,
+ const ShapedBuffer* device_buffer,
+ Shape* device_shape) override;
private:
bool PackSubbyteTypes() const override { return true; }
diff --git a/third_party/xla/xla/service/cpu/cpu_xfeed.cc b/third_party/xla/xla/service/cpu/cpu_xfeed.cc
index b6fea6cb0d3..8b4ea238f51 100644
--- a/third_party/xla/xla/service/cpu/cpu_xfeed.cc
+++ b/third_party/xla/xla/service/cpu/cpu_xfeed.cc
@@ -102,8 +102,8 @@ absl::StatusOr<cpu::runtime::XfeedBuffer*> TransferBufferToInfeedInternal(
return queued_buffer;
}
-Status TransferBufferToInfeed(int device_ordinal, int64_t size,
- const void* source) {
+absl::Status TransferBufferToInfeed(int device_ordinal, int64_t size,
+ const void* source) {
TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer,
TransferBufferToInfeedInternal(size, source));
@@ -175,8 +175,8 @@ absl::StatusOr<Shape> TransferTupleBuffersFromOutfeed(
}
} // namespace
-Status TransferLiteralToInfeedOnCpu(int device_ordinal,
- const LiteralSlice& literal) {
+absl::Status TransferLiteralToInfeedOnCpu(int device_ordinal,
+ const LiteralSlice& literal) {
const Shape& shape = literal.shape();
VLOG(2) << "Transferring literal to infeed with shape: "
<< ShapeUtil::HumanString(shape);
@@ -221,8 +221,8 @@ Status TransferLiteralToInfeedOnCpu(int device_ordinal,
return OkStatus();
}
-Status TransferLiteralFromOutfeedOnCpu(int device_ordinal,
- MutableBorrowingLiteral literal) {
+absl::Status TransferLiteralFromOutfeedOnCpu(int device_ordinal,
+ MutableBorrowingLiteral literal) {
if (!literal.shape().IsTuple()) {
int64_t size =
cpu::runtime::GetByteSizeRequirement(literal.shape(), sizeof(void*));
@@ -275,7 +275,7 @@ Status TransferLiteralFromOutfeedOnCpu(int device_ordinal,
return OkStatus();
}
-Status ReadDynamicShapesOnCpu(
+absl::Status ReadDynamicShapesOnCpu(
const ShapedBuffer* device_buffer, Shape* device_shape,
HloCostAnalysis::ShapeSizeFunction shape_size_fn) {
TF_RET_CHECK(device_shape->is_dynamic());
diff --git a/third_party/xla/xla/service/cpu/cpu_xfeed.h b/third_party/xla/xla/service/cpu/cpu_xfeed.h
index 26512839d50..7d09ec862d8 100644
--- a/third_party/xla/xla/service/cpu/cpu_xfeed.h
+++ b/third_party/xla/xla/service/cpu/cpu_xfeed.h
@@ -30,17 +30,17 @@ limitations under the License.
namespace xla {
// Helper function to transfers to infeed on CPU.
-Status TransferLiteralToInfeedOnCpu(int device_ordinal,
- const LiteralSlice& literal);
+absl::Status TransferLiteralToInfeedOnCpu(int device_ordinal,
+ const LiteralSlice& literal);
// Helper function to transfers from outfeed on CPU.
-Status TransferLiteralFromOutfeedOnCpu(int device_ordinal,
- MutableBorrowingLiteral literal);
+absl::Status TransferLiteralFromOutfeedOnCpu(int device_ordinal,
+ MutableBorrowingLiteral literal);
// Helper function to retrieve dynamic shape on CPU.
-Status ReadDynamicShapesOnCpu(const ShapedBuffer* device_buffer,
- Shape* device_shape,
- HloCostAnalysis::ShapeSizeFunction shape_size_fn);
+absl::Status ReadDynamicShapesOnCpu(
+ const ShapedBuffer* device_buffer, Shape* device_shape,
+ HloCostAnalysis::ShapeSizeFunction shape_size_fn);
} // namespace xla
#endif // XLA_SERVICE_CPU_CPU_XFEED_H_
diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.cc b/third_party/xla/xla/service/cpu/dot_op_emitter.cc
index c6bad1e54ec..9548b9919ac 100644
--- a/third_party/xla/xla/service/cpu/dot_op_emitter.cc
+++ b/third_party/xla/xla/service/cpu/dot_op_emitter.cc
@@ -134,21 +134,21 @@ class DotOpEmitter {
const TargetMachineFeatures& target_machine_features);
// Emits the IR to perform the dot operation.
- Status Emit();
+ absl::Status Emit();
// Emits the IR to perform the batch dot operation.
- Status EmitBatch();
+ absl::Status EmitBatch();
private:
// Emits instructions to perform a scalar dot product (a multiply of the
// LHS and RHS) and store the results in the target.
- Status EmitScalarDot();
+ absl::Status EmitScalarDot();
// Emits a call to the CPU runtime to perform the matrix multiply.
- Status EmitCallToRuntime();
+ absl::Status EmitCallToRuntime();
// Emits a call to the CPU runtime to perform the batch matrix multiply.
- Status EmitCallToBatchRuntime();
+ absl::Status EmitCallToBatchRuntime();
// Represents the dimensions of a matrix-matrix multiply operation.
struct MatMultDims {
@@ -192,7 +192,7 @@ class DotOpEmitter {
void EmitTiledLlvmIrGemm();
// Lowers the dot operation through MLIR's linalg.matmul.
- Status EmitLinalgMatmul();
+ absl::Status EmitLinalgMatmul();
// Lowers the dot operation as a naive nested loop that computes the result
// one element at a time.
@@ -264,7 +264,7 @@ DotOpEmitter::DotOpEmitter(
hlo_module_config_(hlo_module_config),
target_machine_features_(target_machine_features) {}
-Status DotOpEmitter::EmitLinalgMatmul() {
+absl::Status DotOpEmitter::EmitLinalgMatmul() {
Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape};
llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(),
rhs_array_.GetBasePointer()};
@@ -529,7 +529,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() {
}
}
-Status DotOpEmitter::Emit() {
+absl::Status DotOpEmitter::Emit() {
// The dot operation performs a sum of products over dimension 0 of the left
// hand side operand and dimension 1 of the right hand side operand.
//
@@ -584,7 +584,7 @@ Status DotOpEmitter::Emit() {
}
}
-Status DotOpEmitter::EmitBatch() {
+absl::Status DotOpEmitter::EmitBatch() {
// The dot operation performs a sum of products over dimension 0 of the left
// hand side operand and dimension 1 of the right hand side operand.
//
@@ -756,7 +756,7 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() {
b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
}
-Status DotOpEmitter::EmitScalarDot() {
+absl::Status DotOpEmitter::EmitScalarDot() {
// A scalar dot is just a scalar multiply.
llvm::Value* result;
// Use the same index_type for all tensor accesses in the same kernel.
@@ -791,7 +791,7 @@ Status DotOpEmitter::EmitScalarDot() {
return OkStatus();
}
-Status DotOpEmitter::EmitCallToRuntime() {
+absl::Status DotOpEmitter::EmitCallToRuntime() {
// The signature of the Eigen runtime matmul function is:
//
// (void)(void* run_options, float* out, float* lhs, float* rhs,
@@ -902,7 +902,7 @@ Status DotOpEmitter::EmitCallToRuntime() {
return OkStatus();
}
-Status DotOpEmitter::EmitCallToBatchRuntime() {
+absl::Status DotOpEmitter::EmitCallToBatchRuntime() {
// The signature of the runtime batch matmul function is:
//
// (void)(void* run_options, float* out, float* lhs, float* rhs,
@@ -1218,7 +1218,7 @@ DotImplementationStrategy GetDotImplementationStrategy(
return DotImplementationStrategy::kNaiveLlvmIr;
}
-Status EmitNonBatchDotOperation(
+absl::Status EmitNonBatchDotOperation(
DotInfo dot_info, std::string hlo_name,
const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
@@ -1270,7 +1270,8 @@ llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b,
std::move(new_shape));
}
-Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) {
+absl::Status ValidateDotDimensionNumbers(
+ const DotDimensionNumbers& dim_numbers) {
// Checks some invariants that do not hold in general, but DotDecomposer
// should have established for us. This is just a debugging aid.
TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
@@ -1357,7 +1358,7 @@ bool PotentiallyImplementedAsEigenMatmul(
return impl_strategy == DotImplementationStrategy::kEigen;
}
-Status EmitBatchDotOperation(
+absl::Status EmitBatchDotOperation(
const HloInstruction& dot, const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
@@ -1481,15 +1482,13 @@ bool DotOperandsAndResultMustHaveRowMajorLayout(
impl_strategy == DotImplementationStrategy::kEigen;
}
-Status EmitDotOperation(const HloInstruction& dot,
- const llvm_ir::IrArray& target_array,
- const llvm_ir::IrArray& lhs_array,
- const llvm_ir::IrArray& rhs_array,
- const llvm_ir::IrArray* addend_array,
- llvm::Value* executable_run_options_value,
- llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
- const HloModuleConfig& hlo_module_config,
- const TargetMachineFeatures& target_machine_features) {
+absl::Status EmitDotOperation(
+ const HloInstruction& dot, const llvm_ir::IrArray& target_array,
+ const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
+ const llvm_ir::IrArray* addend_array,
+ llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
+ mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
+ const TargetMachineFeatures& target_machine_features) {
// This routine assumes that the dot operation is not in a parallelized
// enclosing computation.
CHECK(dot.parent()
diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.h b/third_party/xla/xla/service/cpu/dot_op_emitter.h
index 9d4f5617960..58e3afee737 100644
--- a/third_party/xla/xla/service/cpu/dot_op_emitter.h
+++ b/third_party/xla/xla/service/cpu/dot_op_emitter.h
@@ -57,15 +57,13 @@ std::optional<int64_t> ProfitableToMakeDotOperandColumnMajor(
// dimensions as the result, and the result is computed as `addend_array` +
// dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported
// for Matrix-vector products.
-Status EmitDotOperation(const HloInstruction& dot,
- const llvm_ir::IrArray& target_array,
- const llvm_ir::IrArray& lhs_array,
- const llvm_ir::IrArray& rhs_array,
- const llvm_ir::IrArray* addend_array,
- llvm::Value* executable_run_options_value,
- llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
- const HloModuleConfig& hlo_module_config,
- const TargetMachineFeatures& target_machine_features);
+absl::Status EmitDotOperation(
+ const HloInstruction& dot, const llvm_ir::IrArray& target_array,
+ const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
+ const llvm_ir::IrArray* addend_array,
+ llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
+ mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
+ const TargetMachineFeatures& target_machine_features);
} // namespace cpu
} // namespace xla
diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc
index dff2ee5e7cd..d6a5a832d03 100644
--- a/third_party/xla/xla/service/cpu/ir_emitter.cc
+++ b/third_party/xla/xla/service/cpu/ir_emitter.cc
@@ -126,7 +126,7 @@ IrEmitter::IrEmitter(mlir::MLIRContext* mlir_context,
target_machine_features_(*target_machine_features),
emit_code_for_msan_(emit_code_for_msan) {
b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_));
- Status s = GatherComputationsByAllocationType(
+ absl::Status s = GatherComputationsByAllocationType(
&hlo_module, &thread_local_computations_, &global_computations_);
absl::c_sort(thread_local_computations_);
absl::c_sort(global_computations_);
@@ -255,7 +255,7 @@ void IrEmitter::InitializeIrFunction(const std::string& function_name) {
module_, &b_, num_dynamic_loop_bounds_);
}
-Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
+absl::Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
VLOG(2) << "HandleBitcast: " << bitcast->ToString();
emitted_value_[bitcast] = GetEmittedValueFor(bitcast->operand(0));
return OkStatus();
@@ -277,7 +277,7 @@ llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
return result_global;
}
-Status IrEmitter::EmitConstantGlobals() {
+absl::Status IrEmitter::EmitConstantGlobals() {
for (const BufferAllocation& allocation : assignment_.Allocations()) {
if (!allocation.is_constant()) {
continue;
@@ -300,14 +300,14 @@ Status IrEmitter::EmitConstantGlobals() {
return OkStatus();
}
-Status IrEmitter::HandleConstant(HloInstruction* constant) {
+absl::Status IrEmitter::HandleConstant(HloInstruction* constant) {
VLOG(2) << "HandleConstant: " << constant->ToString();
// IrEmitter::EmitConstantGlobals has already taken care of emitting the body
// of the constant.
return EmitTargetAddressForOp(constant);
}
-Status IrEmitter::HandleCopy(HloInstruction* copy) {
+absl::Status IrEmitter::HandleCopy(HloInstruction* copy) {
if (copy->shape().IsTuple() ||
(copy->shape().IsArray() &&
LayoutUtil::Equal(copy->operand(0)->shape().layout(),
@@ -382,7 +382,8 @@ void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
}
}
-Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
+absl::Status IrEmitter::HandleGetTupleElement(
+ HloInstruction* get_tuple_element) {
// A tuple is an array of pointers, one for each operand. Each pointer points
// to the output buffer of its corresponding operand. A GetTupleElement
// instruction forwards a pointer to the tuple element buffer at the given
@@ -395,13 +396,13 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
return OkStatus();
}
-Status IrEmitter::HandleSelect(HloInstruction* select) {
+absl::Status IrEmitter::HandleSelect(HloInstruction* select) {
auto pred = select->operand(0);
TF_RET_CHECK(pred->shape().element_type() == PRED);
return DefaultAction(select);
}
-Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
+absl::Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
VLOG(2) << "HandleInfeed: " << infeed->ToString();
@@ -461,8 +462,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
return OkStatus();
}
-Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
- llvm::Value* program_buffer_address) {
+absl::Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
+ llvm::Value* program_buffer_address) {
int64_t length = ByteSizeOf(shape);
if (length < 0 || length > std::numeric_limits<int32_t>::max()) {
return InvalidArgument(
@@ -526,7 +527,7 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
return OkStatus();
}
-Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
+absl::Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
// Outfeed produces no useful result, but it does return a token[] that can be
// threaded through to other side effecting operations to ensure ordering. In
// the IR emitter we treat this token as a normal u8[] and thus need to insert
@@ -556,7 +557,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
return OkStatus();
}
-Status IrEmitter::HandleSort(HloInstruction* hlo) {
+absl::Status IrEmitter::HandleSort(HloInstruction* hlo) {
const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
Shape keys_shape = sort->keys()->shape();
@@ -651,7 +652,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
return OkStatus();
}
-Status IrEmitter::HandleTuple(HloInstruction* tuple) {
+absl::Status IrEmitter::HandleTuple(HloInstruction* tuple) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple));
llvm::SmallVector<llvm::Value*> base_ptrs;
for (auto operand : tuple->operands()) {
@@ -661,7 +662,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
return OkStatus();
}
-Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
+absl::Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
// Pseudo code for reduce window:
//
// for (coordinates O in the output)
@@ -677,12 +678,13 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
// that works.
bool saved_allow_reassociation = allow_reassociation_;
allow_reassociation_ = true;
- Status status = DefaultAction(reduce_window);
+ absl::Status status = DefaultAction(reduce_window);
allow_reassociation_ = saved_allow_reassociation;
return status;
}
-Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
+absl::Status IrEmitter::HandleSelectAndScatter(
+ HloInstruction* select_and_scatter) {
CHECK_EQ(select_and_scatter->operand_count(), 3);
const auto operand = select_and_scatter->operand(0);
const auto source = select_and_scatter->operand(1);
@@ -864,7 +866,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
return OkStatus();
}
-Status IrEmitter::HandleDot(HloInstruction* dot) {
+absl::Status IrEmitter::HandleDot(HloInstruction* dot) {
auto lhs = dot->operand(0);
auto rhs = dot->operand(1);
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
@@ -900,7 +902,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
hlo_module_config_, target_machine_features_);
}
-Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
+absl::Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
auto lhs = convolution->operand(0);
auto rhs = convolution->operand(1);
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
@@ -1085,7 +1087,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
return DefaultAction(convolution);
}
-Status IrEmitter::HandleFft(HloInstruction* fft) {
+absl::Status IrEmitter::HandleFft(HloInstruction* fft) {
auto operand = fft->operand(0);
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*fft, /*operands=*/{operand},
@@ -1137,7 +1139,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
return OkStatus();
}
-Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) {
+absl::Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) {
// When there is a single replica, a cross replica sum is the identity
// function, and the buffer assignment expects a copy.
//
@@ -1196,7 +1198,7 @@ static bool DataTypeIsSupportedByReduceScatter(PrimitiveType datatype) {
}
}
-Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
+absl::Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
CHECK_GE(crs->operand_count(), 1);
PrimitiveType datatype = crs->operand(0)->shape().element_type();
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
@@ -1280,7 +1282,7 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
return OkStatus();
}
-Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
+absl::Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
if (hlo_module_config_.replica_count() == 1 &&
hlo_module_config_.num_partitions() == 1) {
return HandleAllReduceSingleReplica(crs);
@@ -1288,7 +1290,7 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
return HandleAllReduceMultipleReplica(crs);
}
-Status IrEmitter::HandleReduceScatter(HloInstruction* rs) {
+absl::Status IrEmitter::HandleReduceScatter(HloInstruction* rs) {
CHECK_EQ(rs->operand_count(), 1);
PrimitiveType datatype = rs->operand(0)->shape().element_type();
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rs));
@@ -1344,7 +1346,7 @@ Status IrEmitter::HandleReduceScatter(HloInstruction* rs) {
return OkStatus();
}
-Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
+absl::Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
auto* instr = Cast<HloAllToAllInstruction>(instruction);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
CHECK(!instr->split_dimension() && instr->shape().IsTuple())
@@ -1400,7 +1402,7 @@ Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
return OkStatus();
}
-Status IrEmitter::HandleAllGather(HloInstruction* instruction) {
+absl::Status IrEmitter::HandleAllGather(HloInstruction* instruction) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
std::string replica_groups =
@@ -1451,7 +1453,7 @@ Status IrEmitter::HandleAllGather(HloInstruction* instruction) {
return OkStatus();
}
-Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
+absl::Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr));
std::string source_target_pairs = absl::StrJoin(
@@ -1488,7 +1490,7 @@ Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
return OkStatus();
}
-Status IrEmitter::HandlePartitionId(HloInstruction* hlo) {
+absl::Status IrEmitter::HandlePartitionId(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
assignment_.GetUniqueSlice(hlo, {}));
@@ -1500,7 +1502,7 @@ Status IrEmitter::HandlePartitionId(HloInstruction* hlo) {
return OkStatus();
}
-Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
+absl::Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
assignment_.GetUniqueSlice(hlo, {}));
@@ -1512,7 +1514,7 @@ Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
return OkStatus();
}
-Status IrEmitter::HandleParameter(HloInstruction* parameter) {
+absl::Status IrEmitter::HandleParameter(HloInstruction* parameter) {
VLOG(2) << "HandleParameter: " << parameter->ToString();
return EmitTargetAddressForOp(parameter);
}
@@ -1998,7 +2000,7 @@ absl::StatusOr<bool> IrEmitter::EmitVectorizedReduce(
return true;
}
-Status IrEmitter::HandleReduce(HloInstruction* reduce) {
+absl::Status IrEmitter::HandleReduce(HloInstruction* reduce) {
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
absl::Span<const int64_t> dimensions(reduce->dimensions());
@@ -2027,21 +2029,21 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
return DefaultAction(reduce);
}
-Status IrEmitter::HandleSend(HloInstruction* send) {
+absl::Status IrEmitter::HandleSend(HloInstruction* send) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Send is not implemented on CPU.");
}
-Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
+absl::Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Send-done is not implemented on CPU.");
}
-Status IrEmitter::HandleScatter(HloInstruction*) {
+absl::Status IrEmitter::HandleScatter(HloInstruction*) {
return Unimplemented("Scatter is not implemented on CPUs.");
}
-Status IrEmitter::HandleSlice(HloInstruction* slice) {
+absl::Status IrEmitter::HandleSlice(HloInstruction* slice) {
VLOG(2) << "HandleSlice: " << slice->ToString();
auto operand = slice->operand(0);
// The code below emits a sequential loop nest. For the parallel backend, use
@@ -2178,7 +2180,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
return OkStatus();
}
-Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) {
+absl::Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) {
if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice));
return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice);
@@ -2186,7 +2188,7 @@ Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) {
return DefaultAction(dynamic_slice);
}
-Status IrEmitter::HandleDynamicUpdateSlice(
+absl::Status IrEmitter::HandleDynamicUpdateSlice(
HloInstruction* dynamic_update_slice) {
auto update = dynamic_update_slice->operand(1);
if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
@@ -2203,17 +2205,17 @@ Status IrEmitter::HandleDynamicUpdateSlice(
return DefaultAction(dynamic_update_slice);
}
-Status IrEmitter::HandleRecv(HloInstruction* recv) {
+absl::Status IrEmitter::HandleRecv(HloInstruction* recv) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Recv is not implemented on CPU.");
}
-Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
+absl::Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Recv-done is not implemented on CPU.");
}
-Status IrEmitter::HandlePad(HloInstruction* pad) {
+absl::Status IrEmitter::HandlePad(HloInstruction* pad) {
// CPU backend does not properly handle negative padding but this is ok
// because negative padding should be removed by the algebraic simplifier.
for (auto& padding_dimension : pad->padding_config().dimensions()) {
@@ -2272,7 +2274,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
return OkStatus();
}
-Status IrEmitter::HandleFusion(HloInstruction* fusion) {
+absl::Status IrEmitter::HandleFusion(HloInstruction* fusion) {
auto* root = fusion->fused_expression_root();
if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
@@ -2327,7 +2329,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
}
}
-Status IrEmitter::HandleCall(HloInstruction* call) {
+absl::Status IrEmitter::HandleCall(HloInstruction* call) {
HloComputation* computation = call->to_apply();
llvm::Function* call_ir_function = FindOrDie(
emitted_functions_, ComputationToEmit{computation, allow_reassociation_});
@@ -2372,7 +2374,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
return OkStatus();
}
-Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) {
+absl::Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
std::vector<llvm::Value*> dynamic_dims;
int32_t raw_data_size =
@@ -2400,7 +2402,7 @@ Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) {
// dest_index = delinearize(linearize(i, dynamic_dim), static_dim)
// dest[dest_index] = source[i]
auto loop_body_emitter =
- [&](const llvm_ir::IrArray::Index& array_index) -> Status {
+ [&](const llvm_ir::IrArray::Index& array_index) -> absl::Status {
llvm::Value* source_element =
GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(array_index, &b_);
llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
@@ -2415,7 +2417,7 @@ Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) {
.EmitLoop(IrName(hlo));
}
-Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
+absl::Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
@@ -2463,7 +2465,7 @@ Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
// source_index = delinearize(inearize(i, dynamic_dim), static_dim)
// dest[i] = source[source_index]
auto loop_body_emitter =
- [&](const llvm_ir::IrArray::Index& array_index) -> Status {
+ [&](const llvm_ir::IrArray::Index& array_index) -> absl::Status {
llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
llvm_ir::IrArray::Index source_index(linear_index, input_shape, &b_);
llvm::Value* source_element =
@@ -2480,7 +2482,7 @@ Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
return OkStatus();
}
-Status IrEmitter::HandleTopK(HloInstruction* hlo) {
+absl::Status IrEmitter::HandleTopK(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
const HloInstruction* input = hlo->operand(0);
const int64_t k = hlo->shape().tuple_shapes(0).dimensions().back();
@@ -2520,8 +2522,8 @@ Status IrEmitter::HandleTopK(HloInstruction* hlo) {
}
#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
-Status IrEmitter::HandleOneDnnMatMulCalls(HloInstruction* custom_call,
- std::string runtime_symbol_name) {
+absl::Status IrEmitter::HandleOneDnnMatMulCalls(
+ HloInstruction* custom_call, std::string runtime_symbol_name) {
// We would like to emit LLVM IR for the following function call
// custom_call_target(void* result, void** args)
// args can be thought of an array of pointers allocated on the stack,
@@ -2650,7 +2652,7 @@ Status IrEmitter::HandleOneDnnMatMulCalls(HloInstruction* custom_call,
return OkStatus();
}
-Status IrEmitter::HandleOneDnnLayerNorm(HloInstruction* custom_call) {
+absl::Status IrEmitter::HandleOneDnnLayerNorm(HloInstruction* custom_call) {
// args[0]: ptr to nargs
// args[1]: ptr to ExecutableRunOptions
// args[2]: ptr to OneDnnLayerNormConfig
@@ -2727,7 +2729,7 @@ Status IrEmitter::HandleOneDnnLayerNorm(HloInstruction* custom_call) {
return OkStatus();
}
-Status IrEmitter::HandleOneDnnSoftmax(HloInstruction* custom_call) {
+absl::Status IrEmitter::HandleOneDnnSoftmax(HloInstruction* custom_call) {
auto input = custom_call->operand(0);
llvm_ir::IrArray input_array(GetIrArrayFor(input));
auto input_stack_alloca = GetAllocaAndEmitMemrefInfo(b_, input_array);
@@ -2751,7 +2753,7 @@ Status IrEmitter::HandleOneDnnSoftmax(HloInstruction* custom_call) {
}
#endif // INTEL_MKL && ENABLE_ONEDNN_V3
-Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
+absl::Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
if (custom_call->custom_call_target() == "PadToStatic") {
return HandlePadToStatic(custom_call);
}
@@ -2856,7 +2858,7 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
return OkStatus();
}
-Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
+absl::Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
// Precondition: Condition computation must return a scalar bool.
HloComputation* condition = xla_while->while_condition();
TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
@@ -2867,7 +2869,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
xla_while->shape(),
[this, &xla_while](const Shape& /*subshape*/,
- const ShapeIndex& index) -> Status {
+ const ShapeIndex& index) -> absl::Status {
auto check = [this](const HloInstruction* a, const HloInstruction* b,
const ShapeIndex& index) -> absl::Status {
const BufferAllocation::Slice slice_a =
@@ -3236,7 +3238,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
}
}
-Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
+absl::Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
absl::Span<HloInstruction* const> operands(concatenate->operands());
std::string failure_reason;
TF_ASSIGN_OR_RETURN(
@@ -3253,7 +3255,7 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
return DefaultAction(concatenate);
}
-Status IrEmitter::HandleConditional(HloInstruction* conditional) {
+absl::Status IrEmitter::HandleConditional(HloInstruction* conditional) {
auto branch_index = conditional->operand(0);
int num_branches = conditional->branch_count();
TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) &&
@@ -3362,25 +3364,25 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
return OkStatus();
}
-Status IrEmitter::HandleAfterAll(HloInstruction* after_all) {
+absl::Status IrEmitter::HandleAfterAll(HloInstruction* after_all) {
TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0);
// No code to generate, but we need to emit an address for book-keeping.
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all));
return OkStatus();
}
-Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
+absl::Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
// AddDedendency just forwards its zero-th operand.
emitted_value_[add_dependency] =
GetEmittedValueFor(add_dependency->operand(0));
return OkStatus();
}
-Status IrEmitter::HandleRng(HloInstruction* rng) {
+absl::Status IrEmitter::HandleRng(HloInstruction* rng) {
return Unimplemented("Rng should be expanded for CPU.");
}
-Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) {
+absl::Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) {
VLOG(2) << "RngGetAndUpdateState: " << rng_state->ToString();
llvm::Value* old_state = llvm_ir::RngGetAndUpdateState(
Cast<HloRngGetAndUpdateStateInstruction>(rng_state)->delta(), module_,
@@ -3398,7 +3400,7 @@ Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) {
return OkStatus();
}
-Status IrEmitter::FinishVisit(HloInstruction* root) {
+absl::Status IrEmitter::FinishVisit(HloInstruction* root) {
// When this method is called, we should have already emitted an IR value for
// the root (return) op. The IR value holds the address of the buffer holding
// the value. If the root is a constant or parameter, we perform a memcpy from
@@ -3591,7 +3593,7 @@ bool IsHloVeryCheap(const HloInstruction* hlo) {
}
} // namespace
-Status IrEmitter::Preprocess(HloInstruction* hlo) {
+absl::Status IrEmitter::Preprocess(HloInstruction* hlo) {
VLOG(3) << "Visiting: " << hlo->ToString();
// When profiling is enabled, trace the same HLOs that the profiler does.
if (instruction_to_profile_idx_.count(hlo) ||
@@ -3604,7 +3606,7 @@ Status IrEmitter::Preprocess(HloInstruction* hlo) {
return OkStatus();
}
-Status IrEmitter::Postprocess(HloInstruction* hlo) {
+absl::Status IrEmitter::Postprocess(HloInstruction* hlo) {
if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter);
}
@@ -3764,7 +3766,7 @@ llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
}
}
-Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
+absl::Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
const Shape& target_shape = op->shape();
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
assignment_.GetUniqueTopLevelSlice(op));
@@ -3774,13 +3776,13 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
return OkStatus();
}
-Status IrEmitter::EmitTargetElementLoop(
+absl::Status IrEmitter::EmitTargetElementLoop(
HloInstruction* target_op,
const llvm_ir::ElementGenerator& element_generator) {
return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator);
}
-Status IrEmitter::EmitTargetElementLoop(
+absl::Status IrEmitter::EmitTargetElementLoop(
HloInstruction* target_op, absl::string_view desc,
const llvm_ir::ElementGenerator& element_generator) {
VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
@@ -3833,8 +3835,8 @@ Status IrEmitter::EmitTargetElementLoop(
return OkStatus();
}
-Status IrEmitter::EmitMemcpy(const HloInstruction& source,
- const HloInstruction& destination) {
+absl::Status IrEmitter::EmitMemcpy(const HloInstruction& source,
+ const HloInstruction& destination) {
llvm::Value* source_value = GetEmittedValueFor(&source);
llvm::Value* destination_value = GetEmittedValueFor(&destination);
int64_t source_size = ByteSizeOf(source.shape());
@@ -3844,7 +3846,7 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source,
return OkStatus();
}
-Status IrEmitter::ElementTypesSameAndSupported(
+absl::Status IrEmitter::ElementTypesSameAndSupported(
const HloInstruction& instruction,
absl::Span<const HloInstruction* const> operands,
absl::Span<const PrimitiveType> supported_types) {
@@ -3863,7 +3865,7 @@ Status IrEmitter::ElementTypesSameAndSupported(
return OkStatus();
}
-Status IrEmitter::DefaultAction(HloInstruction* hlo) {
+absl::Status IrEmitter::DefaultAction(HloInstruction* hlo) {
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (const HloInstruction* operand : hlo->operands()) {
operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h
index bc99acd4f2b..21a37181011 100644
--- a/third_party/xla/xla/service/cpu/ir_emitter.h
+++ b/third_party/xla/xla/service/cpu/ir_emitter.h
@@ -124,7 +124,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
llvm::IRBuilder<>* builder() { return &b_; }
// Emit an LLVM global variable for every constant buffer allocation.
- Status EmitConstantGlobals();
+ absl::Status EmitConstantGlobals();
protected:
//
@@ -132,55 +132,57 @@ class IrEmitter : public DfsHloVisitorWithDefault,
//
// Default action which emits code for most operations. Operations which are
// special in some way are handled explicitly in HandleFoo methods.
- Status DefaultAction(HloInstruction* hlo) override;
-
- Status HandleAllGather(HloInstruction* instruction) override;
- Status HandleAllToAll(HloInstruction* instruction) override;
- Status HandleBitcast(HloInstruction* bitcast) override;
- Status HandleConstant(HloInstruction* constant) override;
- Status HandleCopy(HloInstruction* copy) override;
- Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
- Status HandleSelect(HloInstruction* select) override;
- Status HandleDot(HloInstruction* dot) override;
- Status HandleConvolution(HloInstruction* convolution) override;
- Status HandleFft(HloInstruction* fft) override;
- Status HandleAllReduce(HloInstruction* crs) override;
- Status HandleReduceScatter(HloInstruction* crs) override;
- Status HandleCollectivePermute(HloInstruction* crs) override;
- Status HandleInfeed(HloInstruction* instruction) override;
- Status HandleOutfeed(HloInstruction* outfeed) override;
- Status HandleSort(HloInstruction* hlo) override;
- Status HandleParameter(HloInstruction* parameter) override;
- Status HandleReduce(HloInstruction* reduce) override;
- Status HandleReduceWindow(HloInstruction* reduce_window) override;
- Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override;
- Status HandleSend(HloInstruction* send) override;
- Status HandleSendDone(HloInstruction* send_done) override;
- Status HandleSlice(HloInstruction* slice) override;
- Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
- Status HandleDynamicUpdateSlice(
+ absl::Status DefaultAction(HloInstruction* hlo) override;
+
+ absl::Status HandleAllGather(HloInstruction* instruction) override;
+ absl::Status HandleAllToAll(HloInstruction* instruction) override;
+ absl::Status HandleBitcast(HloInstruction* bitcast) override;
+ absl::Status HandleConstant(HloInstruction* constant) override;
+ absl::Status HandleCopy(HloInstruction* copy) override;
+ absl::Status HandleGetTupleElement(
+ HloInstruction* get_tuple_element) override;
+ absl::Status HandleSelect(HloInstruction* select) override;
+ absl::Status HandleDot(HloInstruction* dot) override;
+ absl::Status HandleConvolution(HloInstruction* convolution) override;
+ absl::Status HandleFft(HloInstruction* fft) override;
+ absl::Status HandleAllReduce(HloInstruction* crs) override;
+ absl::Status HandleReduceScatter(HloInstruction* crs) override;
+ absl::Status HandleCollectivePermute(HloInstruction* crs) override;
+ absl::Status HandleInfeed(HloInstruction* instruction) override;
+ absl::Status HandleOutfeed(HloInstruction* outfeed) override;
+ absl::Status HandleSort(HloInstruction* hlo) override;
+ absl::Status HandleParameter(HloInstruction* parameter) override;
+ absl::Status HandleReduce(HloInstruction* reduce) override;
+ absl::Status HandleReduceWindow(HloInstruction* reduce_window) override;
+ absl::Status HandleSelectAndScatter(
+ HloInstruction* select_and_scatter) override;
+ absl::Status HandleSend(HloInstruction* send) override;
+ absl::Status HandleSendDone(HloInstruction* send_done) override;
+ absl::Status HandleSlice(HloInstruction* slice) override;
+ absl::Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
+ absl::Status HandleDynamicUpdateSlice(
HloInstruction* dynamic_update_slice) override;
- Status HandleRecv(HloInstruction* recv) override;
- Status HandleRecvDone(HloInstruction* recv_done) override;
- Status HandlePad(HloInstruction* pad) override;
- Status HandleTuple(HloInstruction* tuple) override;
- Status HandleFusion(HloInstruction* fusion) override;
- Status HandleCall(HloInstruction* call) override;
- Status HandleCustomCall(HloInstruction* custom_call) override;
- Status HandleWhile(HloInstruction* xla_while) override;
- Status HandleConcatenate(HloInstruction* concatenate) override;
- Status HandleConditional(HloInstruction* conditional) override;
- Status HandleScatter(HloInstruction* scatter) override;
- Status HandleAfterAll(HloInstruction* after_all) override;
- Status HandleAddDependency(HloInstruction* add_dependency) override;
- Status HandlePartitionId(HloInstruction* hlo) override;
- Status HandleReplicaId(HloInstruction* hlo) override;
- Status HandleRng(HloInstruction* rng) override;
- Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override;
- Status FinishVisit(HloInstruction* root) override;
-
- Status Preprocess(HloInstruction* hlo) override;
- Status Postprocess(HloInstruction* hlo) override;
+ absl::Status HandleRecv(HloInstruction* recv) override;
+ absl::Status HandleRecvDone(HloInstruction* recv_done) override;
+ absl::Status HandlePad(HloInstruction* pad) override;
+ absl::Status HandleTuple(HloInstruction* tuple) override;
+ absl::Status HandleFusion(HloInstruction* fusion) override;
+ absl::Status HandleCall(HloInstruction* call) override;
+ absl::Status HandleCustomCall(HloInstruction* custom_call) override;
+ absl::Status HandleWhile(HloInstruction* xla_while) override;
+ absl::Status HandleConcatenate(HloInstruction* concatenate) override;
+ absl::Status HandleConditional(HloInstruction* conditional) override;
+ absl::Status HandleScatter(HloInstruction* scatter) override;
+ absl::Status HandleAfterAll(HloInstruction* after_all) override;
+ absl::Status HandleAddDependency(HloInstruction* add_dependency) override;
+ absl::Status HandlePartitionId(HloInstruction* hlo) override;
+ absl::Status HandleReplicaId(HloInstruction* hlo) override;
+ absl::Status HandleRng(HloInstruction* rng) override;
+ absl::Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override;
+ absl::Status FinishVisit(HloInstruction* root) override;
+
+ absl::Status Preprocess(HloInstruction* hlo) override;
+ absl::Status Postprocess(HloInstruction* hlo) override;
// A convenient helper for calling BufferAssignment::GetUniqueSlice.
BufferAllocation::Slice GetAllocationSlice(
@@ -189,16 +191,16 @@ class IrEmitter : public DfsHloVisitorWithDefault,
}
private:
- Status HandleSliceToDynamic(HloInstruction* hlo);
- Status HandlePadToStatic(HloInstruction* hlo);
- Status HandleTopK(HloInstruction* hlo);
- Status HandleAllReduceSingleReplica(HloInstruction* crs);
- Status HandleAllReduceMultipleReplica(HloInstruction* crs);
+ absl::Status HandleSliceToDynamic(HloInstruction* hlo);
+ absl::Status HandlePadToStatic(HloInstruction* hlo);
+ absl::Status HandleTopK(HloInstruction* hlo);
+ absl::Status HandleAllReduceSingleReplica(HloInstruction* crs);
+ absl::Status HandleAllReduceMultipleReplica(HloInstruction* crs);
#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
- Status HandleOneDnnMatMulCalls(HloInstruction* hlo,
- std::string runtime_symbol_name);
- Status HandleOneDnnSoftmax(HloInstruction* hlo);
- Status HandleOneDnnLayerNorm(HloInstruction* hlo);
+ absl::Status HandleOneDnnMatMulCalls(HloInstruction* hlo,
+ std::string runtime_symbol_name);
+ absl::Status HandleOneDnnSoftmax(HloInstruction* hlo);
+ absl::Status HandleOneDnnLayerNorm(HloInstruction* hlo);
#endif // INTEL_MKL && ENABLE_ONEDNN_V3
// Private helper to initialize an IR function for the computation.
void InitializeIrFunction(const std::string& function_name);
@@ -316,7 +318,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Verifies that the element types of all of the given operand instructions
// match and are of one of the given supported types.
- Status ElementTypesSameAndSupported(
+ absl::Status ElementTypesSameAndSupported(
const HloInstruction& instruction,
absl::Span<const HloInstruction* const> operands,
absl::Span<const PrimitiveType> supported_types);
@@ -331,23 +333,23 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// in the loop name.
//
// TODO(jingyue): target_op should be a `const HloInstruction*`.
- Status EmitTargetElementLoop(
+ absl::Status EmitTargetElementLoop(
HloInstruction* target_op,
const llvm_ir::ElementGenerator& element_generator);
- Status EmitTargetElementLoop(
+ absl::Status EmitTargetElementLoop(
HloInstruction* target_op, absl::string_view desc,
const llvm_ir::ElementGenerator& element_generator);
// Emits a memcpy from the source instruction's result value to the
// destination's. Both source and destination must have an entry in the
// emitted_value_ table.
- Status EmitMemcpy(const HloInstruction& source,
- const HloInstruction& destination);
+ absl::Status EmitMemcpy(const HloInstruction& source,
+ const HloInstruction& destination);
// Emits IR to compute the target address of the buffer for the given op.
// After calling this function, you can get a pointer to this buffer by
// calling GetIrArrayForOp or GetEmittedValueFor.
- Status EmitTargetAddressForOp(const HloInstruction* op);
+ absl::Status EmitTargetAddressForOp(const HloInstruction* op);
// Structurizes "array_elements" into an MD array that represents "shape".
// This is a recursive function, and "dimension_index" indicates the index of
@@ -652,8 +654,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Emit IR to transfer between a {infeed,outfeed} buffer and an in-program
// address.
- Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
- llvm::Value* program_buffer_address);
+ absl::Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
+ llvm::Value* program_buffer_address);
// Returns a ConstExpr bitcast.
llvm::Constant* EmitGlobalForLiteral(const Literal& literal);
diff --git a/third_party/xla/xla/service/cpu/ir_function.cc b/third_party/xla/xla/service/cpu/ir_function.cc
index 69961501a37..cf66fe382e7 100644
--- a/third_party/xla/xla/service/cpu/ir_function.cc
+++ b/third_party/xla/xla/service/cpu/ir_function.cc
@@ -234,7 +234,7 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
// Emits a call to a runtime fork/join function which dispatches parallel
// calls to 'parallel_function' (and joins threads before returning).
-Status EmitCallToParallelForkJoin(
+absl::Status EmitCallToParallelForkJoin(
const std::vector<llvm::Value*>& arguments, const Shape& shape,
absl::Span<const int64_t> dimension_partition_counts, llvm::IRBuilder<>* b,
llvm::Function* parallel_function, absl::string_view name) {
diff --git a/third_party/xla/xla/service/cpu/ir_function.h b/third_party/xla/xla/service/cpu/ir_function.h
index 47034675bdf..73a85867438 100644
--- a/third_party/xla/xla/service/cpu/ir_function.h
+++ b/third_party/xla/xla/service/cpu/ir_function.h
@@ -145,7 +145,7 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
// Emits a call to a runtime fork/join function which dispatches parallel
// calls to 'parallel_function' (and joins threads before returning).
-Status EmitCallToParallelForkJoin(
+absl::Status EmitCallToParallelForkJoin(
const std::vector<llvm::Value*>& arguments, const Shape& shape,
absl::Span<const int64_t> dimension_partition_counts, llvm::IRBuilder<>* b,
llvm::Function* parallel_function, absl::string_view name);
diff --git a/third_party/xla/xla/service/cpu/mlir_emitter.cc b/third_party/xla/xla/service/cpu/mlir_emitter.cc
index 8d3a28815ee..396f567cf2d 100644
--- a/third_party/xla/xla/service/cpu/mlir_emitter.cc
+++ b/third_party/xla/xla/service/cpu/mlir_emitter.cc
@@ -82,7 +82,7 @@ void BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value *> *args,
}
} // namespace
-Status EmitMlirFuncAndCall(
+absl::Status EmitMlirFuncAndCall(
mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape,
llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr,
llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name,
diff --git a/third_party/xla/xla/service/cpu/mlir_emitter.h b/third_party/xla/xla/service/cpu/mlir_emitter.h
index c7a8480e4e9..af1b1626af2 100644
--- a/third_party/xla/xla/service/cpu/mlir_emitter.h
+++ b/third_party/xla/xla/service/cpu/mlir_emitter.h
@@ -32,7 +32,7 @@ namespace cpu {
// `emitter` and create a call, passing it the buffers defined by
// resultShape/resultPtr and operandShapes/operandPtrs. The function is added to
// the LLVM module at `b`s insertion point.
-Status EmitMlirFuncAndCall(
+absl::Status EmitMlirFuncAndCall(
mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape,
llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr,
llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name,
diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc
index a792a5ba521..1fc542088d1 100644
--- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc
+++ b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc
@@ -43,7 +43,7 @@ namespace {
namespace m = match;
namespace pu = ::xla::cpu::onednn_pattern_utils_internal;
-inline Status ValidateDotDimensionNumbers(
+inline absl::Status ValidateDotDimensionNumbers(
const DotDimensionNumbers& dim_numbers) {
// Checks some invariants that do not hold in general, but DotDecomposer
// should have established for us.
@@ -276,8 +276,8 @@ auto GELUActivation(HloInstruction* instr, HloInstruction** src) {
// OneDNN matmul can fuse add operation with automatic broadcasting along the
// addend's dimensions that are 1s. When compatible, Broadcast can be replaced
// by Bitcast, which is much cheaper. Compute new shape for the Bitcast.
-StatusOr<Shape> AdjustBiasShape(const HloInstruction* broadcast_instr,
- const Shape& dot_shape) {
+absl::StatusOr<Shape> AdjustBiasShape(const HloInstruction* broadcast_instr,
+ const Shape& dot_shape) {
if (broadcast_instr->opcode() != HloOpcode::kBroadcast) {
return absl::InvalidArgumentError(
"Hlo instruction is not a Broadcast insruction.");
@@ -428,7 +428,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor {
public:
// Matches patterns for possible MatMul fusions that are supported by oneDNN
// library. Matched HLO instruction(s) are replaced by custom call.
- Status HandleDot(HloInstruction* instr) override {
+ absl::Status HandleDot(HloInstruction* instr) override {
HloInstruction* dot_instr;
auto pattern = m::Op(&dot_instr).WithOpcode(HloOpcode::kDot);
if (!Match(instr, pattern)) return OkStatus();
@@ -463,7 +463,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor {
return OkStatus();
}
- Status HandleAdd(HloInstruction* instr) override {
+ absl::Status HandleAdd(HloInstruction* instr) override {
// Try to do a fusion for Dot(onednn-matmul) + Add. However,
// HLO Add instruction might receive the addends after additional
// processing like Broadcast, Bitcast, Convert, etc. is applied to the raw
@@ -616,7 +616,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor {
return OkStatus();
}
- Status HandleMaximum(HloInstruction* instr) override {
+ absl::Status HandleMaximum(HloInstruction* instr) override {
HloInstruction* matmul_call;
HloInstruction* intermediate_instr = nullptr;
HloInstruction* optional_bitcast = nullptr;
@@ -699,7 +699,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor {
return OkStatus();
}
- Status HandleMultiply(HloInstruction* instr) override {
+ absl::Status HandleMultiply(HloInstruction* instr) override {
HloInstruction* matmul_call;
HloInstruction* intermediate_instr = nullptr;
HloInstruction* src;
@@ -760,10 +760,11 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor {
return OkStatus();
}
- Status FuseActivation(OneDnnMatMulConfig_FusionKind kind,
- HloInstruction* activation, HloInstruction* matmul,
- HloInstruction* intermediate_instr = nullptr,
- HloInstruction* optional_bitcast = nullptr) {
+ absl::Status FuseActivation(OneDnnMatMulConfig_FusionKind kind,
+ HloInstruction* activation,
+ HloInstruction* matmul,
+ HloInstruction* intermediate_instr = nullptr,
+ HloInstruction* optional_bitcast = nullptr) {
TF_ASSIGN_OR_RETURN(auto backend_config,
matmul->backend_config<BackendConfig>());
auto* matmul_config = backend_config.mutable_onednn_matmul_config();
@@ -811,7 +812,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor {
// lhs: [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim]
// rhs: [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1]
// result: [batch_dims] to [batch_dims,1,1]
- StatusOr<HloInstruction*> ReconfigureDotDimensions(
+ absl::StatusOr<HloInstruction*> ReconfigureDotDimensions(
HloInstruction* dot_instr) {
HloInstruction* lhs = dot_instr->mutable_operand(0);
HloInstruction* rhs = dot_instr->mutable_operand(1);
@@ -900,7 +901,7 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor {
#endif
}
- Status HandleCustomCall(HloInstruction* custom_call) override {
+ absl::Status HandleCustomCall(HloInstruction* custom_call) override {
HloInstruction* matmul;
if (Match(custom_call, OneDnnMatmulInstr(&matmul))) {
return HandleCustomCallInternal<dnnl::matmul::primitive_desc>(
@@ -911,7 +912,7 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor {
}
template <typename PrimDesc>
- Status HandleCustomCallInternal(HloInstruction* custom_call) {
+ absl::Status HandleCustomCallInternal(HloInstruction* custom_call) {
auto scratch_add = AddScratch<PrimDesc>(custom_call);
if (scratch_add.ok()) {
custom_call = *scratch_add;
@@ -926,10 +927,10 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor {
}
template <typename>
- Status SetWeightsPrepack(HloInstruction*, bool);
+ absl::Status SetWeightsPrepack(HloInstruction*, bool);
template <typename>
- Status SetUserScratch(HloInstruction*, bool);
+ absl::Status SetUserScratch(HloInstruction*, bool);
template <typename>
bool GetWeightsPrepack(HloInstruction*);
@@ -940,7 +941,7 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor {
// Add scratch for matmul by changing the result of custom-call to
// tuple(result, scratch)
template <typename PrimDesc>
- StatusOr<HloInstruction*> AddScratch(HloInstruction* custom_call) {
+ absl::StatusOr<HloInstruction*> AddScratch(HloInstruction* custom_call) {
if (GetUserScratch<PrimDesc>(custom_call)) {
return custom_call;
}
@@ -964,7 +965,7 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor {
}
template <typename PrimDesc>
- StatusOr<HloInstruction*> PrepackWeights(HloInstruction* custom_call) {
+ absl::StatusOr<HloInstruction*> PrepackWeights(HloInstruction* custom_call) {
if (GetWeightsPrepack<PrimDesc>(custom_call)) {
return custom_call;
}
@@ -1042,7 +1043,7 @@ EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GetWeightsPrepack,
#define EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SETTER, PRIM_DESC, CONFIG_TYPE, \
CONFIG, FIELD) \
template <> \
- inline Status OneDnnPostRewriteVisitor::SETTER<PRIM_DESC>( \
+ inline absl::Status OneDnnPostRewriteVisitor::SETTER<PRIM_DESC>( \
HloInstruction * custom_call, bool value) { \
TF_ASSIGN_OR_RETURN(auto backend_config, \
custom_call->backend_config<BackendConfig>()); \
@@ -1060,7 +1061,7 @@ EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetUserScratch,
OneDnnMatMulConfig, onednn_matmul_config,
user_scratchpad);
-StatusOr<bool> OneDnnMatMulRewriter::Run(
+absl::StatusOr<bool> OneDnnMatMulRewriter::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
OneDnnMatMulRewriteVisitor visitor;
diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h
index 36cab7ee949..eaeea210021 100644
--- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h
+++ b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h
@@ -41,7 +41,7 @@ class OneDnnMatMulRewriter : public HloModulePass {
absl::string_view name() const override { return "onednn-matmul-rewriter"; }
using HloPassInterface::Run;
- StatusOr<bool> Run(
+ absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
diff --git a/third_party/xla/xla/service/cpu/onednn_memory_util.cc b/third_party/xla/xla/service/cpu/onednn_memory_util.cc
index fd0b6f92795..6ab913161a7 100644
--- a/third_party/xla/xla/service/cpu/onednn_memory_util.cc
+++ b/third_party/xla/xla/service/cpu/onednn_memory_util.cc
@@ -169,7 +169,7 @@ int64_t MemrefInfo::GetChannels() const { return pod_->dims[pod_->rank - 1]; }
int64_t MemrefInfo::GetRank() const { return pod_->rank; }
-StatusOr<dnnl::memory::desc> TransposeLastTwoDims(
+absl::StatusOr<dnnl::memory::desc> TransposeLastTwoDims(
const dnnl::memory::desc& md) {
int64_t ndims = md.get_ndims();
if (ndims < 2) {
diff --git a/third_party/xla/xla/service/cpu/onednn_memory_util.h b/third_party/xla/xla/service/cpu/onednn_memory_util.h
index 5793c1bdf4f..c0c956a32dc 100644
--- a/third_party/xla/xla/service/cpu/onednn_memory_util.h
+++ b/third_party/xla/xla/service/cpu/onednn_memory_util.h
@@ -119,7 +119,8 @@ class MemrefInfo {
MemrefInfoPOD* pod_;
};
-StatusOr<dnnl::memory::desc> TransposeLastTwoDims(const dnnl::memory::desc& md);
+absl::StatusOr<dnnl::memory::desc> TransposeLastTwoDims(
+ const dnnl::memory::desc& md);
#define TRANSPOSE_LAST_TWO_DIMS_IF(pred, mem_desc) \
if (pred) { \
auto trans_mem_desc = TransposeLastTwoDims(mem_desc); \
diff --git a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc
index 06c2137a782..25530aa59c8 100644
--- a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc
+++ b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc
@@ -385,7 +385,7 @@ bool MatchFlaxLayerNorm(HloInstruction* instr, HloInstruction** src,
class OneDnnOpsRewriterVisitor : public DfsHloRewriteVisitor {
public:
- Status HandleAdd(HloInstruction* instr) override {
+ absl::Status HandleAdd(HloInstruction* instr) override {
HloInstruction *src, *scale, *bias;
float eps;
bool is_bf16orfp16_convert = false;
@@ -445,7 +445,7 @@ class OneDnnOpsRewriterVisitor : public DfsHloRewriteVisitor {
return OkStatus();
}
- Status HandleConvert(HloInstruction* instr) override {
+ absl::Status HandleConvert(HloInstruction* instr) override {
HloInstruction* custom_call;
HloInstruction* convert_instr;
auto pattern =
@@ -478,7 +478,7 @@ class OneDnnOpsRewriterVisitor : public DfsHloRewriteVisitor {
return OkStatus();
}
- Status HandleDivide(HloInstruction* divide_instr) override {
+ absl::Status HandleDivide(HloInstruction* divide_instr) override {
if (divide_instr->HasControlDependencies()) return OkStatus();
if (!IsSupportedType(divide_instr->shape().element_type()))
return OkStatus();
@@ -495,7 +495,7 @@ class OneDnnOpsRewriterVisitor : public DfsHloRewriteVisitor {
}
};
-StatusOr<bool> OneDnnOpsRewriter::Run(
+absl::StatusOr<bool> OneDnnOpsRewriter::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
OneDnnOpsRewriterVisitor visitor;
diff --git a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.h b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.h
index ea62f33ebcf..8e777d8889a 100644
--- a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.h
+++ b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.h
@@ -32,7 +32,7 @@ class OneDnnOpsRewriter : public HloModulePass {
absl::string_view name() const override { return "onednn-ops-rewriter"; }
using HloPassInterface::Run;
- StatusOr<bool> Run(
+ absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
};
diff --git a/third_party/xla/xla/service/cpu/onednn_rewriter.h b/third_party/xla/xla/service/cpu/onednn_rewriter.h
index a1ba3205c96..53fd5c0f977 100644
--- a/third_party/xla/xla/service/cpu/onednn_rewriter.h
+++ b/third_party/xla/xla/service/cpu/onednn_rewriter.h
@@ -33,7 +33,7 @@ class OneDnnRewriter : public HloModulePass {
absl::string_view name() const override { return "onednn-rewriter"; }
using HloPassInterface::Run;
- StatusOr<bool> Run(
+ absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
};
diff --git a/third_party/xla/xla/service/cpu/parallel_task_assignment.cc b/third_party/xla/xla/service/cpu/parallel_task_assignment.cc
index 6d103759585..b82707a367e 100644
--- a/third_party/xla/xla/service/cpu/parallel_task_assignment.cc
+++ b/third_party/xla/xla/service/cpu/parallel_task_assignment.cc
@@ -133,7 +133,8 @@ ParallelTaskAssignment::ParallelTaskAssignment(
// Run cost analysis on 'module'.
auto cost_analysis = std::make_unique<HloCostAnalysis>(shape_size);
HloComputation* computation = module->entry_computation();
- Status status = computation->root_instruction()->Accept(cost_analysis.get());
+ absl::Status status =
+ computation->root_instruction()->Accept(cost_analysis.get());
if (status.ok()) {
// Set default cost model based on 'cost_analysis'.
cost_model_ = std::make_unique<DefaultCostModel>(