Skip to content

Conversation

@wiredfool
Copy link
Member

A recent run of oss-fuzz here ended up entirely finding dds parser timeout issues. There were 40 or so of them.

This PR has two significant changes:

  1. Hoist a divide + comparison out of the hot loop. Using the line profiler, this nets us about 10% on the inner loop. Not great, but not bad though.

  2. Only go pixel by pixel for the data we have in the file, and bulk fill with 0 once we hit the end. This gets us the 500x speed increase we need, but this only really helps semi-invalid images where the file data is much smaller than the image data that we're generating. We probably won't hit this on real images, but fuzzed ones hit it constantly. On the downside, we give up the first 10% gain in the hot loop checking to see if we read any data.

initial lineperf:

(vpy313) erics@wf:~/test$ python -m line_profiler -rtmz profile_output.lprof
Timer unit: 1e-06 s

Total time: 194.62 s
File: /home/erics/vpy313/lib/python3.13/site-packages/PIL/DdsImagePlugin.py
Function: DdsRgbDecoder.decode at line 492

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   492                                               @line_profiler.profile
   493                                               def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
   494         1          1.1      1.1      0.0          assert self.fd is not None
   495         1          0.9      0.9      0.0          bitcount, masks = self.args
   496                                           
   497                                                   # Some masks will be padded with zeros, e.g. R 0b11 G 0b1100
   498                                                   # Calculate how many zeros each mask is padded with
   499         1          0.6      0.6      0.0          mask_offsets = []
   500                                                   # And the maximum value of each channel without the padding
   501         1          0.7      0.7      0.0          mask_totals = []
   502         4          3.4      0.8      0.0          for mask in masks:
   503         3          1.5      0.5      0.0              offset = 0
   504         3          1.7      0.6      0.0              if mask != 0:
   505                                                           while mask >> (offset + 1) << (offset + 1) == mask:
   506                                                               offset += 1
   507         3          2.3      0.8      0.0              mask_offsets.append(offset)
   508         3          2.0      0.7      0.0              mask_totals.append(mask >> offset)
   509                                           
   510         1          1.5      1.5      0.0          data = bytearray()
   511         1          0.7      0.7      0.0          bytecount = bitcount // 8
   512         1          2.1      2.1      0.0          dest_length = self.state.xsize * self.state.ysize * len(masks)
   513  16777218    8494592.6      0.5      4.4          while len(data) < dest_length:
   514  16777217    9764454.4      0.6      5.0              value = int.from_bytes(self.fd.read(bytecount), "little")
   515  67108868   32003938.7      0.5     16.4              for i, mask in enumerate(masks):
   516  50331651   22649228.0      0.4     11.6                  masked_value = value & mask
   517                                                           # Remove the zero padding, and scale it to 8 bits
   518 100663302   79428755.0      0.8     40.8                  data += o8(
   519                                                               int(((masked_value >> mask_offsets[i]) / mask_totals[i]) * 255)
   520  50331651   21372037.4      0.4     11.0                      if mask_totals[i]
   521  50331651   20668863.3      0.4     10.6                      else 0
   522                                                           )
   523         1     238290.4 238290.4      0.1          self.set_as_raw(data)
   524         1          0.8      0.8      0.0          return -1, 0

Premultiplied:

(vpy313) erics@wf:~/test$ python -m line_profiler -rtmz profile_output.lprof
Timer unit: 1e-06 s

Total time: 175.299 s
File: /home/erics/vpy313/lib/python3.13/site-packages/PIL/DdsImagePlugin.py
Function: DdsRgbDecoder.decode at line 492

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   492                                               @line_profiler.profile
   493                                               def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
   494         1          1.4      1.4      0.0          assert self.fd is not None
   495         1          0.7      0.7      0.0          bitcount, masks = self.args
   496                                           
   497                                                   # Some masks will be padded with zeros, e.g. R 0b11 G 0b1100
   498                                                   # Calculate how many zeros each mask is padded with
   499         1          0.5      0.5      0.0          mask_offsets = []
   500                                                   # And the maximum value of each channel without the padding
   501         1          0.6      0.6      0.0          mask_totals = []
   502         4          2.7      0.7      0.0          for mask in masks:
   503         3          1.4      0.5      0.0              offset = 0
   504         3          1.7      0.6      0.0              if mask != 0:
   505                                                           while mask >> (offset + 1) << (offset + 1) == mask:
   506                                                               offset += 1
   507         3          2.2      0.7      0.0              mask_offsets.append(offset)
   508         3          1.8      0.6      0.0              mask_total = mask >> offset
   509         3          1.6      0.5      0.0              if not mask_total:
   510         3          1.4      0.5      0.0                  mask_totals.append(0)
   511                                                       else:
   512                                                           mask_totals.append(255/mask_total)
   513                                           
   514         1          1.4      1.4      0.0          data = bytearray()
   515         1          0.8      0.8      0.0          bytecount = bitcount // 8
   516         1          1.5      1.5      0.0          dest_length = self.state.xsize * self.state.ysize * len(masks)
   517                                           #        consolidate_mask = zip(masks, mask_offsets, mask_totals)
   518  16777218    8314582.1      0.5      4.7          while len(data) < dest_length:
   519  16777217    9686548.8      0.6      5.5              value = int.from_bytes(self.fd.read(bytecount), "little")
   520  67108868   31127505.5      0.5     17.8              for i, mask in enumerate(masks):
   521  50331651   21402455.3      0.4     12.2                  masked_value = value & mask
   522                                                           # Remove the zero padding, and scale it to 8 bits
   523 100663302   81338723.3      0.8     46.4                  data += o8(
   524  50331651   23194260.0      0.5     13.2                      int((masked_value >> mask_offsets[i]) * mask_totals[i])
   525                                                           )
   526         1     234889.2 234889.2      0.1          self.set_as_raw(data)
   527         1          1.1      1.1      0.0          return -1, 0

