From 45ec96500337a467a8eb94b661340d032602d336 Mon Sep 17 00:00:00 2001 From: tammy-baylis-swi Date: Mon, 4 May 2026 12:28:11 -0700 Subject: [PATCH] Update unit tests for traceflags 02,03 --- tests/unit/test_oboe/test_oboe_sampler.py | 63 +++++++++++++++++++++++ tests/unit/test_propagator.py | 58 ++++++++++++++++++++- tests/unit/test_w3c_transformer.py | 14 ++++- 3 files changed, 131 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_oboe/test_oboe_sampler.py b/tests/unit/test_oboe/test_oboe_sampler.py index 880fe9009..93de268fb 100644 --- a/tests/unit/test_oboe/test_oboe_sampler.py +++ b/tests/unit/test_oboe/test_oboe_sampler.py @@ -492,6 +492,69 @@ def test_respects_sw_not_sampled_over_w3c_sampled(self, sample_through_always_se get_current_span(ctxt).get_span_context().span_id, "016x") check_counters(sample_through_always_set, ["trace.service.request_count"]) + def test_respects_sw_random_trace_id_not_sampled_02( + self, sample_through_always_set + ): + generator = RandomIdGenerator() + trace_id = generator.generate_trace_id() + span_id = generator.generate_span_id() + trace_state = TraceState([("sw", f"{span_id:016x}-02")]) + span_context = trace.SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=True, + trace_flags=TraceFlags.SAMPLED, + trace_state=trace_state, + ) + ctxt = trace.set_span_in_context(trace.NonRecordingSpan(span_context)) + sample = sample_through_always_set.should_sample( + ctxt, + get_current_span(ctxt).get_span_context().trace_id, + "respects_sw_random_trace_id_not_sampled_02", + ) + assert not sample.decision.is_sampled() + assert sample.decision.is_recording() + assert sample.attributes.get("sw.tracestate_parent_id") == format( + get_current_span(ctxt).get_span_context().span_id, "016x" + ) + assert "-02" in sample.attributes.get(TRACESTATE_CAPTURE_ATTRIBUTE) + check_counters(sample_through_always_set, ["trace.service.request_count"]) + + def test_respects_sw_random_trace_id_sampled_03( + self, sample_through_always_set + ): + generator = RandomIdGenerator() + trace_id = generator.generate_trace_id() + span_id = generator.generate_span_id() + trace_state = TraceState([("sw", f"{span_id:016x}-03")]) + span_context = trace.SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=True, + trace_flags=TraceFlags.DEFAULT, + trace_state=trace_state, + ) + ctxt = trace.set_span_in_context(trace.NonRecordingSpan(span_context)) + sample = sample_through_always_set.should_sample( + ctxt, + get_current_span(ctxt).get_span_context().trace_id, + "respects_sw_random_trace_id_sampled_03", + ) + assert sample.decision.is_sampled() + assert sample.decision.is_recording() + assert sample.attributes.get("sw.tracestate_parent_id") == format( + get_current_span(ctxt).get_span_context().span_id, "016x" + ) + assert "-03" in sample.attributes.get(TRACESTATE_CAPTURE_ATTRIBUTE) + check_counters( + sample_through_always_set, + [ + "trace.service.request_count", + "trace.service.tracecount", + "trace.service.through_trace_count", + ], + ) + class TestEntrySpanWithValidSwContextSampleThroughAlwaysUnset: def test_records_but_does_not_sample_when_SAMPLE_START_set(self): diff --git a/tests/unit/test_propagator.py b/tests/unit/test_propagator.py index 03eb71198..f3ba8f785 100644 --- a/tests/unit/test_propagator.py +++ b/tests/unit/test_propagator.py @@ -57,7 +57,7 @@ def test_extract_existing_context(self): assert actual_xto.options_header == "foo" assert actual_xto.signature == "bar" - def mock_otel_context(self, mocker, valid_span_id=True): + def mock_otel_context(self, mocker, valid_span_id=True, trace_flags=0x01): """Shared mocks for OTel trace context""" # Mock OTel trace API and current span context mock_get_span_context = mocker.Mock() @@ -65,7 +65,7 @@ def mock_otel_context(self, mocker, valid_span_id=True): mock_get_span_context.configure_mock( **{ "span_id": 0x1000100010001000, - "trace_flags": 0x01, + "trace_flags": trace_flags, } ) else: @@ -136,6 +136,60 @@ def test_inject_no_tracestate_new_tracestate(self, mocker): ), ]) + def test_inject_no_tracestate_new_tracestate_random_not_sampled_flag( + self, mocker + ): + """New tracestate preserves non-sampled random flag 02""" + self.mock_otel_context(mocker, True, trace_flags=0x02) + mock_carrier = dict() + mock_context = mocker.Mock() + mock_setter = mocker.Mock() + mock_set = mocker.Mock() + mock_setter.configure_mock( + **{ + "set": mock_set + } + ) + SolarWindsPropagator().inject( + mock_carrier, + mock_context, + mock_setter, + ) + mock_set.assert_has_calls([ + call( + mock_carrier, + "tracestate", + TraceState([("sw", "1000100010001000-02")]).to_header(), + ), + ]) + + def test_inject_no_tracestate_new_tracestate_random_sampled_flag( + self, mocker + ): + """New tracestate preserves sampled random flag 03""" + self.mock_otel_context(mocker, True, trace_flags=0x03) + mock_carrier = dict() + mock_context = mocker.Mock() + mock_setter = mocker.Mock() + mock_set = mocker.Mock() + mock_setter.configure_mock( + **{ + "set": mock_set + } + ) + SolarWindsPropagator().inject( + mock_carrier, + mock_context, + mock_setter, + ) + mock_set.assert_has_calls([ + call( + mock_carrier, + "tracestate", + TraceState([("sw", "1000100010001000-03")]).to_header(), + ), + ]) + def test_inject_existing_tracestate_no_sw(self, mocker): """sw added to start, foo=bar kept, xtrace_options_response removed""" self.mock_otel_context(mocker, True) diff --git a/tests/unit/test_w3c_transformer.py b/tests/unit/test_w3c_transformer.py index 6327604cc..8701aaeae 100644 --- a/tests/unit/test_w3c_transformer.py +++ b/tests/unit/test_w3c_transformer.py @@ -34,8 +34,18 @@ def test_span_id_from_sw(self): def test_span_id_from_sw_invalid_type_returns_zero_fallback(self): assert W3CTransformer.span_id_from_sw(None) == "{:016x}".format(0) - def test_trace_flags_from_int(self): - assert W3CTransformer.trace_flags_from_int(1) == "01" + @pytest.mark.parametrize( + "trace_flags,expected", + [ + (0x00, "00"), + (0x01, "01"), + (0x02, "02"), + (0x03, "03"), + (0xAB, "ab"), + ], + ) + def test_trace_flags_from_int(self, trace_flags, expected): + assert W3CTransformer.trace_flags_from_int(trace_flags) == expected def test_traceparent_from_context(self, span_context): assert W3CTransformer.traceparent_from_context(span_context) \