104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
# --------------------------------------------------------
|
|
# Swin Transformer
|
|
# Copyright (c) 2021 Microsoft
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
# Written by Ze Liu
|
|
# --------------------------------------------------------
|
|
|
|
import os
|
|
import zipfile
|
|
import io
|
|
import numpy as np
|
|
from PIL import Image
|
|
from PIL import ImageFile
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
|
|
|
|
def is_zip_path(img_or_path):
|
|
"""judge if this is a zip path"""
|
|
return '.zip@' in img_or_path
|
|
|
|
|
|
class ZipReader(object):
|
|
"""A class to read zipped files"""
|
|
zip_bank = dict()
|
|
|
|
def __init__(self):
|
|
super(ZipReader, self).__init__()
|
|
|
|
@staticmethod
|
|
def get_zipfile(path):
|
|
zip_bank = ZipReader.zip_bank
|
|
if path not in zip_bank:
|
|
zfile = zipfile.ZipFile(path, 'r')
|
|
zip_bank[path] = zfile
|
|
return zip_bank[path]
|
|
|
|
@staticmethod
|
|
def split_zip_style_path(path):
|
|
pos_at = path.index('@')
|
|
assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
|
|
|
|
zip_path = path[0: pos_at]
|
|
folder_path = path[pos_at + 1:]
|
|
folder_path = str.strip(folder_path, '/')
|
|
return zip_path, folder_path
|
|
|
|
@staticmethod
|
|
def list_folder(path):
|
|
zip_path, folder_path = ZipReader.split_zip_style_path(path)
|
|
|
|
zfile = ZipReader.get_zipfile(zip_path)
|
|
folder_list = []
|
|
for file_foler_name in zfile.namelist():
|
|
file_foler_name = str.strip(file_foler_name, '/')
|
|
if file_foler_name.startswith(folder_path) and \
|
|
len(os.path.splitext(file_foler_name)[-1]) == 0 and \
|
|
file_foler_name != folder_path:
|
|
if len(folder_path) == 0:
|
|
folder_list.append(file_foler_name)
|
|
else:
|
|
folder_list.append(file_foler_name[len(folder_path) + 1:])
|
|
|
|
return folder_list
|
|
|
|
@staticmethod
|
|
def list_files(path, extension=None):
|
|
if extension is None:
|
|
extension = ['.*']
|
|
zip_path, folder_path = ZipReader.split_zip_style_path(path)
|
|
|
|
zfile = ZipReader.get_zipfile(zip_path)
|
|
file_lists = []
|
|
for file_foler_name in zfile.namelist():
|
|
file_foler_name = str.strip(file_foler_name, '/')
|
|
if file_foler_name.startswith(folder_path) and \
|
|
str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
|
|
if len(folder_path) == 0:
|
|
file_lists.append(file_foler_name)
|
|
else:
|
|
file_lists.append(file_foler_name[len(folder_path) + 1:])
|
|
|
|
return file_lists
|
|
|
|
@staticmethod
|
|
def read(path):
|
|
zip_path, path_img = ZipReader.split_zip_style_path(path)
|
|
zfile = ZipReader.get_zipfile(zip_path)
|
|
data = zfile.read(path_img)
|
|
return data
|
|
|
|
@staticmethod
|
|
def imread(path):
|
|
zip_path, path_img = ZipReader.split_zip_style_path(path)
|
|
zfile = ZipReader.get_zipfile(zip_path)
|
|
data = zfile.read(path_img)
|
|
try:
|
|
im = Image.open(io.BytesIO(data))
|
|
except:
|
|
print("ERROR IMG LOADED: ", path_img)
|
|
random_img = np.random.rand(224, 224, 3) * 255
|
|
im = Image.fromarray(np.uint8(random_img))
|
|
return im
|