|
63 | 63 |
|
64 | 64 | timer = dpctl.SyclTimer(time_scale=1e3) |
65 | 65 |
|
66 | | -iters = [] |
67 | | -for i in range(6): |
68 | | - with timer(api_dev.sycl_queue): |
69 | | - x, conv_in = solve.cg_solve(A, b) |
70 | | - |
71 | | - print(i, "(host_dt, device_dt)=", timer.dt) |
72 | | - iters.append(conv_in) |
73 | | - assert x.usm_type == A.usm_type |
74 | | - assert x.usm_type == b.usm_type |
75 | | - assert x.sycl_queue == A.sycl_queue |
76 | | - assert x.sycl_queue == b.sycl_queue |
77 | | - |
78 | | -print("Converged in: ", iters) |
79 | | - |
80 | | -hev, ev = sycl_gemm.gemv(q, A, x, r) |
81 | | -hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev]) |
82 | | -rs = sycl_gemm.norm_squared_blocking(q, delta) |
83 | | -dpctl.SyclEvent.wait_for([hev, hev2]) |
84 | | -print(f"Python solution residual norm squared: {rs}") |
| 66 | + |
| 67 | +def time_python_solver(num_iters=6): |
| 68 | + """ |
| 69 | + Time solver implemented in Python with use of asynchronous |
| 70 | + SYCL kernel submission. |
| 71 | + """ |
| 72 | + global x |
| 73 | + iters = [] |
| 74 | + for i in range(num_iters): |
| 75 | + with timer(api_dev.sycl_queue): |
| 76 | + x, conv_in = solve.cg_solve(A, b) |
| 77 | + |
| 78 | + print(i, "(host_dt, device_dt)=", timer.dt) |
| 79 | + iters.append(conv_in) |
| 80 | + assert x.usm_type == A.usm_type |
| 81 | + assert x.usm_type == b.usm_type |
| 82 | + assert x.sycl_queue == A.sycl_queue |
| 83 | + assert x.sycl_queue == b.sycl_queue |
| 84 | + |
| 85 | + return iters |
| 86 | + |
| 87 | + |
| 88 | +def time_cpp_solver(num_iters=6): |
| 89 | + """ |
| 90 | + Time solver implemented in C++ but callable from Python. |
| 91 | + C++ implementation uses the same algorithm and submits same |
| 92 | + kernels asynchronously, but bypasses Python binding overhead |
| 93 | + incurred when algorithm is driver from Python. |
| 94 | + """ |
| 95 | + global x_cpp |
| 96 | + x_cpp = dpt.empty_like(b) |
| 97 | + iters = [] |
| 98 | + for i in range(num_iters): |
| 99 | + with timer(api_dev.sycl_queue): |
| 100 | + conv_in = sycl_gemm.cpp_cg_solve(q, A, b, x_cpp) |
| 101 | + |
| 102 | + print(i, "(host_dt, device_dt)=", timer.dt) |
| 103 | + iters.append(conv_in) |
| 104 | + |
| 105 | + return iters |
| 106 | + |
| 107 | + |
| 108 | +def compute_residual(x): |
| 109 | + """ |
| 110 | + Computes quality of the solution, `norm_squared(A@x - b)`. |
| 111 | + """ |
| 112 | + assert isinstance(x, dpt.usm_ndarray) |
| 113 | + q = A.sycl_queue |
| 114 | + hev, ev = sycl_gemm.gemv(q, A, x, r) |
| 115 | + hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev]) |
| 116 | + rs = sycl_gemm.norm_squared_blocking(q, delta) |
| 117 | + dpctl.SyclEvent.wait_for([hev, hev2]) |
| 118 | + return rs |
| 119 | + |
| 120 | + |
| 121 | +print("Converged in: ", time_python_solver()) |
| 122 | +print(f"Python solution residual norm squared: {compute_residual(x)}") |
85 | 123 |
|
86 | 124 | assert q == api_dev.sycl_queue |
87 | 125 | print("") |
88 | 126 |
|
89 | | -x_cpp = dpt.empty_like(b) |
90 | | -iters = [] |
91 | | -for i in range(6): |
92 | | - with timer(api_dev.sycl_queue): |
93 | | - conv_in = sycl_gemm.cpp_cg_solve(q, A, b, x_cpp) |
94 | | - |
95 | | - print(i, "(host_dt, device_dt)=", timer.dt) |
96 | | - iters.append(conv_in) |
97 | | - |
98 | | -print("Converged in: ", iters) |
99 | | -hev, ev = sycl_gemm.gemv(q, A, x_cpp, r) |
100 | | -hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev]) |
101 | | -rs = sycl_gemm.norm_squared_blocking(q, delta) |
102 | | -dpctl.SyclEvent.wait_for([hev, hev2]) |
103 | | -print(f"cpp_cg_solve solution residual norm squared: {rs}") |
| 127 | +print("Converged in: ", time_cpp_solver()) |
| 128 | +print(f"cpp_cg_solve solution residual norm squared: {compute_residual(x_cpp)}") |
0 commit comments