@@ -149,7 +149,7 @@ def test_debug_logging(self):
149149 with jax_debug_log_modules ("jax" ):
150150 with capture_jax_logs () as log_output :
151151 jax .jit (lambda x : x + 1 )(1 )
152- self .assertIn ("Finished tracing + transforming " , log_output .getvalue ())
152+ self .assertIn ("Finished tracing" , log_output .getvalue ())
153153 self .assertIn ("Compiling jit(<lambda>)" , log_output .getvalue ())
154154
155155 # Turn off all debug logging.
@@ -162,7 +162,7 @@ def test_debug_logging(self):
162162 with jax_debug_log_modules ("jax._src.dispatch" ):
163163 with capture_jax_logs () as log_output :
164164 jax .jit (lambda x : x + 1 )(1 )
165- self .assertIn ("Finished tracing + transforming " , log_output .getvalue ())
165+ self .assertIn ("Finished tracing" , log_output .getvalue ())
166166 self .assertNotIn ("Compiling jit(<lambda>)" , log_output .getvalue ())
167167
168168 # Turn everything off again.
@@ -233,7 +233,7 @@ def test_subprocess_toggling_logging_level(self):
233233 log_output_verbose = log_output [:m .start ()]
234234 log_output_silent = log_output [m .end ():]
235235
236- self .assertIn ("Finished tracing + transforming <lambda> for pjit " ,
236+ self .assertIn ("Finished tracing <lambda> for jit " ,
237237 log_output_verbose )
238238 self .assertEqual (log_output_silent , "" )
239239
@@ -256,7 +256,7 @@ def test_subprocess_double_logging_absent(self):
256256 # only one tracing line should be printed, if there's more than one
257257 # then logs are printing duplicated
258258 self .assertLen ([line for line in log_lines
259- if "Finished tracing + transforming " in line ], 1 )
259+ if "Finished tracing" in line ], 1 )
260260
261261 @jtu .skip_on_devices ("tpu" )
262262 @unittest .skipIf (platform .system () == "Windows" ,
0 commit comments