diff --git a/tests/test_zipstream.py b/tests/test_zipstream.py index 56e1ef2..82bc2a6 100644 --- a/tests/test_zipstream.py +++ b/tests/test_zipstream.py @@ -6,6 +6,18 @@ import unittest import zipstream import zipfile +import socket +import functools +from nose.plugins.skip import SkipTest + + +def skipIfNotPosix(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + if os.name == "posix": + return f(*args, **kwargs) + raise SkipTest("requires POSIX") + return wrapper class ZipInfoTestCase(unittest.TestCase): @@ -54,6 +66,78 @@ def test_write_file(self): os.remove(f.name) + def test_write_fp(self): + z = zipstream.ZipFile(mode='w') + for fileobj in self.fileobjs: + z.write_stream(fileobj) + + f = tempfile.NamedTemporaryFile(suffix='zip', delete=False) + for chunk in z: + f.write(chunk) + f.close() + + z2 = zipfile.ZipFile(f.name, 'r') + z2.testzip() + + os.remove(f.name) + + def test_write_fp_with_stat(self): + z = zipstream.ZipFile(mode='w') + # test mtime + z.write_stream(self.fileobjs[0], arcname="mtime", + mtime=315532900) + + # test with a specific file size + fdata = tempfile.NamedTemporaryFile(suffix='.data') + fdata.write(" "*15) + fdata.seek(0) + z.write_stream(fdata, arcname="size", size=15) + + # test isdir + z.write_stream(None, arcname="isdir", isdir=True) + + f = tempfile.NamedTemporaryFile(suffix='zip', delete=False) + for chunk in z: + f.write(chunk) + f.close() + fdata.close() + + z2 = zipfile.ZipFile(f.name, 'r') + z2.testzip() + self.assertEqual( + [zi.filename for zi in z2.filelist], + ['mtime', 'size', 'isdir/']) + self.assertEqual(z2.filelist[0].date_time[5], 40) + self.assertEqual(z2.filelist[1].file_size, 15) + + os.remove(f.name) + + @skipIfNotPosix + def test_write_socket(self): + z = zipstream.ZipFile(mode='w') + s, c = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + try: + txt = "FILE CONTENTS" + s.send(txt.encode("ascii")) + try: + inf = c.makefile(mode='rb') + except TypeError: + inf = c.makefile() + z.write_stream(inf) + s.close() + + f = tempfile.NamedTemporaryFile(suffix='zip', delete=False) + for chunk in z: + f.write(chunk) + f.close() + + z2 = zipfile.ZipFile(f.name, 'r') + z2.testzip() + + os.remove(f.name) + finally: + c.close() + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/zipstream/__init__.py b/zipstream/__init__.py index 811e2cf..1361a47 100644 --- a/zipstream/__init__.py +++ b/zipstream/__init__.py @@ -16,7 +16,7 @@ import zipfile from .compat import ( - str, bytes, + str, bytes, basestring, ZIP64_VERSION, ZIP_BZIP2, BZIP2_VERSION, ZIP_LZMA, LZMA_VERSION) @@ -32,6 +32,19 @@ stringDataDescriptor = b'PK\x07\x08' # magic number for data descriptor +def _stream_stat(mtime, isdir, size): + st = [0]*10 + st[0] = stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH # mode + st[7] = st[8] = st[9] = 315532800 # times + if isdir is True: + st[0] |= stat.S_IFDIR + if size is not None: + st[6] = size + if mtime is not None: + st[8] = mtime + return os.stat_result(st) + + def _get_compressor(compress_type): if compress_type == ZIP_DEFLATED: return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, @@ -208,16 +221,32 @@ def write(self, filename, arcname=None, compress_type=None): ((filename, ), {'arcname': arcname, 'compress_type': compress_type}), ) - def __write(self, filename, arcname=None, compress_type=None): + def write_stream(self, fp, arcname=None, compress_type=None, + mtime=None, isdir=None, size=None): + self.paths_to_write.append( + ((fp, ), {'arcname': arcname, 'compress_type': compress_type, + 'st': _stream_stat(mtime, isdir, size)}), + ) + + + def __write(self, fp, arcname=None, compress_type=None, st=None): """Put the bytes from filename into the archive under the name arcname.""" if not self.fp: raise RuntimeError( "Attempt to write to ZIP archive that was already closed") - st = os.stat(filename) + if isinstance(fp, basestring): + filename, fp = (fp, None) + st = st or os.stat(filename) + else: + filename = '' + st = st or os.stat(0) + isdir = stat.S_ISDIR(st.st_mode) mtime = time.localtime(st.st_mtime) + if (mtime.tm_year < 1980): + mtime = time.localtime() date_time = mtime[0:6] # Create ZipInfo instance to store file information if arcname is None: @@ -255,7 +284,8 @@ def __write(self, filename, arcname=None, compress_type=None): return cmpr = _get_compressor(zinfo.compress_type) - with open(filename, 'rb') as fp: + fp = fp or open(filename, 'rb') + with fp: # Must overwrite CRC and sizes with correct data later zinfo.CRC = CRC = 0 zinfo.compress_size = compress_size = 0 @@ -265,7 +295,12 @@ def __write(self, filename, arcname=None, compress_type=None): yield self.fp.write(zinfo.FileHeader(zip64)) file_size = 0 while 1: - buf = fp.read(1024 * 8) + sz = 1024 * 8 + if zinfo.file_size > 0: # known size, read only that much + if zinfo.file_size == file_size: + break + sz = min(zinfo.file_size - file_size, sz) + buf = fp.read(sz) if not buf: break file_size = file_size + len(buf) @@ -282,7 +317,10 @@ def __write(self, filename, arcname=None, compress_type=None): else: zinfo.compress_size = file_size zinfo.CRC = CRC - zinfo.file_size = file_size + if zinfo.file_size > 0 and zinfo.file_size != file_size: + raise RuntimeError('File size changed during compressing') + else: + zinfo.file_size = file_size if not zip64 and self._allowZip64: if file_size > ZIP64_LIMIT: raise RuntimeError('File size has increased during compressing')