Source code for graphvite.dataset

# Copyright 2019 MilaGraph. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Author: Zhaocheng Zhu

"""
Dataset module of GraphVite

Graph

- :class:`BlogCatalog`
- :class:`Youtube`
- :class:`Flickr`
- :class:`Hyperlink2012`
- :class:`Friendster`
- :class:`Wikipedia`

Knowledge Graph

- :class:`FB15k`
- :class:`FB15k237`
- :class:`WN18`
- :class:`WN18RR`
- :class:`Freebase`

Visualization

- :class:`MNIST`
- :class:`CIFAR10`
- :class:`ImageNet`
"""
from __future__ import absolute_import

import os
import glob
import shutil
import logging
import gzip, zipfile, tarfile
import multiprocessing
from collections import defaultdict

import numpy as np

from . import cfg

logger = logging.getLogger(__name__)


[docs]class Dataset(object): """ Graph dataset. Parameters: name (str): name of dataset urls (dict, optional): url(s) for each split, can be either str or list of str members (dict, optional): zip member(s) for each split, leave empty for default Datasets contain several splits, such as train, valid and test. For each split, there are one or more URLs, specifying the file to download. You may also specify the zip member to extract. When a split is accessed, it will be automatically downloaded and decompressed if it is not present. You can assign a preprocess for each split, by defining a function with name [split]_preprocess:: class MyDataset(Dataset): def __init__(self): super(MyDataset, self).__init__( "my_dataset", train="url/to/train/split", test="url/to/test/split" ) def train_preprocess(self, input_file, output_file): with open(input_file, "r") as fin, open(output_file, "w") as fout: fout.write(fin.read()) f = open(MyDataset().train) If the preprocess returns a non-trivial value, then it is assigned to the split, otherwise the file name is assigned. By convention, only splits ending with ``_data`` have non-trivial return value. See also: Pre-defined preprocess functions :func:`csv2txt` :func:`top_k_label`, :func:`induced_graph` :func:`link_prediction_split` :func:`image_feature_data` """ def __init__(self, name, urls=None, members=None): self.name = name self.urls = urls or {} self.members = members or {} for key in self.urls: if isinstance(self.urls[key], str): self.urls[key] = [self.urls[key]] if key not in self.members: self.members[key] = [None] * len(self.urls[key]) elif isinstance(self.members[key], str): self.members[key] = [self.members[key]] if len(self.urls[key]) != len(self.members[key]): raise ValueError("Number of members is inconsistent with number of urls in `%s`" % key) self.path = os.path.join(cfg.dataset_path, self.name) def relpath(self, path): return os.path.relpath(path, self.path) def download(self, url): from six.moves.urllib.request import urlretrieve save_file = os.path.join(self.path, os.path.basename(url)) if save_file in self.local_files(): return save_file logger.info("downloading %s to %s" % (url, self.relpath(save_file))) urlretrieve(url, save_file) return save_file def extract(self, zip_file, member=None): zip_name, extension = os.path.splitext(zip_file) if zip_name.endswith(".tar"): extension = ".tar" + extension zip_name = zip_name[:-4] if extension == ".txt": return zip_file elif member is None: save_file = zip_name else: save_file = os.path.join(os.path.dirname(zip_name), os.path.basename(member)) if save_file in self.local_files(): return save_file if extension == ".gz": logger.info("extracting %s to %s" % (self.relpath(zip_file), self.relpath(save_file))) with gzip.open(zip_file, "rb") as fin, open(save_file, "wb") as fout: shutil.copyfileobj(fin, fout) elif extension == ".tar.gz" or extension == ".tar": if member is None: logger.info("extracting %s to %s" % (self.relpath(zip_file), self.relpath(save_file))) with tarfile.open(zip_file, "r") as fin: fin.extractall(save_file) else: logger.info("extracting %s from %s to %s" % (member, self.relpath(zip_file), self.relpath(save_file))) with tarfile.open(zip_file, "r").extractfile(member) as fin, open(save_file, "wb") as fout: shutil.copyfileobj(fin, fout) elif extension == ".zip": if member is None: logger.info("extracting %s to %s" % (self.relpath(zip_file), self.relpath(save_file))) with zipfile.ZipFile(zip_file) as fin: fin.extractall(save_file) else: logger.info("extracting %s from %s to %s" % (member, self.relpath(zip_file), self.relpath(save_file))) with zipfile.ZipFile(zip_file).open(member, "r") as fin, open(save_file, "wb") as fout: shutil.copyfileobj(fin, fout) else: raise ValueError("Unknown file extension `%s`" % extension) return save_file def get_file(self, key): file_name = os.path.join(self.path, "%s_%s.txt" % (self.name, key)) if file_name in self.local_files(): return file_name urls = self.urls[key] members = self.members[key] preprocess_name = key + "_preprocess" preprocess = getattr(self, preprocess_name, None) if len(urls) > 1 and preprocess is None: raise AttributeError( "There are non-trivial number of files, but function `%s` is not found" % preprocess_name) extract_files = [] for url, member in zip(urls, members): download_file = self.download(url) extract_file = self.extract(download_file, member) extract_files.append(extract_file) if preprocess: result = preprocess(*(extract_files + [file_name])) if result is not None: return result elif os.path.isfile(extract_files[0]): logger.info("renaming %s to %s" % (self.relpath(extract_files[0]), self.relpath(file_name))) shutil.move(extract_files[0], file_name) else: raise AttributeError( "There are non-trivial number of files, but function `%s` is not found" % preprocess_name) return file_name def local_files(self): if not os.path.exists(self.path): os.mkdir(self.path) return set(glob.glob(os.path.join(self.path, "*"))) def __getattr__(self, key): if key in self.__dict__: return self.__dict__[key] if key in self.urls: return self.get_file(key) raise AttributeError("Can't resolve split `%s`" % key)
[docs] def csv2txt(self, csv_file, txt_file): """ Convert ``csv`` to ``txt``. Parameters: csv_file: csv file txt_file: txt file """ logger.info("converting %s to %s" % (self.relpath(csv_file), self.relpath(txt_file))) with open(csv_file, "r") as fin, open(txt_file, "w") as fout: for line in fin: fout.write(line.replace(",", "\t"))
[docs] def top_k_label(self, label_file, save_file, k, format="node-label"): """ Extract top-k labels. Parameters: label_file (str): label file save_file (str): save file k (int): top-k labels will be extracted format (str, optional): format of label file, can be 'node-label' or '(label)-nodes': - **node-label**: each line is [node] [label] - **(label)-nodes**: each line is [node]..., no explicit label """ logger.info("extracting top-%d labels of %s to %s" % (k, self.relpath(label_file), self.relpath(save_file))) if format == "node-label": label2nodes = defaultdict(list) with open(label_file, "r") as fin: for line in fin: node, label = line.split() label2nodes[label].append(node) elif format == "(label)-nodes": label2nodes = {} with open(label_file, "r") as fin: for i, line in enumerate(fin): label2nodes[i] = line.split() else: raise ValueError("Unknown file format `%s`" % format) labels = sorted(label2nodes, key=lambda x: len(label2nodes[x]), reverse=True)[:k] with open(save_file, "w") as fout: for label in sorted(labels): for node in sorted(label2nodes[label]): fout.write("%s\t%s\n" % (node, label))
[docs] def induced_graph(self, graph_file, label_file, save_file): """ Induce a subgraph from labeled nodes. All edges in the induced graph have at least one labeled node. Parameters: graph_file (str): graph file label_file (str): label file save_file (str): save file """ logger.info("extracting subgraph of %s induced by %s to %s" % (self.relpath(graph_file), self.relpath(label_file), self.relpath(save_file))) nodes = set() with open(label_file, "r") as fin: for line in fin: nodes.update(line.split()) with open(graph_file, "r") as fin, open(save_file, "w") as fout: for line in fin: if not line.startswith("#"): u, v = line.split() if u not in nodes or v not in nodes: continue fout.write("%s\t%s\n" % (u, v))
[docs] def image_feature_data(self, dataset, model="resnet50", batch_size=128): """ Infer feature vectors on a dataset using a neural network. Parameters: dataset (torch.utils.data.Dataset): dataset model (str or torch.nn.Module, optional): pretrained model. If it is a str, use the last hidden layer of that model. batch_size (int, optional): batch size """ import torch import torchvision from torch import nn logger.info("computing %s feature" % model) if isinstance(model, str): full_model = getattr(torchvision.models, model)(pretrained=True) model = nn.Sequential(*list(full_model.children())[:-1]) num_worker = multiprocessing.cpu_count() data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_worker, shuffle=False) model = model.cuda() model.eval() features = [] with torch.no_grad(): for i, (batch_images, batch_labels) in enumerate(data_loader): if i % 100 == 0: logger.info("%g%%" % (100.0 * i * batch_size / len(dataset))) batch_images = batch_images.cuda() batch_features = model(batch_images).view(batch_images.size(0), -1).cpu().numpy() features.append(batch_features) features = np.concatenate(features) return features
[docs]class BlogCatalog(Dataset): """ BlogCatalog social network dataset. Splits: train, label """ def __init__(self): super(BlogCatalog, self).__init__( "blogcatalog", urls={ "train": "http://socialcomputing.asu.edu/uploads/1283153973/BlogCatalog-dataset.zip", "label": "http://socialcomputing.asu.edu/uploads/1283153973/BlogCatalog-dataset.zip" }, members={ "train": "BlogCatalog-dataset/data/edges.csv", "label": "BlogCatalog-dataset/data/group-edges.csv" } ) def train_preprocess(self, raw_file, save_file): self.csv2txt(raw_file, save_file) def label_preprocess(self, raw_file, save_file): self.csv2txt(raw_file, save_file)
[docs]class Youtube(Dataset): """ Youtube social network dataset. Splits: train, label """ def __init__(self): super(Youtube, self).__init__( "youtube", urls={ "train": "http://socialnetworks.mpi-sws.mpg.de/data/youtube-links.txt.gz", "label": "http://socialnetworks.mpi-sws.mpg.de/data/youtube-groupmemberships.txt.gz" } ) def label_preprocess(self, raw_file, save_file): self.top_k_label(raw_file, save_file, k=47)
[docs]class Flickr(Dataset): """ Flickr social network dataset. Splits: train, label """ def __init__(self): super(Flickr, self).__init__( "flickr", urls={ "train": "http://socialnetworks.mpi-sws.mpg.de/data/flickr-links.txt.gz", "label": "http://socialnetworks.mpi-sws.mpg.de/data/flickr-groupmemberships.txt.gz" } ) def label_preprocess(self, label_file, save_file): self.top_k_label(label_file, save_file, k=5)
[docs]class Hyperlink2012(Dataset): """ Hyperlink 2012 graph dataset. Splits: pld_train, pld_test """ def __init__(self): super(Hyperlink2012, self).__init__( "hyperlink2012", urls={ "pld_train": "http://data.dws.informatik.uni-mannheim.de/hyperlinkgraph/2012-08/pld-arc.gz", "pld_test": "http://data.dws.informatik.uni-mannheim.de/hyperlinkgraph/2012-08/pld-arc.gz" } ) def pld_train_preprocess(self, graph_file, train_file): test_file = train_file[:train_file.rfind("pld_train.txt")] + "pld_test.txt" self.link_prediction_split(graph_file, train_file, test_file, portion=1e-4) def pld_test_preprocess(self, graph_file, test_file): train_file = test_file[:test_file.rfind("pld_test.txt")] + "pld_train.txt" self.link_prediction_split(graph_file, train_file, test_file, portion=1e-4)
[docs]class Friendster(Dataset): """ Friendster social network dataset. Splits: train, small_train, label """ def __init__(self): super(Friendster, self).__init__( "friendster", urls={ "train": "https://snap.stanford.edu/data/bigdata/communities/com-friendster.ungraph.txt.gz", "small_train": ["https://snap.stanford.edu/data/bigdata/communities/com-friendster.ungraph.txt.gz", "https://snap.stanford.edu/data/bigdata/communities/com-friendster.all.cmty.txt.gz"], "label": "https://snap.stanford.edu/data/bigdata/communities/com-friendster.top5000.cmty.txt.gz" } ) def small_train_preprocess(self, graph_file, label_file, save_file): self.induced_graph(graph_file, label_file, save_file) def label_preprocess(self, label_file, save_file): self.top_k_label(label_file, save_file, k=100, format="(label)-nodes")
[docs]class Wikipedia(Dataset): """ Wikipedia dump for word embedding. Splits: train """ def __init__(self): super(Wikipedia, self).__init__( "wikipedia", urls={ "train": "https://www.dropbox.com/s/mwt4uu1qu9fflfk/enwiki-latest-pages-articles-sentences.txt.gz" } )
[docs]class FB15k(Dataset): """ FB15k knowledge graph dataset. Splits: train, valid, test """ def __init__(self): super(FB15k, self).__init__( "fb15k", urls={ "train": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k/train.txt", "valid": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k/valid.txt", "test": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k/test.txt" } )
[docs]class FB15k237(Dataset): """ FB15k-237 knowledge graph dataset. Splits: train, valid, test """ def __init__(self): super(FB15k237, self).__init__( "fb15k-237", urls={ "train": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k-237/train.txt", "valid": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k-237/valid.txt", "test": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k-237/test.txt" } )
[docs]class WN18(Dataset): """ WN18 knowledge graph dataset. Splits: train, valid, test """ def __init__(self): super(WN18, self).__init__( "wn18", urls={ "train": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18/train.txt", "valid": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18/valid.txt", "test": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18/test.txt" } )
[docs]class WN18RR(Dataset): """ WN18RR knowledge graph dataset. Splits: train, valid, test """ def __init__(self): super(WN18RR, self).__init__( "wn18rr", urls={ "train": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18rr/train.txt", "valid": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18rr/valid.txt", "test": "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18rr/test.txt" } )
[docs]class Freebase(Dataset): """ Freebase knowledge graph dataset. Splits: train """ def __init__(self): super(Freebase, self).__init__( "freebase", urls={ "train": "http://commondatastorage.googleapis.com/freebase-public/rdf/freebase-rdf-latest.gz" } )
[docs]class MNIST(Dataset): """ MNIST dataset for visualization. Splits: train_image_data, train_label_data, test_image_data, test_label_data, image_data, label_data """ def __init__(self): super(MNIST, self).__init__( "mnist", urls={ "train_image_data": "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "train_label_data": "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "test_image_data": "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "test_label_data": "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "image_data": [], # depends on `train_image_data` & `test_image_data` "label_data": [] # depends on `train_label_data` & `test_label_data` } ) def train_image_data_preprocess(self, raw_file, save_file): images = np.fromfile(raw_file, dtype=np.uint8) return images[16:].reshape(-1, 28*28) def train_label_data_preprocess(self, raw_file, save_file): labels = np.fromfile(raw_file, dtype=np.uint8) return labels[8:] test_image_data_preprocess = train_image_data_preprocess test_label_data_preprocess = train_label_data_preprocess def image_data_preprocess(self, save_file): return np.concatenate([self.train_image_data, self.test_image_data]) def label_data_preprocess(self, save_file): return np.concatenate([self.train_label_data, self.test_label_data])
[docs]class CIFAR10(Dataset): """ CIFAR10 dataset for visualization. Splits: train_image_data, train_label_data, test_image_data, test_label_data, image_data, label_data """ def __init__(self): super(CIFAR10, self).__init__( "cifar10", urls={ "train_image_data": "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", "train_label_data": "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", "test_image_data": "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", "test_label_data": "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", "image_data": [], # depends on `train_image_data` & `test_image_data` "label_data": [] # depends on `train_label_data` & `test_label_data` }, ) def load_images(self, *batch_files): images = [] for batch_file in batch_files: batch = np.fromfile(batch_file, dtype=np.uint8) batch = batch.reshape(-1, 32*32*3 + 1) images.append(batch[:, 1:]) return np.concatenate(images) def load_labels(self, meta_file, *batch_files): classes = [] with open(meta_file, "r") as fin: for line in fin: line = line.strip() if line: classes.append(line) classes = np.asarray(classes) labels = [] for batch_file in batch_files: batch = np.fromfile(batch_file, dtype=np.uint8) batch = batch.reshape(-1, 32*32*3 + 1) labels.append(batch[:, 0]) return classes[np.concatenate(labels)] def train_image_data_preprocess(self, raw_path, save_file): batch_files = glob.glob(os.path.join(raw_path, "cifar-10-batches-bin/data_batch_*.bin")) return self.load_images(*batch_files) def train_label_data_preprocess(self, raw_path, save_file): meta_file = os.path.join(raw_path, "cifar-10-batches-bin/batches.meta.txt") batch_files = glob.glob(os.path.join(raw_path, "cifar-10-batches-bin/data_batch_*.bin")) return self.load_labels(meta_file, *batch_files) def test_image_data_preprocess(self, raw_path, save_path): batch_file = os.path.join(raw_path, "cifar-10-batches-bin/test_batch.bin") return self.load_images(batch_file) def test_label_data_preprocess(self, raw_path, save_path): meta_file = os.path.join(raw_path, "cifar-10-batches-bin/batches.meta.txt") batch_file = os.path.join(raw_path, "cifar-10-batches-bin/test_batch.bin") return self.load_labels(meta_file, batch_file) def image_data_preprocess(self, save_path): return np.concatenate([self.train_image_data, self.test_image_data]) def label_data_preprocess(self, save_path): return np.concatenate([self.train_label_data, self.test_label_data])
[docs]class ImageNet(Dataset): """ ImageNet dataset for visualization. Splits: train_image, train_feature_data, train_label, train_hierarchical_label, valid_image, valid_feature_data, valid_label, valid_hierarchical_label """ def __init__(self): super(ImageNet, self).__init__( "imagenet", urls={ "train_image": "http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar", "train_feature_data": [], # depends on `train_image` "train_label": [], # depends on `train_image` "train_hierarchical_label": [], # depends on `train_image` "valid_image": ["http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar", "http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz"], "valid_feature_data": [], # depends on `valid_image` "valid_label": "http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz", "valid_hierarchical_label": "http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz", "feature_data": [], # depends on `train_feature_data` & `valid_feature_data` "label": [], # depends on `train_label` & `valid_label` "hierarchical_label": [], # depends on `train_hierarchical_label` & `valid_hierarchical_label` } ) def import_wordnet(self): import nltk try: nltk.data.find("corpora/wordnet") except LookupError: nltk.download("wordnet") from nltk.corpus import wordnet try: wordnet.synset_from_pos_and_offset except AttributeError: wordnet.synset_from_pos_and_offset = wordnet._synset_from_pos_and_offset return wordnet def get_name(self, synset): name = synset.name() return name[:name.find(".")] def readable_label(self, labels, save_file, hierarchy=False): wordnet = self.import_wordnet() if hierarchy: logger.info("generating human-readable hierarchical labels") else: logger.info("generating human-readable labels") synsets = [] for label in labels: pos = label[0] offset = int(label[1:]) synset = wordnet.synset_from_pos_and_offset(pos, offset) synsets.append(synset) depth = max([synset.max_depth() for synset in synsets]) num_sample = len(synsets) labels = [self.get_name(synset) for synset in synsets] num_class = len(set(labels)) hierarchies = [labels] while hierarchy and num_class > 1: depth -= 1 for i in range(num_sample): if synsets[i].max_depth() > depth: # only takes the first recall synsets[i] = synsets[i].hypernyms()[0] labels = [self.get_name(synset) for synset in synsets] hierarchies.append(labels) num_class = len(set(labels)) hierarchies = hierarchies[::-1] with open(save_file, "w") as fout: for hierarchy in zip(*hierarchies): fout.write("%s\n" % "\t".join(hierarchy)) def cached_feature_data(self, image_path, save_file): numpy_file = os.path.splitext(save_file)[0] + ".npy" if os.path.exists(numpy_file): return np.load(numpy_file) import torchvision from torchvision import transforms augmentation = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = torchvision.datasets.ImageFolder(image_path, augmentation) features = self.image_feature_data(dataset) np.save(numpy_file, features) return features def train_image_preprocess(self, image_path, save_file): tar_files = glob.glob(os.path.join(image_path, "*.tar")) if len(tar_files) == 0: return image_path for tar_file in tar_files: self.extract(tar_file) os.remove(tar_file) return image_path def train_feature_data_preprocess(self, save_file): return self.cached_feature_data(self.train_image, save_file) def train_label_preprocess(self, save_file): image_files = glob.glob(os.path.join(self.train_image, "*/*.JPEG")) labels = [os.path.basename(os.path.dirname(image_file)) for image_file in image_files] # be consistent with the order in torch.utils.data.DataLoader labels = sorted(labels) self.readable_label(labels, save_file) def train_hierarchical_label_preprocess(self, save_file): image_files = glob.glob(os.path.join(self.train_image, "*/*.JPEG")) labels = [os.path.basename(os.path.dirname(image_file)) for image_file in image_files] # be consistent with the order in torch.utils.data.DataLoader labels = sorted(labels) self.readable_label(labels, save_file, hierarchy=True) def valid_image_preprocess(self, image_path, meta_path, save_file): from scipy.io import loadmat image_files = glob.glob(os.path.join(image_path, "*.JPEG")) if len(image_files) == 0: return image_path logger.info("re-arranging images into sub-folders") image_files = sorted(image_files) meta_file = os.path.join(meta_path, "ILSVRC2012_devkit_t12/data/meta.mat") id_file = os.path.join(meta_path, "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt") metas = loadmat(meta_file, squeeze_me=True)["synsets"][:1000] id2class = {meta[0]: meta[1] for meta in metas} ids = np.loadtxt(id_file) labels = [id2class[id] for id in ids] for image_file, label in zip(image_files, labels): class_path = os.path.join(image_path, label) if not os.path.exists(class_path): os.mkdir(class_path) shutil.move(image_file, class_path) return image_path def valid_feature_data_preprocess(self, save_file): return self.cached_feature_data(self.valid_image, save_file) def valid_label_preprocess(self, meta_path, save_file): from scipy.io import loadmat meta_file = os.path.join(meta_path, "ILSVRC2012_devkit_t12/data/meta.mat") id_file = os.path.join(meta_path, "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt") metas = loadmat(meta_file, squeeze_me=True)["synsets"][:1000] id2class = {meta[0]: meta[1] for meta in metas} ids = np.loadtxt(id_file, dtype=np.int32) labels = [id2class[id] for id in ids] # be consistent with the order in torch.utils.data.DataLoader labels = sorted(labels) self.readable_label(labels, save_file) def valid_hierarchical_label_preprocess(self, meta_path, save_file): from scipy.io import loadmat meta_file = os.path.join(meta_path, "ILSVRC2012_devkit_t12/data/meta.mat") id_file = os.path.join(meta_path, "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt") metas = loadmat(meta_file, squeeze_me=True)["synsets"][:1000] id2class = {meta[0]: meta[1] for meta in metas} ids = np.loadtxt(id_file, dtype=np.int32) labels = [id2class[id] for id in ids] # be consistent with the order in torch.utils.data.DataLoader labels = sorted(labels) self.readable_label(labels, save_file, hierarchy=True) def feature_data_preprocess(self, save_file): return np.concatenate([self.train_feature_data, self.valid_feature_data]) def label_preprocess(self, save_file): with open(save_file, "w") as fout: with open(self.train_label, "r") as fin: shutil.copyfileobj(fin, fout) with open(save_file, "a") as fout: with open(self.valid_label, "r") as fin: shutil.copyfileobj(fin, fout) def hierarchical_label_preprocess(self, save_file): with open(save_file, "w") as fout: with open(self.train_hierarchical_label, "r") as fin: shutil.copyfileobj(fin, fout) with open(save_file, "a") as fout: with open(self.valid_hierarchical_label, "r") as fin: shutil.copyfileobj(fin, fout)
blogcatalog = BlogCatalog() youtube = Youtube() flickr = Flickr() hyperlink2012 = Hyperlink2012() friendster = Friendster() wikipedia = Wikipedia() fb15k = FB15k() fb15k237 = FB15k237() wn18 = WN18() wn18rr = WN18RR() freebase = Freebase() mnist = MNIST() cifar10 = CIFAR10() imagenet = ImageNet() __all__ = [ "Dataset", "BlogCatalog", "Youtube", "Flickr", "Hyperlink2012", "Friendster", "Wikipedia", "FB15k", "FB15k237", "WN18", "WN18RR", "Freebase", "MNIST", "CIFAR10", "ImageNet" ]