|
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 |
|
7 | | -from hls4ml.backends import VitisBackend, VivadoBackend |
| 7 | +from hls4ml.backends import VitisBackend, VivadoBackend, VitisAcceleratorConfig |
8 | 8 | from hls4ml.model.flow import get_flow, register_flow |
9 | | - |
| 9 | +import ctypes |
10 | 10 |
|
11 | 11 | class VitisAcceleratorBackend(VitisBackend): |
12 | 12 | def __init__(self): |
@@ -114,24 +114,67 @@ def dat_to_numpy(self, model): |
114 | 114 | y = np.loadtxt(output_file, dtype=float).reshape(-1, expected_shape) |
115 | 115 | return y |
116 | 116 |
|
117 | | - def hardware_predict(self, model, x, target="hw", debug=False, profilingRepeat=-1): |
118 | | - command = "" |
| 117 | + def hardware_predict(self, model, x, target="hw", debug=False, profilingRepeat=-1, method="file"): |
| 118 | + if method == "file": |
| 119 | + """Run the hardware prediction using file-based communication.""" |
| 120 | + command = "" |
119 | 121 |
|
120 | | - if debug: |
121 | | - command += "DEBUG=1 " |
122 | | - if isinstance(profilingRepeat, int) and profilingRepeat > 0: |
123 | | - command += "PROFILING_DATA_REPEAT_COUNT=" + profilingRepeat + " " |
124 | | - self._validate_target(target) |
| 122 | + if debug: |
| 123 | + command += "DEBUG=1 " |
| 124 | + if isinstance(profilingRepeat, int) and profilingRepeat > 0: |
| 125 | + command += "PROFILING_DATA_REPEAT_COUNT=" + profilingRepeat + " " |
| 126 | + self._validate_target(target) |
125 | 127 |
|
126 | | - self.numpy_to_dat(model, x) |
| 128 | + self.numpy_to_dat(model, x) |
127 | 129 |
|
128 | | - currdir = os.getcwd() |
129 | | - os.chdir(model.config.get_output_dir()) |
130 | | - command += "TARGET=" + target + " make run" |
131 | | - os.system(command) |
132 | | - os.chdir(currdir) |
| 130 | + currdir = os.getcwd() |
| 131 | + os.chdir(model.config.get_output_dir()) |
| 132 | + command += "TARGET=" + target + " make run" |
| 133 | + os.system(command) |
| 134 | + os.chdir(currdir) |
| 135 | + |
| 136 | + return self.dat_to_numpy(model) |
| 137 | + |
| 138 | + elif method == "lib": |
| 139 | + """Run the hardware prediction using a shared library.""" |
| 140 | + # Set array to contiguous memory layout |
| 141 | + X_test = np.ascontiguousarray(x) |
| 142 | + |
| 143 | + # Create prediction array |
| 144 | + config = VitisAcceleratorConfig(model.config) |
| 145 | + batchsize = config.get_batchsize() |
| 146 | + originalSampleCount = X_test.shape[0] |
| 147 | + numBatches = int(np.ceil(originalSampleCount / batchsize)) |
| 148 | + sampleOutputSIze = model.get_output_variables()[0].size() |
| 149 | + predictions_size = numBatches * batchsize * sampleOutputSIze |
| 150 | + predictions = np.zeros(predictions_size, dtype=np.float64) |
| 151 | + predictions_size = predictions.shape[0] |
| 152 | + predictions_ptr = predictions.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) |
| 153 | + |
| 154 | + # Flatten the input data |
| 155 | + X_test_flat = X_test.flatten() |
| 156 | + X_test_size = X_test_flat.shape[0] |
| 157 | + X_test_flat = X_test_flat.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) |
| 158 | + |
| 159 | + # Change working directory to the HLS project directory |
| 160 | + os.chdir('model_va/hls4ml_prj') |
| 161 | + |
| 162 | + # Load the shared library |
| 163 | + lib = ctypes.cdll.LoadLibrary('./lib_host.so') |
| 164 | + |
| 165 | + # Call the predict function |
| 166 | + lib.predict.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_size_t, ctypes.POINTER(ctypes.c_double), ctypes.c_size_t] |
| 167 | + lib.predict(X_test_flat, X_test_size, predictions_ptr, predictions_size) |
| 168 | + |
| 169 | + # Change back to the original directory |
| 170 | + os.chdir('../..') |
| 171 | + |
| 172 | + # Reshape the predictions to match the expected output shape |
| 173 | + y_hls = predictions.reshape(-1, sampleOutputSIze)[:originalSampleCount, :] |
| 174 | + return y_hls |
133 | 175 |
|
134 | | - return self.dat_to_numpy(model) |
| 176 | + else: |
| 177 | + raise Exception(f"Unsupported method {method} for hardware prediction") |
135 | 178 |
|
136 | 179 | def _register_flows(self): |
137 | 180 | validation_passes = [ |
|
0 commit comments