aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorA Googler <no-reply@google.com>2024-02-14 19:02:00 -0800
committerBlaze Rules Copybara <blaze-rules@google.com>2024-02-14 19:02:34 -0800
commitda76c59ad0e8be403e2982e7232047736343574b (patch)
tree1eeb7142ffece1ba9e26e518a5c2490c7faf0368
parentd86056eb6622cabea8ff668e76afb03d52ac3775 (diff)
downloadbazelbuild-rules_testing-da76c59ad0e8be403e2982e7232047736343574b.tar.gz
Add support for provider maps to rules_testing
Currently, If you are testing a custom provider, the API requires you to write: env.expect.that_target(targets.foo).provider(FooInfo, factory=FooFactory)... env.expect.that_target(targets.foo).provider(BarInfo, factory=BarFactory)... This can get very tedious, and more importantly, is not very safe, since you can write: env.expect.that_target(targets.foo).provider(FooInfo, factory=FooFactory)... env.expect.that_target(targets.foo).provider(BarInfo, factory=FooFactory)... Additionally, custom types are always rendered as "<provider>". To solve this, we add the ability to directly specify a list of factories for custom types in your test. analysis_test( ..., provider_factories = [struct(type = FooInfo, name = "FooInfo", factory = FooFactory)] ) PiperOrigin-RevId: 607174706
-rw-r--r--CHANGELOG.md5
-rw-r--r--lib/private/analysis_test.bzl21
-rw-r--r--lib/private/target_subject.bzl66
-rw-r--r--tests/truth_tests.bzl3
4 files changed, 71 insertions, 24 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 525a958..13adde9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -15,6 +15,11 @@
they have the usual target under test aspects applied. This allows
testing multiple targets in one test with a mixture of configurations.
([#67](https://github.com/bazelbuild/rules_testing/issues/67))
+ * `analysis_test` now takes the parameter `provider_subject_factories`.
+ If you want to perform assertions on custom providers, you no longer need
+ to use the factory parameter each time you want to retrieve the provider.
+ instead, you now write `analysis_test(..., provider_subject_factories = [
+ type = FooInfo, name = "FooInfo", factory = FooSubjectFactory])`.
## [0.5.0] - 2023-10-04
diff --git a/lib/private/analysis_test.bzl b/lib/private/analysis_test.bzl
index c491ebb..df2b647 100644
--- a/lib/private/analysis_test.bzl
+++ b/lib/private/analysis_test.bzl
@@ -21,6 +21,7 @@ load("@bazel_skylib//lib:dicts.bzl", "dicts")
load("@bazel_skylib//lib:types.bzl", "types")
load("//lib:truth.bzl", "truth")
load("//lib:util.bzl", "recursive_testing_aspect", "testing_aspect")
+load("//lib/private:target_subject.bzl", "PROVIDER_SUBJECT_FACTORIES")
load("//lib/private:util.bzl", "get_test_name_from_function")
def _fail(env, msg):
@@ -37,7 +38,7 @@ def _fail(env, msg):
print(full_msg)
env.failures.append(full_msg)
-def _begin_analysis_test(ctx):
+def _begin_analysis_test(ctx, provider_subject_factories):
"""Begins a unit test.
This should be the first function called in a unit test implementation
@@ -48,6 +49,10 @@ def _begin_analysis_test(ctx):
Args:
ctx: The Starlark context. Pass the implementation function's `ctx` argument
in verbatim.
+ provider_subject_factories: list of ProviderSubjectFactory structs, these are
+ additional provider factories on top of built in ones.
+ See analysis_test's provider_subject_factory arg for more details on
+ the type.
Returns:
An analysis_test "environment" struct. The following fields are public:
@@ -86,6 +91,7 @@ def _begin_analysis_test(ctx):
truth_env = struct(
ctx = ctx,
fail = lambda msg: _fail(failures_env, msg),
+ provider_subject_factories = PROVIDER_SUBJECT_FACTORIES + provider_subject_factories,
)
analysis_test_env = struct(
ctx = ctx,
@@ -126,7 +132,8 @@ def analysis_test(
fragments = [],
config_settings = {},
extra_target_under_test_aspects = [],
- collect_actions_recursively = False):
+ collect_actions_recursively = False,
+ provider_subject_factories = []):
"""Creates an analysis test from its implementation function.
An analysis test verifies the behavior of a "real" rule target by examining
@@ -189,6 +196,7 @@ def analysis_test(
analysis test target itself (e.g. common attributes like `tags`,
`target_compatible_with`, or attributes from `attrs`). Note that these
are for the analysis test target itself, not the target under test.
+
fragments: An optional list of fragment names that can be used to give rules access to
language-specific parts of configuration.
config_settings: A dictionary of configuration settings to change for the target under
@@ -202,6 +210,13 @@ def analysis_test(
in addition to those set up by default for the test harness itself.
collect_actions_recursively: If true, runs testing_aspect over all attributes, otherwise
it is only applied to the target under test.
+ provider_subject_factories: Optional list of ProviderSubjectFactory structs,
+ these are additional provider factories on top of built in ones.
+ A ProviderSubjectFactory is a struct with the following fields:
+ * type: A provider object, e.g. the callable FooInfo object
+ * name: A human-friendly name of the provider (eg. "FooInfo")
+ * factory: A callable to convert an instance of the provider to a
+ subject; see TargetSubject.provider()'s factory arg for the signature.
Returns:
(None)
@@ -290,7 +305,7 @@ def analysis_test(
)
def wrapped_impl(ctx):
- env, target = _begin_analysis_test(ctx)
+ env, target = _begin_analysis_test(ctx, provider_subject_factories)
impl(env, target)
return _end_analysis_test(env)
diff --git a/lib/private/target_subject.bzl b/lib/private/target_subject.bzl
index 47d8b94..0644ce1 100644
--- a/lib/private/target_subject.bzl
+++ b/lib/private/target_subject.bzl
@@ -187,7 +187,7 @@ def _target_subject_has_provider(self, provider):
if self.meta.has_provider(self.target, provider):
return
self.meta.add_failure(
- "expected to have provider: {}".format(_provider_name(provider)),
+ "expected to have provider: {}".format(_provider_subject_factory(self, provider).name),
"but provider was not found",
)
@@ -233,23 +233,30 @@ def _target_subject_provider(self, provider_key, factory = None):
the subject for the found provider. Required if the provider key is
not an inherently supported provider. It must have the following
signature: `def factory(value, /, *, meta)`.
+ Additional types of providers can be pre-registered by using the
+ `provider_subject_factories` arg of `analysis_test`.
Returns:
A subject wrapper of the provider value.
"""
- if not factory:
- for key, value in _PROVIDER_SUBJECT_FACTORIES:
- if key == provider_key:
- factory = value
- break
+ if factory:
+ provider_subject_factory = struct(
+ type = provider_key,
+ # str(provider_key) just returns "<provider>", which isn't helpful.
+ # For lack of a better option, just call it unknown
+ provider_name = "<Unknown provider>",
+ factory = factory,
+ )
+ else:
+ provider_subject_factory = _provider_subject_factory(self, provider_key)
- if not factory:
- fail("Unsupported provider: {}".format(provider_key))
+ if not provider_subject_factory.factory:
+ fail("Unsupported provider: {}".format(provider_subject_factory.name))
info = self.target[provider_key]
- return factory(
+ return provider_subject_factory.factory(
info,
- meta = self.meta.derive("provider({})".format(provider_key)),
+ meta = self.meta.derive("provider({})".format(provider_subject_factory.name)),
)
def _target_subject_action_generating(self, short_path):
@@ -385,18 +392,35 @@ def _target_subject_attr(self, name, *, factory = None):
meta = self.meta.derive("attr({})".format(name)),
)
-# Providers aren't hashable, so we have to use a list of (key, value)
-_PROVIDER_SUBJECT_FACTORIES = [
- (InstrumentedFilesInfo, InstrumentedFilesInfoSubject.new),
- (RunEnvironmentInfo, RunEnvironmentInfoSubject.new),
- (testing.ExecutionInfo, ExecutionInfoSubject.new),
-]
+def _provider_subject_factory(self, provider):
+ for provider_subject_factory in self.meta.env.provider_subject_factories:
+ if provider_subject_factory.type == provider:
+ return provider_subject_factory
-def _provider_name(provider):
- # This relies on implementation details of how Starlark represents
- # providers, and isn't entirely accurate, but works well enough
- # for error messages.
- return str(provider).split("<function ")[1].split(">")[0]
+ return struct(
+ type = provider,
+ name = "<Unknown provider>",
+ factory = None,
+ )
+
+# Providers aren't hashable, so we have to use a list of structs.
+PROVIDER_SUBJECT_FACTORIES = [
+ struct(
+ type = InstrumentedFilesInfo,
+ name = "InstrumentedFilesInfo",
+ factory = InstrumentedFilesInfoSubject.new,
+ ),
+ struct(
+ type = RunEnvironmentInfo,
+ name = "RunEnvironmentInfo",
+ factory = RunEnvironmentInfoSubject.new,
+ ),
+ struct(
+ type = testing.ExecutionInfo,
+ name = "testing.ExecutionInfo",
+ factory = ExecutionInfoSubject.new,
+ ),
+]
# We use this name so it shows up nice in docs.
# buildifier: disable=name-conventions
diff --git a/tests/truth_tests.bzl b/tests/truth_tests.bzl
index 322c528..a544034 100644
--- a/tests/truth_tests.bzl
+++ b/tests/truth_tests.bzl
@@ -25,11 +25,13 @@ _IS_BAZEL_6_OR_HIGHER = (testing.ExecutionInfo == testing.ExecutionInfo)
_suite = []
def _fake_env(env):
+ provider_subject_factories = env.expect.meta.env.provider_subject_factories
failures = []
env1 = struct(
ctx = env.ctx,
failures = failures,
fail = lambda msg: failures.append(msg), # Silent fail
+ provider_subject_factories = provider_subject_factories,
)
env2 = struct(
ctx = env.ctx,
@@ -37,6 +39,7 @@ def _fake_env(env):
fail = lambda msg: failures.append(msg), # Silent fail
expect = truth.expect(env1),
reset = lambda: failures.clear(),
+ provider_subject_factories = provider_subject_factories,
)
return env2