# Copyright 2019 MilaGraph. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Author: Zhaocheng Zhu
"""
Dataset module of GraphVite
Graph
- :class:`BlogCatalog`
- :class:`Youtube`
- :class:`Flickr`
- :class:`Hyperlink2012`
- :class:`Friendster`
- :class:`Wikipedia`
Knowledge Graph
- :class:`FB15k`
- :class:`FB15k237`
- :class:`WN18`
- :class:`WN18RR`
- :class:`Freebase`
Visualization
- :class:`MNIST`
- :class:`CIFAR10`
- :class:`ImageNet`
"""
from __future__ import absolute_import
import os
import glob
import shutil
import logging
import gzip, zipfile, tarfile
import multiprocessing
from collections import defaultdict
import numpy as np
from . import cfg
logger = logging.getLogger(__name__)
[docs]class Dataset(object):
"""
Graph dataset.
Parameters:
name (str): name of dataset
urls (dict, optional): url(s) for each split,
can be either str or list of str
members (dict, optional): zip member(s) for each split,
leave empty for default
Datasets contain several splits, such as train, valid and test.
For each split, there are one or more URLs, specifying the file to download.
You may also specify the zip member to extract.
When a split is accessed, it will be automatically downloaded and decompressed
if it is not present.
You can assign a preprocess for each split, by defining a function with name [split]_preprocess::
class MyDataset(Dataset):
def __init__(self):
super(MyDataset, self).__init__(
"my_dataset",
train="url/to/train/split",
test="url/to/test/split"
)
def train_preprocess(self, input_file, output_file):
with open(input_file, "r") as fin, open(output_file, "w") as fout:
fout.write(fin.read())
f = open(MyDataset().train)
If the preprocess returns a non-trivial value, then it is assigned to the split,
otherwise the file name is assigned.
By convention, only splits ending with ``_data`` have non-trivial return value.
See also:
Pre-defined preprocess functions
:func:`csv2txt`
:func:`top_k_label`,
:func:`induced_graph`
:func:`link_prediction_split`
:func:`image_feature_data`
"""
def __init__(self, name, urls=None, members=None):
self.name = name
self.urls = urls or {}
self.members = members or {}
for key in self.urls:
if isinstance(self.urls[key], str):
self.urls[key] = [self.urls[key]]
if key not in self.members:
self.members[key] = [None] * len(self.urls[key])
elif isinstance(self.members[key], str):
self.members[key] = [self.members[key]]
if len(self.urls[key]) != len(self.members[key]):
raise ValueError("Number of members is inconsistent with number of urls in `%s`" % key)
self.path = os.path.join(cfg.dataset_path, self.name)
def relpath(self, path):
return os.path.relpath(path, self.path)
def download(self, url):
from six.moves.urllib.request import urlretrieve
save_file = os.path.join(self.path, os.path.basename(url))
if save_file in self.local_files():
return save_file
logger.info("downloading %s to %s" % (url, self.relpath(save_file)))
urlretrieve(url, save_file)
return save_file
def extract(self, zip_file, member=None):
zip_name, extension = os.path.splitext(zip_file)
if zip_name.endswith(".tar"):
extension = ".tar" + extension
zip_name = zip_name[:-4]
if extension == ".txt":
return zip_file
elif member is None:
save_file = zip_name
else:
save_file = os.path.join(os.path.dirname(zip_name), os.path.basename(member))
if save_file in self.local_files():
return save_file
if extension == ".gz":
logger.info("extracting %s to %s" % (self.relpath(zip_file), self.relpath(save_file)))
with gzip.open(zip_file, "rb") as fin, open(save_file, "wb") as fout:
shutil.copyfileobj(fin, fout)
elif extension == ".tar.gz" or extension == ".tar":
if member is None:
logger.info("extracting %s to %s" % (self.relpath(zip_file), self.relpath(save_file)))
with tarfile.open(zip_file, "r") as fin:
fin.extractall(save_file)
else:
logger.info("extracting %s from %s to %s" % (member, self.relpath(zip_file), self.relpath(save_file)))
with tarfile.open(zip_file, "r").extractfile(member) as fin, open(save_file, "wb") as fout:
shutil.copyfileobj(fin, fout)
elif extension == ".zip":
if member is None:
logger.info("extracting %s to %s" % (self.relpath(zip_file), self.relpath(save_file)))
with zipfile.ZipFile(zip_file) as fin:
fin.extractall(save_file)
else:
logger.info("extracting %s from %s to %s" % (member, self.relpath(zip_file), self.relpath(save_file)))
with zipfile.ZipFile(zip_file).open(member, "r") as fin, open(save_file, "wb") as fout:
shutil.copyfileobj(fin, fout)
else:
raise ValueError("Unknown file extension `%s`" % extension)
return save_file
def get_file(self, key):
file_name = os.path.join(self.path, "%s_%s.txt" % (self.name, key))
if file_name in self.local_files():
return file_name
urls = self.urls[key]
members = self.members[key]
preprocess_name = key + "_preprocess"
preprocess = getattr(self, preprocess_name, None)
if len(urls) > 1 and preprocess is None:
raise AttributeError(
"There are non-trivial number of files, but function `%s` is not found" % preprocess_name)
extract_files = []
for url, member in zip(urls, members):
download_file = self.download(url)
extract_file = self.extract(download_file, member)
extract_files.append(extract_file)
if preprocess:
result = preprocess(*(extract_files + [file_name]))
if result is not None:
return result
elif os.path.isfile(extract_files[0]):
logger.info("renaming %s to %s" % (self.relpath(extract_files[0]), self.relpath(file_name)))
shutil.move(extract_files[0], file_name)
else:
raise AttributeError(
"There are non-trivial number of files, but function `%s` is not found" % preprocess_name)
return file_name
def local_files(self):
if not os.path.exists(self.path):
os.mkdir(self.path)
return set(glob.glob(os.path.join(self.path, "*")))
def __getattr__(self, key):
if key in self.__dict__:
return self.__dict__[key]
if key in self.urls:
return self.get_file(key)
raise AttributeError("Can't resolve split `%s`" % key)
[docs] def csv2txt(self, csv_file, txt_file):
"""
Convert ``csv`` to ``txt``.
Parameters:
csv_file: csv file
txt_file: txt file
"""
logger.info("converting %s to %s" % (self.relpath(csv_file), self.relpath(txt_file)))
with open(csv_file, "r") as fin, open(txt_file, "w") as fout:
for line in fin:
fout.write(line.replace(",", "\t"))
[docs] def top_k_label(self, label_file, save_file, k, format="node-label"):
"""
Extract top-k labels.
Parameters:
label_file (str): label file
save_file (str): save file
k (int): top-k labels will be extracted
format (str, optional): format of label file,
can be 'node-label' or '(label)-nodes':
- **node-label**: each line is [node] [label]
- **(label)-nodes**: each line is [node]..., no explicit label
"""
logger.info("extracting top-%d labels of %s to %s" % (k, self.relpath(label_file), self.relpath(save_file)))
if format == "node-label":
label2nodes = defaultdict(list)
with open(label_file, "r") as fin:
for line in fin:
node, label = line.split()
label2nodes[label].append(node)
elif format == "(label)-nodes":
label2nodes = {}
with open(label_file, "r") as fin:
for i, line in enumerate(fin):
label2nodes[i] = line.split()
else:
raise ValueError("Unknown file format `%s`" % format)
labels = sorted(label2nodes, key=lambda x: len(label2nodes[x]), reverse=True)[:k]
with open(save_file, "w") as fout:
for label in sorted(labels):
for node in sorted(label2nodes[label]):
fout.write("%s\t%s\n" % (node, label))
[docs] def induced_graph(self, graph_file, label_file, save_file):
"""
Induce a subgraph from labeled nodes. All edges in the induced graph have at least one labeled node.
Parameters:
graph_file (str): graph file
label_file (str): label file
save_file (str): save file
"""
logger.info("extracting subgraph of %s induced by %s to %s" %
(self.relpath(graph_file), self.relpath(label_file), self.relpath(save_file)))
nodes = set()
with open(label_file, "r") as fin:
for line in fin:
nodes.update(line.split())
with open(graph_file, "r") as fin, open(save_file, "w") as fout:
for line in fin:
if not line.startswith("#"):
u, v = line.split()
if u not in nodes or v not in nodes:
continue
fout.write("%s\t%s\n" % (u, v))
[docs] def link_prediction_split(self, graph_file, train_file, test_file, portion):
"""
Split a graph for link prediction use. The test split will contain half true and half false edges.
Parameters:
graph_file (str): graph file
train_file (str): train file
test_file (str): test file
portion (str): portion of test edges
"""
logger.info("splitting graph %s into %s and %s" %
(self.relpath(graph_file), self.relpath(train_file), self.relpath(test_file)))
np.random.seed(1024)
nodes = set()
edges = set()
num_test = 0
with open(graph_file, "r") as fin, open(train_file, "w") as ftrain, open(test_file, "w") as ftest:
for line in fin:
u, v = line.split()
nodes.update([u, v])
edges.add((u, v))
if np.random.rand() > portion:
ftrain.write("%s\t%s\n" % (u, v))
else:
ftest.write("%s\t%s\t1\n" % (u, v))
num_test += 1
nodes = list(nodes)
with open(test_file, "a") as ftest:
for i in range(num_test):
valid = False
while not valid:
u = nodes[int(np.random.rand() * len(nodes))]
v = nodes[int(np.random.rand() * len(nodes))]
valid = u != v and (u, v) not in edges and (v, u) not in edges
ftest.write("%s\t%s\t0\n" % (u, v))
[docs] def image_feature_data(self, dataset, model="resnet50", batch_size=128):
"""
Infer feature vectors on a dataset using a neural network.
Parameters:
dataset (torch.utils.data.Dataset): dataset
model (str or torch.nn.Module, optional): pretrained model.
If it is a str, use the last hidden layer of that model.
batch_size (int, optional): batch size
"""
import torch
import torchvision
from torch import nn
logger.info("computing %s feature" % model)
if isinstance(model, str):
full_model = getattr(torchvision.models, model)(pretrained=True)
model = nn.Sequential(*list(full_model.children())[:-1])
num_worker = multiprocessing.cpu_count()
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, num_workers=num_worker, shuffle=False)
model = model.cuda()
model.eval()
features = []
with torch.no_grad():
for i, (batch_images, batch_labels) in enumerate(data_loader):
if i % 100 == 0:
logger.info("%g%%" % (100.0 * i * batch_size / len(dataset)))
batch_images = batch_images.cuda()
batch_features = model(batch_images).view(batch_images.size(0), -1).cpu().numpy()
features.append(batch_features)
features = np.concatenate(features)
return features
[docs]class BlogCatalog(Dataset):
"""
BlogCatalog social network dataset.
Splits:
train, label
"""
def __init__(self):
super(BlogCatalog, self).__init__(
"blogcatalog",
urls={
"train": "http://socialcomputing.asu.edu/uploads/1283153973/BlogCatalog-dataset.zip",
"label": "http://socialcomputing.asu.edu/uploads/1283153973/BlogCatalog-dataset.zip"
},
members={
"train": "BlogCatalog-dataset/data/edges.csv",
"label": "BlogCatalog-dataset/data/group-edges.csv"
}
)
def train_preprocess(self, raw_file, save_file):
self.csv2txt(raw_file, save_file)
def label_preprocess(self, raw_file, save_file):
self.csv2txt(raw_file, save_file)
[docs]class Youtube(Dataset):
"""
Youtube social network dataset.
Splits:
train, label
"""
def __init__(self):
super(Youtube, self).__init__(
"youtube",
urls={
"train": "http://socialnetworks.mpi-sws.mpg.de/data/youtube-links.txt.gz",
"label": "http://socialnetworks.mpi-sws.mpg.de/data/youtube-groupmemberships.txt.gz"
}
)
def label_preprocess(self, raw_file, save_file):
self.top_k_label(raw_file, save_file, k=47)
[docs]class Flickr(Dataset):
"""
Flickr social network dataset.
Splits:
train, label
"""
def __init__(self):
super(Flickr, self).__init__(
"flickr",
urls={
"train": "http://socialnetworks.mpi-sws.mpg.de/data/flickr-links.txt.gz",
"label": "http://socialnetworks.mpi-sws.mpg.de/data/flickr-groupmemberships.txt.gz"
}
)
def label_preprocess(self, label_file, save_file):
self.top_k_label(label_file, save_file, k=5)
[docs]class Hyperlink2012(Dataset):
"""
Hyperlink 2012 graph dataset.
Splits:
pld_train, pld_test
"""
def __init__(self):
super(Hyperlink2012, self).__init__(
"hyperlink2012",
urls={
"pld_train": "http://data.dws.informatik.uni-mannheim.de/hyperlinkgraph/2012-08/pld-arc.gz",
"pld_test": "http://data.dws.informatik.uni-mannheim.de/hyperlinkgraph/2012-08/pld-arc.gz"
}
)
def pld_train_preprocess(self, graph_file, train_file):
test_file = train_file[:train_file.rfind("pld_train.txt")] + "pld_test.txt"
self.link_prediction_split(graph_file, train_file, test_file, portion=1e-4)
def pld_test_preprocess(self, graph_file, test_file):
train_file = test_file[:test_file.rfind("pld_test.txt")] + "pld_train.txt"
self.link_prediction_split(graph_file, train_file, test_file, portion=1e-4)
[docs]class Friendster(Dataset):
"""
Friendster social network dataset.
Splits:
train, small_train, label
"""
def __init__(self):
super(Friendster, self).__init__(
"friendster",
urls={
"train": "https://snap.stanford.edu/data/bigdata/communities/com-friendster.ungraph.txt.gz",
"small_train": ["https://snap.stanford.edu/data/bigdata/communities/com-friendster.ungraph.txt.gz",
"https://snap.stanford.edu/data/bigdata/communities/com-friendster.all.cmty.txt.gz"],
"label": "https://snap.stanford.edu/data/bigdata/communities/com-friendster.top5000.cmty.txt.gz"
}
)
def small_train_preprocess(self, graph_file, label_file, save_file):
self.induced_graph(graph_file, label_file, save_file)
def label_preprocess(self, label_file, save_file):
self.top_k_label(label_file, save_file, k=100, format="(label)-nodes")
[docs]class Wikipedia(Dataset):
"""
Wikipedia dump for word embedding.
Splits:
train
"""
def __init__(self):
super(Wikipedia, self).__init__(
"wikipedia",
urls={
"train": "https://www.dropbox.com/s/mwt4uu1qu9fflfk/enwiki-latest-pages-articles-sentences.txt.gz"
}
)
[docs]class FB15k(Dataset):
"""
FB15k knowledge graph dataset.
Splits:
train, valid, test
"""
def __init__(self):
super(FB15k, self).__init__(
"fb15k",
urls={
"train": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k/train.txt",
"valid": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k/valid.txt",
"test": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k/test.txt"
}
)
[docs]class FB15k237(Dataset):
"""
FB15k-237 knowledge graph dataset.
Splits:
train, valid, test
"""
def __init__(self):
super(FB15k237, self).__init__(
"fb15k-237",
urls={
"train": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k-237/train.txt",
"valid": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k-237/valid.txt",
"test": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k-237/test.txt"
}
)
[docs]class WN18(Dataset):
"""
WN18 knowledge graph dataset.
Splits:
train, valid, test
"""
def __init__(self):
super(WN18, self).__init__(
"wn18",
urls={
"train": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18/train.txt",
"valid": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18/valid.txt",
"test": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18/test.txt"
}
)
[docs]class WN18RR(Dataset):
"""
WN18RR knowledge graph dataset.
Splits:
train, valid, test
"""
def __init__(self):
super(WN18RR, self).__init__(
"wn18rr",
urls={
"train": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18rr/train.txt",
"valid": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18rr/valid.txt",
"test": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18rr/test.txt"
}
)
[docs]class Freebase(Dataset):
"""
Freebase knowledge graph dataset.
Splits:
train
"""
def __init__(self):
super(Freebase, self).__init__(
"freebase",
urls={
"train": "http://commondatastorage.googleapis.com/freebase-public/rdf/freebase-rdf-latest.gz"
}
)
[docs]class MNIST(Dataset):
"""
MNIST dataset for visualization.
Splits:
train_image_data, train_label_data, test_image_data, test_label_data, image_data, label_data
"""
def __init__(self):
super(MNIST, self).__init__(
"mnist",
urls={
"train_image_data": "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
"train_label_data": "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
"test_image_data": "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
"test_label_data": "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
"image_data": [], # depends on `train_image_data` & `test_image_data`
"label_data": [] # depends on `train_label_data` & `test_label_data`
}
)
def train_image_data_preprocess(self, raw_file, save_file):
images = np.fromfile(raw_file, dtype=np.uint8)
return images[16:].reshape(-1, 28*28)
def train_label_data_preprocess(self, raw_file, save_file):
labels = np.fromfile(raw_file, dtype=np.uint8)
return labels[8:]
test_image_data_preprocess = train_image_data_preprocess
test_label_data_preprocess = train_label_data_preprocess
def image_data_preprocess(self, save_file):
return np.concatenate([self.train_image_data, self.test_image_data])
def label_data_preprocess(self, save_file):
return np.concatenate([self.train_label_data, self.test_label_data])
[docs]class CIFAR10(Dataset):
"""
CIFAR10 dataset for visualization.
Splits:
train_image_data, train_label_data, test_image_data, test_label_data, image_data, label_data
"""
def __init__(self):
super(CIFAR10, self).__init__(
"cifar10",
urls={
"train_image_data": "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
"train_label_data": "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
"test_image_data": "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
"test_label_data": "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
"image_data": [], # depends on `train_image_data` & `test_image_data`
"label_data": [] # depends on `train_label_data` & `test_label_data`
},
)
def load_images(self, *batch_files):
images = []
for batch_file in batch_files:
batch = np.fromfile(batch_file, dtype=np.uint8)
batch = batch.reshape(-1, 32*32*3 + 1)
images.append(batch[:, 1:])
return np.concatenate(images)
def load_labels(self, meta_file, *batch_files):
classes = []
with open(meta_file, "r") as fin:
for line in fin:
line = line.strip()
if line:
classes.append(line)
classes = np.asarray(classes)
labels = []
for batch_file in batch_files:
batch = np.fromfile(batch_file, dtype=np.uint8)
batch = batch.reshape(-1, 32*32*3 + 1)
labels.append(batch[:, 0])
return classes[np.concatenate(labels)]
def train_image_data_preprocess(self, raw_path, save_file):
batch_files = glob.glob(os.path.join(raw_path, "cifar-10-batches-bin/data_batch_*.bin"))
return self.load_images(*batch_files)
def train_label_data_preprocess(self, raw_path, save_file):
meta_file = os.path.join(raw_path, "cifar-10-batches-bin/batches.meta.txt")
batch_files = glob.glob(os.path.join(raw_path, "cifar-10-batches-bin/data_batch_*.bin"))
return self.load_labels(meta_file, *batch_files)
def test_image_data_preprocess(self, raw_path, save_path):
batch_file = os.path.join(raw_path, "cifar-10-batches-bin/test_batch.bin")
return self.load_images(batch_file)
def test_label_data_preprocess(self, raw_path, save_path):
meta_file = os.path.join(raw_path, "cifar-10-batches-bin/batches.meta.txt")
batch_file = os.path.join(raw_path, "cifar-10-batches-bin/test_batch.bin")
return self.load_labels(meta_file, batch_file)
def image_data_preprocess(self, save_path):
return np.concatenate([self.train_image_data, self.test_image_data])
def label_data_preprocess(self, save_path):
return np.concatenate([self.train_label_data, self.test_label_data])
[docs]class ImageNet(Dataset):
"""
ImageNet dataset for visualization.
Splits:
train_image, train_feature_data, train_label, train_hierarchical_label,
valid_image, valid_feature_data, valid_label, valid_hierarchical_label
"""
def __init__(self):
super(ImageNet, self).__init__(
"imagenet",
urls={
"train_image": "http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar",
"train_feature_data": [], # depends on `train_image`
"train_label": [], # depends on `train_image`
"train_hierarchical_label": [], # depends on `train_image`
"valid_image": ["http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar",
"http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz"],
"valid_feature_data": [], # depends on `valid_image`
"valid_label": "http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz",
"valid_hierarchical_label":
"http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz",
"feature_data": [], # depends on `train_feature_data` & `valid_feature_data`
"label": [], # depends on `train_label` & `valid_label`
"hierarchical_label": [], # depends on `train_hierarchical_label` & `valid_hierarchical_label`
}
)
def import_wordnet(self):
import nltk
try:
nltk.data.find("corpora/wordnet")
except LookupError:
nltk.download("wordnet")
from nltk.corpus import wordnet
try:
wordnet.synset_from_pos_and_offset
except AttributeError:
wordnet.synset_from_pos_and_offset = wordnet._synset_from_pos_and_offset
return wordnet
def get_name(self, synset):
name = synset.name()
return name[:name.find(".")]
def readable_label(self, labels, save_file, hierarchy=False):
wordnet = self.import_wordnet()
if hierarchy:
logger.info("generating human-readable hierarchical labels")
else:
logger.info("generating human-readable labels")
synsets = []
for label in labels:
pos = label[0]
offset = int(label[1:])
synset = wordnet.synset_from_pos_and_offset(pos, offset)
synsets.append(synset)
depth = max([synset.max_depth() for synset in synsets])
num_sample = len(synsets)
labels = [self.get_name(synset) for synset in synsets]
num_class = len(set(labels))
hierarchies = [labels]
while hierarchy and num_class > 1:
depth -= 1
for i in range(num_sample):
if synsets[i].max_depth() > depth:
# only takes the first recall
synsets[i] = synsets[i].hypernyms()[0]
labels = [self.get_name(synset) for synset in synsets]
hierarchies.append(labels)
num_class = len(set(labels))
hierarchies = hierarchies[::-1]
with open(save_file, "w") as fout:
for hierarchy in zip(*hierarchies):
fout.write("%s\n" % "\t".join(hierarchy))
def cached_feature_data(self, image_path, save_file):
numpy_file = os.path.splitext(save_file)[0] + ".npy"
if os.path.exists(numpy_file):
return np.load(numpy_file)
import torchvision
from torchvision import transforms
augmentation = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = torchvision.datasets.ImageFolder(image_path, augmentation)
features = self.image_feature_data(dataset)
np.save(numpy_file, features)
return features
def train_image_preprocess(self, image_path, save_file):
tar_files = glob.glob(os.path.join(image_path, "*.tar"))
if len(tar_files) == 0:
return image_path
for tar_file in tar_files:
self.extract(tar_file)
os.remove(tar_file)
return image_path
def train_feature_data_preprocess(self, save_file):
return self.cached_feature_data(self.train_image, save_file)
def train_label_preprocess(self, save_file):
image_files = glob.glob(os.path.join(self.train_image, "*/*.JPEG"))
labels = [os.path.basename(os.path.dirname(image_file)) for image_file in image_files]
# be consistent with the order in torch.utils.data.DataLoader
labels = sorted(labels)
self.readable_label(labels, save_file)
def train_hierarchical_label_preprocess(self, save_file):
image_files = glob.glob(os.path.join(self.train_image, "*/*.JPEG"))
labels = [os.path.basename(os.path.dirname(image_file)) for image_file in image_files]
# be consistent with the order in torch.utils.data.DataLoader
labels = sorted(labels)
self.readable_label(labels, save_file, hierarchy=True)
def valid_image_preprocess(self, image_path, meta_path, save_file):
from scipy.io import loadmat
image_files = glob.glob(os.path.join(image_path, "*.JPEG"))
if len(image_files) == 0:
return image_path
logger.info("re-arranging images into sub-folders")
image_files = sorted(image_files)
meta_file = os.path.join(meta_path, "ILSVRC2012_devkit_t12/data/meta.mat")
id_file = os.path.join(meta_path, "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt")
metas = loadmat(meta_file, squeeze_me=True)["synsets"][:1000]
id2class = {meta[0]: meta[1] for meta in metas}
ids = np.loadtxt(id_file)
labels = [id2class[id] for id in ids]
for image_file, label in zip(image_files, labels):
class_path = os.path.join(image_path, label)
if not os.path.exists(class_path):
os.mkdir(class_path)
shutil.move(image_file, class_path)
return image_path
def valid_feature_data_preprocess(self, save_file):
return self.cached_feature_data(self.valid_image, save_file)
def valid_label_preprocess(self, meta_path, save_file):
from scipy.io import loadmat
meta_file = os.path.join(meta_path, "ILSVRC2012_devkit_t12/data/meta.mat")
id_file = os.path.join(meta_path, "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt")
metas = loadmat(meta_file, squeeze_me=True)["synsets"][:1000]
id2class = {meta[0]: meta[1] for meta in metas}
ids = np.loadtxt(id_file, dtype=np.int32)
labels = [id2class[id] for id in ids]
# be consistent with the order in torch.utils.data.DataLoader
labels = sorted(labels)
self.readable_label(labels, save_file)
def valid_hierarchical_label_preprocess(self, meta_path, save_file):
from scipy.io import loadmat
meta_file = os.path.join(meta_path, "ILSVRC2012_devkit_t12/data/meta.mat")
id_file = os.path.join(meta_path, "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt")
metas = loadmat(meta_file, squeeze_me=True)["synsets"][:1000]
id2class = {meta[0]: meta[1] for meta in metas}
ids = np.loadtxt(id_file, dtype=np.int32)
labels = [id2class[id] for id in ids]
# be consistent with the order in torch.utils.data.DataLoader
labels = sorted(labels)
self.readable_label(labels, save_file, hierarchy=True)
def feature_data_preprocess(self, save_file):
return np.concatenate([self.train_feature_data, self.valid_feature_data])
def label_preprocess(self, save_file):
with open(save_file, "w") as fout:
with open(self.train_label, "r") as fin:
shutil.copyfileobj(fin, fout)
with open(save_file, "a") as fout:
with open(self.valid_label, "r") as fin:
shutil.copyfileobj(fin, fout)
def hierarchical_label_preprocess(self, save_file):
with open(save_file, "w") as fout:
with open(self.train_hierarchical_label, "r") as fin:
shutil.copyfileobj(fin, fout)
with open(save_file, "a") as fout:
with open(self.valid_hierarchical_label, "r") as fin:
shutil.copyfileobj(fin, fout)
blogcatalog = BlogCatalog()
youtube = Youtube()
flickr = Flickr()
hyperlink2012 = Hyperlink2012()
friendster = Friendster()
wikipedia = Wikipedia()
fb15k = FB15k()
fb15k237 = FB15k237()
wn18 = WN18()
wn18rr = WN18RR()
freebase = Freebase()
mnist = MNIST()
cifar10 = CIFAR10()
imagenet = ImageNet()
__all__ = [
"Dataset",
"BlogCatalog", "Youtube", "Flickr", "Hyperlink2012", "Friendster", "Wikipedia",
"FB15k", "FB15k237", "WN18", "WN18RR", "Freebase",
"MNIST", "CIFAR10", "ImageNet"
]