Skip to content

Commit cea418a

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Only log elapsed time during tracing for top level jit
PiperOrigin-RevId: 890643915
1 parent 6735ea0 commit cea418a

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

jax/_src/pjit.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from collections import defaultdict
1818
from collections.abc import Callable, Sequence, Iterable
19+
import contextlib
1920
from dataclasses import dataclass, replace
2021
from functools import partial
2122
import inspect
@@ -511,9 +512,12 @@ def _trace_for_jit(
511512

512513
qdd_token = _qdd_cache_index(fun, in_type.vals) # represents qdd state context
513514

514-
with dispatch.log_elapsed_time(
515-
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
516-
fun_name(fun), event=dispatch.JAXPR_TRACE_EVENT):
515+
elapsed_time_ctx = (
516+
dispatch.log_elapsed_time(
517+
"Finished tracing {fun_name} for jit in {elapsed_time:.9f} sec",
518+
fun_name(fun), event=dispatch.JAXPR_TRACE_EVENT)
519+
if core.trace_state_clean() else contextlib.nullcontext())
520+
with elapsed_time_ctx:
517521
if ji.use_resource_env: # pjit
518522
with (_internal_use_concrete_mesh(ctx_mesh),
519523
mesh_lib.use_abstract_mesh(ctx_mesh.abstract_mesh)):

tests/api_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4745,7 +4745,7 @@ def add(x):
47454745
add(inp)
47464746
tracing_add_count = 0
47474747
for m in cm.output:
4748-
if 'Finished tracing + transforming add for pjit' in m:
4748+
if 'Finished tracing add for jit' in m:
47494749
tracing_add_count += 1
47504750
self.assertEqual(tracing_add_count, 2)
47514751

tests/logging_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def test_debug_logging(self):
149149
with jax_debug_log_modules("jax"):
150150
with capture_jax_logs() as log_output:
151151
jax.jit(lambda x: x + 1)(1)
152-
self.assertIn("Finished tracing + transforming", log_output.getvalue())
152+
self.assertIn("Finished tracing", log_output.getvalue())
153153
self.assertIn("Compiling jit(<lambda>)", log_output.getvalue())
154154

155155
# Turn off all debug logging.
@@ -162,7 +162,7 @@ def test_debug_logging(self):
162162
with jax_debug_log_modules("jax._src.dispatch"):
163163
with capture_jax_logs() as log_output:
164164
jax.jit(lambda x: x + 1)(1)
165-
self.assertIn("Finished tracing + transforming", log_output.getvalue())
165+
self.assertIn("Finished tracing", log_output.getvalue())
166166
self.assertNotIn("Compiling jit(<lambda>)", log_output.getvalue())
167167

168168
# Turn everything off again.
@@ -233,7 +233,7 @@ def test_subprocess_toggling_logging_level(self):
233233
log_output_verbose = log_output[:m.start()]
234234
log_output_silent = log_output[m.end():]
235235

236-
self.assertIn("Finished tracing + transforming <lambda> for pjit",
236+
self.assertIn("Finished tracing <lambda> for jit",
237237
log_output_verbose)
238238
self.assertEqual(log_output_silent, "")
239239

@@ -256,7 +256,7 @@ def test_subprocess_double_logging_absent(self):
256256
# only one tracing line should be printed, if there's more than one
257257
# then logs are printing duplicated
258258
self.assertLen([line for line in log_lines
259-
if "Finished tracing + transforming" in line], 1)
259+
if "Finished tracing" in line], 1)
260260

261261
@jtu.skip_on_devices("tpu")
262262
@unittest.skipIf(platform.system() == "Windows",

0 commit comments

Comments
 (0)