@@ -358,9 +358,9 @@ public boolean isReduceRetracingParamExists() {
358358 private static Map <TensorTypeAnalysis , Set <InstanceKey >> tensorContainersCache = Maps .newConcurrentMap ();
359359
360360 /**
361- * Containing {@link IDocument }s that have had import statements added to them during transformation.
361+ * Containing {@link File }s that have had import statements added to them during transformation.
362362 */
363- private static Set <IDocument > documentsWithAddedImport = new HashSet <>();
363+ private static Set <File > filesWithAddedImport = new HashSet <>();
364364
365365 private static final String TF_FUNCTION_FQN = "tensorflow.python.eager.def_function.function" ;
366366
@@ -472,7 +472,7 @@ private static boolean allCreationsWithinClosureInteral2(MethodReference methodR
472472 public static void clearCaches () {
473473 creationsCache .clear ();
474474 tensorContainersCache .clear ();
475- documentsWithAddedImport .clear ();
475+ filesWithAddedImport .clear ();
476476 }
477477
478478 /**
@@ -1959,15 +1959,17 @@ private List<TextEdit> convertToHybrid() throws BadLocationException {
19591959
19601960 if (prefix == null ) {
19611961 // need to add an import if it doesn't already exist.
1962- if (!documentsWithAddedImport .contains (doc )) {
1962+ File file = this .getContainingFile ();
1963+
1964+ if (!filesWithAddedImport .contains (file )) {
19631965 int line = getLineToInsertImport (doc );
19641966 int lineOffset = doc .getLineOffset (line );
19651967
19661968 TextEdit edit = new InsertEdit (lineOffset , "from tensorflow import function\n " );
19671969 MultiTextEdit mte = new MultiTextEdit ();
19681970 mte .addChild (edit );
19691971 ret .add (mte );
1970- documentsWithAddedImport .add (doc );
1972+ filesWithAddedImport .add (file );
19711973 }
19721974
19731975 prefix = "" ; // no prefix needed.
0 commit comments