And the final config:

Timer unit: 1e-06 s

Total time: 0.299102 s
File: /home/erics/vpy313/lib/python3.13/site-packages/PIL/DdsImagePlugin.py
Function: DdsRgbDecoder.decode at line 493

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   493                                               @line_profiler.profile
   494                                               def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
   495         1          1.9      1.9      0.0          assert self.fd is not None
   496         1          0.9      0.9      0.0          bitcount, masks = self.args
   497                                           
   498                                                   # Some masks will be padded with zeros, e.g. R 0b11 G 0b1100
   499                                                   # Calculate how many zeros each mask is padded with
   500         1          0.5      0.5      0.0          mask_offsets = []
   501                                                   # And the maximum value of each channel without the padding
   502         1          0.7      0.7      0.0          mask_totals = []
   503         4          9.3      2.3      0.0          for mask in masks:
   504         3          2.1      0.7      0.0              offset = 0
   505         3          1.7      0.6      0.0              if mask != 0:
   506                                                           while mask >> (offset + 1) << (offset + 1) == mask:
   507                                                               offset += 1
   508         3          2.3      0.8      0.0              mask_offsets.append(offset)
   509         3          2.0      0.7      0.0              mask_total = mask >> offset
   510         3          1.5      0.5      0.0              if not mask_total:
   511         3          1.5      0.5      0.0                  mask_totals.append(0)
   512                                                       else:
   513                                                           mask_totals.append(255/mask_total)
   514                                           
   515         1          1.4      1.4      0.0          data = bytearray()
   516         1          0.8      0.8      0.0          bytecount = bitcount // 8
   517         1          1.5      1.5      0.0          dest_length = self.state.xsize * self.state.ysize * len(masks)
   518                                                   # consume the data
   519         1          0.5      0.5      0.0          has_more = True
   520         3          2.3      0.8      0.0          while len(data) < dest_length and has_more:
   521         2          2.6      1.3      0.0              chunk = self.fd.read(bytecount)
   522                                                       # work around BufferedIO not being seekable
   523         2          1.5      0.7      0.0              has_more = len(chunk) > 0
   524         2          2.8      1.4      0.0              value = int.from_bytes(chunk, "little")
   525         8          5.2      0.6      0.0              for i, mask, in enumerate(masks):
   526         6          3.1      0.5      0.0                  masked_value = value & mask
   527                                                           # Remove the zero padding, and scale it to 8 bits
   528        12         16.2      1.3      0.0                  data += o8(
   529         6          4.1      0.7      0.0                      int((masked_value >> mask_offsets[i]) * mask_totals[i])
   530                                                           )
   531                                           
   532                                                   # extra padding pixels -- always all 0
   533         1          0.6      0.6      0.0          if len(data) < dest_length:
   534         1          0.7      0.7      0.0              pixel = bytearray()
   535         1          1.4      1.4      0.0              pixel += o8(0)
   536         1          0.7      0.7      0.0              ct_bytes = dest_length - len(data)
   537         1      57766.5  57766.5     19.3              data += pixel * ct_bytes
   538                                           
   539                                           
   540         1     241264.5 241264.5     80.7          self.set_as_raw(data)
   541         1          1.2      1.2      0.0          return -1, 0

"Tests/images/timeout-c60a3d7314213624607bfb3e38d551a8b24a7435.dds",
],
)
def test_timeout(test_file) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_timeout(test_file) -> None:
def test_timeout(test_file: str) -> None:

data += o8(int((masked_value >> mask_offsets[i]) * mask_totals[i]))

# extra padding pixels -- always all 0
if len(data) < dest_length:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the file stopped before we were able to read all the bytes we needed, why not just let the data be shorter than expected, leading to a ValueError: not enough image data be raised? That error sounds like it is accurately describing the situation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question - I couldn't tell if it was intentional that it filled off the end. Given that it's plain black, it's at least a reasonable fill, but I didn't see a spec. I'd be happy enough erroring out as well. Main issue here is that I want to catch things other than dds timeouts in oss-fuzz.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created #9405

while len(data) < dest_length:
value = int.from_bytes(self.fd.read(bytecount), "little")
# consume the data
has_more = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
has_more = True
has_more = bytecount != 0

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bytecount is number of bytes per pixel, or the number of masks. It's not the data size.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but if bytecount is zero, then self.fd.read(bytecount) is always going to be empty.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But bytecount should never be zero. It’s an image level compression property, how many bytes do we consume per pixel.

To me, this looks like a misunderstanding, so can you explain why it’s better to initialize with this rather than the simpler expression for true?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But bytecount should never be zero

...awkwardly, it does seem to be possible. I mean, it doesn't seem helpful value, and I do think it should lead to an error.

# pixel format
pfsize, pfflags, fourcc, bitcount = struct.unpack("<4I", header[68:84])

bytecount = bitcount // 8

Here's the documentation - https://learn.microsoft.com/en-us/windows/win32/direct3ddds/dds-pixelformat

dwRGBBitCount

Type: DWORD

Number of bits in an RGB (possibly including alpha) format. Valid when dwFlags includes DDPF_RGB, DDPF_LUMINANCE, or DDPF_YUV.

To demonstrate, I've added radarhere@c90d124 onto a fork of this branch - https://github.com/radarhere/Pillow/actions/runs/21312659863/job/61351156707#step:11:58

bitcount is set to 0 in _open
bitcount is 0 which means that bytecount is 0 within DdsRgbDecoder

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants