1717package tpu ;
1818
1919import static com .google .common .truth .Truth .assertThat ;
20+ import static org .junit .Assert .assertEquals ;
2021import static org .mockito .Mockito .any ;
2122import static org .mockito .Mockito .mock ;
2223import static org .mockito .Mockito .mockStatic ;
2526import static org .mockito .Mockito .when ;
2627
2728import com .google .api .gax .longrunning .OperationFuture ;
29+ import com .google .cloud .tpu .v2 .CreateNodeRequest ;
2830import com .google .cloud .tpu .v2 .DeleteNodeRequest ;
2931import com .google .cloud .tpu .v2 .GetNodeRequest ;
32+ import com .google .cloud .tpu .v2 .ListNodesRequest ;
3033import com .google .cloud .tpu .v2 .Node ;
3134import com .google .cloud .tpu .v2 .TpuClient ;
3235import com .google .cloud .tpu .v2 .TpuSettings ;
3336import java .io .ByteArrayOutputStream ;
3437import java .io .IOException ;
3538import java .io .PrintStream ;
39+ import java .util .Arrays ;
40+ import java .util .List ;
3641import java .util .concurrent .ExecutionException ;
3742import org .junit .jupiter .api .BeforeAll ;
3843import org .junit .jupiter .api .Test ;
@@ -47,6 +52,8 @@ public class TpuVmIT {
4752 private static final String PROJECT_ID = "project-id" ;
4853 private static final String ZONE = "asia-east1-c" ;
4954 private static final String NODE_NAME = "test-tpu" ;
55+ private static final String TPU_TYPE = "v2-8" ;
56+ private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1" ;
5057 private static ByteArrayOutputStream bout ;
5158
5259 @ BeforeAll
@@ -55,21 +62,45 @@ public static void setUp() {
5562 System .setOut (new PrintStream (bout ));
5663 }
5764
65+ @ Test
66+ public void testCreateTpuVm () throws Exception {
67+ try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
68+ Node mockNode = mock (Node .class );
69+ TpuClient mockTpuClient = mock (TpuClient .class );
70+ OperationFuture mockFuture = mock (OperationFuture .class );
71+
72+ mockedTpuClient .when (() -> TpuClient .create (any (TpuSettings .class )))
73+ .thenReturn (mockTpuClient );
74+ when (mockTpuClient .createNodeAsync (any (CreateNodeRequest .class )))
75+ .thenReturn (mockFuture );
76+ when (mockFuture .get ()).thenReturn (mockNode );
77+
78+ Node returnedNode = CreateTpuVm .createTpuVm (
79+ PROJECT_ID , ZONE , NODE_NAME ,
80+ TPU_TYPE , TPU_SOFTWARE_VERSION );
81+
82+ verify (mockTpuClient , times (1 ))
83+ .createNodeAsync (any (CreateNodeRequest .class ));
84+ verify (mockFuture , times (1 )).get ();
85+ assertEquals (returnedNode , mockNode );
86+ }
87+ }
88+
5889 @ Test
5990 public void testGetTpuVm () throws IOException {
6091 try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
6192 Node mockNode = mock (Node .class );
6293 TpuClient mockClient = mock (TpuClient .class );
63- GetTpuVm mockGetTpuVm = mock (GetTpuVm .class );
6494
6595 mockedTpuClient .when (TpuClient ::create ).thenReturn (mockClient );
6696 when (mockClient .getNode (any (GetNodeRequest .class ))).thenReturn (mockNode );
6797
6898 Node returnedNode = GetTpuVm .getTpuVm (PROJECT_ID , ZONE , NODE_NAME );
6999
70- verify (mockGetTpuVm , times (1 ))
71- .getTpuVm ( PROJECT_ID , ZONE , NODE_NAME );
100+ verify (mockClient , times (1 ))
101+ .getNode ( any ( GetNodeRequest . class ) );
72102 assertThat (returnedNode ).isEqualTo (mockNode );
103+ verify (mockClient , times (1 )).close ();
73104 }
74105 }
75106
@@ -91,4 +122,27 @@ public void testDeleteTpuVm() throws IOException, ExecutionException, Interrupte
91122 verify (mockTpuClient , times (1 )).deleteNodeAsync (any (DeleteNodeRequest .class ));
92123 }
93124 }
125+
126+ @ Test
127+ public void testListTpuVm () throws IOException {
128+ try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
129+ Node mockNode1 = mock (Node .class );
130+ Node mockNode2 = mock (Node .class );
131+ List <Node > mockListNodes = Arrays .asList (mockNode1 , mockNode2 );
132+
133+ TpuClient mockTpuClient = mock (TpuClient .class );
134+ mockedTpuClient .when (TpuClient ::create ).thenReturn (mockTpuClient );
135+ TpuClient .ListNodesPagedResponse mockListNodesResponse =
136+ mock (TpuClient .ListNodesPagedResponse .class );
137+ when (mockTpuClient .listNodes (any (ListNodesRequest .class ))).thenReturn (mockListNodesResponse );
138+ TpuClient .ListNodesPage mockListNodesPage = mock (TpuClient .ListNodesPage .class );
139+ when (mockListNodesResponse .getPage ()).thenReturn (mockListNodesPage );
140+ when (mockListNodesPage .getValues ()).thenReturn (mockListNodes );
141+
142+ TpuClient .ListNodesPage returnedListNodes = ListTpuVms .listTpuVms (PROJECT_ID , ZONE );
143+
144+ assertThat (returnedListNodes .getValues ()).isEqualTo (mockListNodes );
145+ verify (mockTpuClient , times (1 )).listNodes (any (ListNodesRequest .class ));
146+ }
147+ }
94148}
0 commit comments