@@ -89,11 +89,149 @@ uint32_t crc_update(uint32_t crc, const void * data, size_t data_len)
8989 MAIN
9090 **************************************************************************************/
9191
92+ union HeaderVersion
93+ {
94+ struct __attribute__ ((packed))
95+ {
96+ uint32_t header_version : 6 ;
97+ uint32_t compression : 1 ;
98+ uint32_t signature : 1 ;
99+ uint32_t spare : 4 ;
100+ uint32_t payload_target : 4 ;
101+ uint32_t payload_major : 8 ;
102+ uint32_t payload_minor : 8 ;
103+ uint32_t payload_patch : 8 ;
104+ uint32_t payload_build_num : 24 ;
105+ } field;
106+ uint8_t buf[sizeof (field)];
107+ static_assert (sizeof (buf) == 8 , " Error: sizeof(HEADER.VERSION) != 8" );
108+ };
109+
110+ union OTAHeader
111+ {
112+ struct __attribute__ ((packed))
113+ {
114+ uint32_t len;
115+ uint32_t crc32;
116+ uint32_t magic_number;
117+ HeaderVersion hdr_version;
118+ } header;
119+ uint8_t buf[sizeof (header)];
120+ static_assert (sizeof (buf) == 20 , " Error: sizeof(HEADER) != 20" );
121+ };
122+
92123int Arduino_Portenta_OTA::download (const char * url, bool const is_https, MbedSocketClass * socket)
93124{
94125 return socket->download ((char *)url, UPDATE_FILE_NAME_LZSS, is_https);
95126}
96127
128+ int Arduino_Portenta_OTA::downloadAndDecompress (const char * url, bool const is_https, MbedSocketClass * socket) {
129+ int res=0 ;
130+
131+ FILE* decompressed = fopen (UPDATE_FILE_NAME, " wb" );
132+ OTAHeader ota_header;
133+
134+ LZSSDecoder decoder ([&decompressed](const uint8_t c) {
135+ fwrite (&c, 1 , 1 , decompressed);
136+ });
137+
138+ enum OTA_DOWNLOAD_STATE: uint8_t {
139+ OTA_DOWNLOAD_HEADER=0 ,
140+ OTA_DOWNLOAD_FILE,
141+ OTA_DOWNLOAD_ERR
142+ };
143+
144+ // since mbed::Callback requires a function to not exceed a certain size, we group the following parameters in a struct
145+ struct {
146+ uint32_t crc32 = 0xFFFFFFFF ;
147+ uint32_t header_copied_bytes = 0 ;
148+ OTA_DOWNLOAD_STATE state=OTA_DOWNLOAD_HEADER;
149+ } ota_progress;
150+
151+ int bytes = socket->download (url, is_https, [&decoder, &ota_header, &ota_progress](const char * buffer, uint32_t size) {
152+ for (char * cursor=(char *)buffer; cursor<buffer+size; ) {
153+ switch (ota_progress.state ) {
154+ case OTA_DOWNLOAD_HEADER: {
155+ // read to ota_header.buf
156+ // the header could be split into two arrivals, we must handle that
157+ uint32_t copied = size < sizeof (ota_header.buf ) ? size : sizeof (ota_header.buf );
158+ memcpy (ota_header.buf , buffer, copied);
159+ cursor += copied;
160+ ota_progress.header_copied_bytes += copied;
161+
162+ // when finished go to next state
163+ if (sizeof (ota_header.buf ) == ota_progress.header_copied_bytes ) {
164+ ota_progress.state = OTA_DOWNLOAD_FILE;
165+
166+ ota_progress.crc32 = crc_update (
167+ ota_progress.crc32 ,
168+ &(ota_header.header .magic_number ),
169+ sizeof (ota_header) - offsetof (OTAHeader, header.magic_number )
170+ );
171+
172+ }
173+ break ;
174+ }
175+ case OTA_DOWNLOAD_FILE:
176+ // continue to download the payload, decompressing it and calculate crc
177+ decoder.decompress ((uint8_t *)cursor, size - (cursor-buffer));
178+ ota_progress.crc32 = crc_update (
179+ ota_progress.crc32 ,
180+ cursor,
181+ size - (cursor-buffer)
182+ );
183+
184+ cursor += size - (cursor-buffer);
185+ break ;
186+ default :
187+ ota_progress.state = OTA_DOWNLOAD_ERR;
188+ }
189+ }
190+ });
191+
192+ // if download fails it return a negative error code
193+ if (bytes <= 0 ) {
194+ res = bytes;
195+ goto exit;
196+ }
197+
198+ // if state is download finished and completed correctly the state should be OTA_DOWNLOAD_FILE
199+ if (ota_progress.state != OTA_DOWNLOAD_FILE) {
200+ res = static_cast <int >(Error::OtaDownload);
201+ goto exit;
202+ }
203+
204+ if (ota_header.header .len == (bytes-sizeof (ota_header.buf ))) {
205+ res = static_cast <int >(Error::OtaHeaderLength);
206+ goto exit;
207+ }
208+
209+ // verify magic number: it may be done in the download function and stop the download immediately
210+ if (ota_header.header .magic_number != ARDUINO_PORTENTA_OTA_MAGIC) {
211+ res = static_cast <int >(Error::OtaHeaterMagicNumber);
212+ goto exit;
213+ }
214+
215+ // finalize CRC and verify it
216+ ota_progress.crc32 ^= 0xFFFFFFFF ;
217+ if (ota_header.header .crc32 != ota_progress.crc32 ) {
218+ res = static_cast <int >(Error::OtaHeaderCrc);
219+ goto exit;
220+ }
221+
222+ res = ftell (decompressed);
223+
224+ exit:
225+ fclose (decompressed);
226+
227+ if (res < 0 ) {
228+ remove (UPDATE_FILE_NAME);
229+ }
230+
231+ return res;
232+ }
233+
234+
97235int Arduino_Portenta_OTA::decompress ()
98236{
99237 struct stat stat_buf;
@@ -103,36 +241,7 @@ int Arduino_Portenta_OTA::decompress()
103241 /* For UPDATE.BIN.LZSS - LZSS compressed binary files. */
104242 FILE* update_file = fopen (UPDATE_FILE_NAME_LZSS, " rb" );
105243
106- union HeaderVersion
107- {
108- struct __attribute__ ((packed))
109- {
110- uint32_t header_version : 6 ;
111- uint32_t compression : 1 ;
112- uint32_t signature : 1 ;
113- uint32_t spare : 4 ;
114- uint32_t payload_target : 4 ;
115- uint32_t payload_major : 8 ;
116- uint32_t payload_minor : 8 ;
117- uint32_t payload_patch : 8 ;
118- uint32_t payload_build_num : 24 ;
119- } field;
120- uint8_t buf[sizeof (field)];
121- static_assert (sizeof (buf) == 8 , " Error: sizeof(HEADER.VERSION) != 8" );
122- };
123-
124- union
125- {
126- struct __attribute__ ((packed))
127- {
128- uint32_t len;
129- uint32_t crc32;
130- uint32_t magic_number;
131- HeaderVersion hdr_version;
132- } header;
133- uint8_t buf[sizeof (header)];
134- static_assert (sizeof (buf) == 20 , " Error: sizeof(HEADER) != 20" );
135- } ota_header;
244+ OTAHeader ota_header;
136245 uint32_t crc32, bytes_read;
137246 uint8_t crc_buf[128 ];
138247
0 commit comments