finding task-relevant features for few-shot learning by ... · module...

10
Finding Task-Relevant Features for Few-Shot Learning by Category Traversal Hongyang Li 1,2* David Eigen 2 Samuel Dodge 2 Matthew Zeiler 2 Xiaogang Wang 1,3 1 The Chinese University of Hong Kong 2 Clarifai Inc. 3 SenseTime Research {yangli,xgwang}@ee.cuhk.edu.hk {deigen,samuel,zeiler}@clarifai.com Abstract Few-shot learning is an important area of research. Con- ceptually, humans are readily able to understand new con- cepts given just a few examples, while in more pragmatic terms, limited-example training situations are common in practice. Recent effective approaches to few-shot learn- ing employ a metric-learning framework to learn a feature similarity comparison between a query (test) example, and the few support (training) examples. However, these ap- proaches treat each support class independently from one another, never looking at the entire task as a whole. Be- cause of this, they are constrained to use a single set of fea- tures for all possible test-time tasks, which hinders the abil- ity to distinguish the most relevant dimensions for the task at hand. In this work, we introduce a Category Traversal Module that can be inserted as a plug-and-play module into most metric-learning based few-shot learners. This compo- nent traverses across the entire support set at once, identi- fying task-relevant features based on both intra-class com- monality and inter-class uniqueness in the feature space. Incorporating our module improves performance consider- ably (5%-10% relative) over baseline systems on both mini- ImageNet and tieredImageNet benchmarks, with overall performance competitive with recent state-of-the-art sys- tems. 1. Introduction The goal of few-shot learning [38, 35, 30, 36, 33, 7, 25, 26] is to classify unseen data instances (query examples) into a set of new categories, given just a small number of labeled instances in each class (support examples). Typi- cally, there are between 1 and 10 labeled examples per class in the support set; this stands in contrast to the standard classification problem [19, 16, 20], in which there are of- ten thousands per class. Also classes for training and test set are the same in traditional problem whereas in few-shot learning the two sets are exclusive. A key challenge in few- * This paper is the product of work during an internship at Clarifai Inc. ISOLATED (a) averaging/ embedding/ TRAVERSED (b) Category Traversal mask Figure 1. (a) A high-level illustration of the metric-based algo- rithms for few-shot learning. Both support and query are first fed into a feature extractor f θ ; in previous methods, the query is com- pared with support based on the feature similarity individually, without associating the most relevant information across classes. (b) The proposed Category Traversal Module (CTM) looks at all categories in the support set to find task-relevant features. shot learning, therefore, is to make best use of the limited data available in the support set in order to find the “right” generalizations as suggested by the task. A recent effective approach to this problem is to train a neural network to produce feature embeddings used to compare the query to each of the support samples. This is similar to metric learning using Siamese Networks [4, 6], trained explicitly for the few-shot classification problem by iteratively sampling a query and support set from a larger la- beled training set, and optimizing to maximize a similarity score between the query and same-labeled examples. Op- timizing for similarity between query and support has been very effective [35, 38, 31, 14, 21, 2, 13]. Fig. 1 (a) illus- trates such a mechanism at a high level. However, while these approaches are able to learn rich features, the features are generated for each class in the sup- arXiv:1905.11116v1 [cs.CV] 27 May 2019

Upload: others

Post on 27-Oct-2019

1 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: Finding Task-Relevant Features for Few-Shot Learning by ... · Module thatcanbeinsertedasaplug-and-playmoduleinto most metric-learning based few-shot learners. This compo-nent traverses

Finding Task-Relevant Features for Few-Shot Learning by Category Traversal

Hongyang Li1,2∗ David Eigen2 Samuel Dodge2 Matthew Zeiler2 Xiaogang Wang1,3

1The Chinese University of Hong Kong 2Clarifai Inc. 3SenseTime Research{yangli,xgwang}@ee.cuhk.edu.hk {deigen,samuel,zeiler}@clarifai.com

Abstract

Few-shot learning is an important area of research. Con-ceptually, humans are readily able to understand new con-cepts given just a few examples, while in more pragmaticterms, limited-example training situations are common inpractice. Recent effective approaches to few-shot learn-ing employ a metric-learning framework to learn a featuresimilarity comparison between a query (test) example, andthe few support (training) examples. However, these ap-proaches treat each support class independently from oneanother, never looking at the entire task as a whole. Be-cause of this, they are constrained to use a single set of fea-tures for all possible test-time tasks, which hinders the abil-ity to distinguish the most relevant dimensions for the taskat hand. In this work, we introduce a Category TraversalModule that can be inserted as a plug-and-play module intomost metric-learning based few-shot learners. This compo-nent traverses across the entire support set at once, identi-fying task-relevant features based on both intra-class com-monality and inter-class uniqueness in the feature space.Incorporating our module improves performance consider-ably (5%-10% relative) over baseline systems on both mini-ImageNet and tieredImageNet benchmarks, with overallperformance competitive with recent state-of-the-art sys-tems.

1. Introduction

The goal of few-shot learning [38, 35, 30, 36, 33, 7, 25,26] is to classify unseen data instances (query examples)into a set of new categories, given just a small number oflabeled instances in each class (support examples). Typi-cally, there are between 1 and 10 labeled examples per classin the support set; this stands in contrast to the standardclassification problem [19, 16, 20], in which there are of-ten thousands per class. Also classes for training and testset are the same in traditional problem whereas in few-shotlearning the two sets are exclusive. A key challenge in few-

∗This paper is the product of work during an internship at Clarifai Inc.

ISOLATED(a) averaging/

embedding/ …

TRAVERSED(b)

Category Traversal

mask

Figure 1. (a) A high-level illustration of the metric-based algo-rithms for few-shot learning. Both support and query are first fedinto a feature extractor fθ; in previous methods, the query is com-pared with support based on the feature similarity individually,without associating the most relevant information across classes.(b) The proposed Category Traversal Module (CTM) looks at allcategories in the support set to find task-relevant features.

shot learning, therefore, is to make best use of the limiteddata available in the support set in order to find the “right”generalizations as suggested by the task.

A recent effective approach to this problem is to traina neural network to produce feature embeddings used tocompare the query to each of the support samples. Thisis similar to metric learning using Siamese Networks [4, 6],trained explicitly for the few-shot classification problem byiteratively sampling a query and support set from a larger la-beled training set, and optimizing to maximize a similarityscore between the query and same-labeled examples. Op-timizing for similarity between query and support has beenvery effective [35, 38, 31, 14, 21, 2, 13]. Fig. 1 (a) illus-trates such a mechanism at a high level.

However, while these approaches are able to learn richfeatures, the features are generated for each class in the sup-

arX

iv:1

905.

1111

6v1

[cs

.CV

] 2

7 M

ay 2

019

Page 2: Finding Task-Relevant Features for Few-Shot Learning by ... · Module thatcanbeinsertedasaplug-and-playmoduleinto most metric-learning based few-shot learners. This compo-nent traverses

i.

ii.

iii.

iv.

v.

Support Set (k-shot)

Color Shape

i.

ii.

iii.

iv.

v.

Support Set (1-shot)

Query

12

1

1

1

distance

(a)

(b)

Figure 2. Toy example illustrating the motivation for task-relevantfeatures. (a) A task defines five classes (i)..(v) with two feature di-mensions, color and shape. The distance between query and sup-port is the same for classes (i, iii, iv, v). However, by taking thecontext of all classes into account, we see the relevant feature iscolor, so the class of the query image should be that of (iii): green.(b) In a k-shot (k > 1) case, most instances in one class sharethe property of color blue whilst their shape differ among them —making the color feature more representative.

port set independently. In this paper, we extend the effectivemetric-learning based approaches to incorporate the contextof the entire support set, viewed as a whole. By includ-ing such a view, our model finds the dimensions most rele-vant to each task. This is particularly important in few-shotlearning: since very little labeled data is available for eachtask, it is imperative to make best use of all of the informa-tion available from the full set of examples, taken together.

As a motivating example, consider Fig.2 (a), which de-picts a 5-way 1-shot task with two simple feature dimen-sions, color and shape. The goal of classification is to de-termine the answer to a question:

To which category in the support set does query belong?

Here, the query is a green circle. We argue that because eachexample in the support set has a unique color, but shares itsshape with other support examples, the relevant feature inthis task is color. Therefore, the correct label assignmentto the query should be class (iii): green. However, if fea-ture similarity to the query is calculated independently foreach class, as is done in [38, 35, 36, 34, 21, 30, 23, 10], itis impossible to know which dimension (color or shape) ismore relevant, hence categories (i, iii, iv, v) would all haveequal score based on distance1. Only by looking at context

1Note that if the feature extractor were to learn that color is more im-portant than shape in order to succeed in this particular task, it would failon the task where color is shared and shape is unique — it is impossiblefor static feature comparison to succeed in both tasks.

of the entire support set can we determine that color is thediscriminating dimension for this task. Such an observa-tion motivates us to traverse all support classes to find theinter-class uniqueness along the feature dimension.

Moreover, under the multi-shot setting shown in Fig. 2(b), it can be even clearer that the color dimension is mostrelevant, since most instances have the same color of blue,while their shape varies. Thus, in addition to inter-classuniqueness, relevant dimensions can also be found usingthe intra-class commonality. Note in this k > 1 case, fea-ture averaging within each class is an effective way to re-duce intra-class variations and expose shared feature com-ponents; this averaging is performed in [35, 36]. While bothinter- and intra-class comparisons are fundamental to clas-sification, and have long been used in machine learning andstatistics [8], metric based few-shot learning methods havenot yet incorporated any of the context available by lookingbetween classes.

To incorporate both inter-class as well as intra-classviews of support set, we introduce a category traversalmodule (CTM). Such a module selects the most relevantfeature dimensions after traversing both across and withincategories. The output of CTM is bundled onto the fea-ture embeddings of the support and query set, making met-ric learning in the subsequent feature space more effective.CTM consists of a concentrator unit to extract the embed-dings within a category for commonality, and a projectorunit to consider the output of the concentrator across cate-gories for uniqueness. The concentrator and projector canbe implemented as convolutional layer(s). Fig. 1 (b) givesa description of how CTM is applied into existing metric-based few-shot learning algorithms. It can be viewed as aplug-and-play module to provide more discriminative andrepresentative features by considering the global featuredistribution in the support set – making metric learning inhigh-dimensional space more effective.

We demonstrate the effectiveness of our category traver-sal module on the few-shot learning benchmarks. CTMis on par with or exceeding previous state-of-the-art. In-corporating CTM into existing algorithms [36, 38, 35], wewitness consistent relative gains of around 5%-10% on bothminiImageNet and tieredImageNet. The code suite is at:https://github.com/Clarifai/few-shot-ctm.

2. Related WorkRecent years have witnessed a vast amount of work on

the few-shot learning task. They can be roughly catego-rized into three branches, (i) metric based, (ii) optimizationbased, and (iii) large corpus based.

The first branch of works are metric based approaches[35, 38, 36, 14, 34, 21, 30, 23, 2, 10]. Vinyals et al. [38]introduced the concept of episode training in few-shot learn-ing, where the training procedure mimics the test scenario

Page 3: Finding Task-Relevant Features for Few-Shot Learning by ... · Module thatcanbeinsertedasaplug-and-playmoduleinto most metric-learning based few-shot learners. This compo-nent traverses

based on support-query metric learning. The idea is intu-itive and simple: these methods compare feature similar-ity after embedding both support and query samples intoa shared feature space. The prototypical network [35] isbuilt upon [38] by comparing the query with class proto-types in the support set, making use of class centroids toeliminate the outliers in the support set and find dimensionscommon to all class samples. Such a practice is similar inspirit to our concentrator module, which we design to focuson intra-class commonality. Our work goes beyond this byalso looking at all classes in the support set together to finddimensions relevant for each task. In [14], a kernel genera-tor is introduced to modify feature embeddings, conditionedon the query image. This is a complementary approach toours: while [14] looks to the query to determine what maybe relevant for its classification, we look at the whole of thesupport set to enable our network to better determine whichfeatures most pertain to the task. In [34], the feature embed-ding and classifier weight creation networks are broken up,to enable zero-shot and few-shot tasks to both be performedwithin the same framework.

There are also interesting works that explore the relation-ship between support and query to enable more complexcomparisons between support and query features. The re-lation network [36] proposes evaluating the relationship ofeach query-support pair using a neural network with con-catenated feature embeddings. It can be viewed as a fur-ther extension to [35, 38] with a learned metric defined bya neural network. Liu et al. [23] propose a transductivepropagation network to propagate labels from known la-beled instances to unlabeled test instances, by learning agraph construction module that exploits the manifold struc-ture in the data. Garcia et al. [10] introduced the conceptof a graph neural network to explicitly learn feature embed-dings by assimilating message-passing inference algorithmsinto neural-network counterparts. Oreshkin et al. [26] alsolearns a task-dependent metric, but conditions based on themean of class prototypes, which can reduce inter-class vari-ations available to their task conditioning network, and re-quires an auxiliary task co-training loss not needed by ourmethod to realize performance gains. Gao et al. [9] ap-plied masks to features in a prototypical network applied toa NLP few-shot sentence classification task, but base theirmasks only on examples within each class, not betweenclasses as our method does.

All the approaches mentioned above base their algo-rithms on a metric learning framework that compares thequery to each support class, taken separately. However,none of them incorporate information available acrosscategories for the task, beyond the final comparison ofindividually-created distance scores. This can lead to prob-lems mentioned in Section 1, where feature dimensions ir-relevant to the current task can end up dominating the sim-

ilarity comparison. In this work, we extend metric-basedapproaches by introducing a category traversal module tofind relevant feature dimensions for the task by looking atall categories simultaneously.

The second branch of literature are optimization basedsolutions [28, 22, 7, 25, 33]. For each task (episode), thelearner samples from a distribution and performs SGD orunrolled weight updates for a few iterations to adapt a pa-rameterized model for the particular task at hand. In [28],a learner model is adapted to a new episodic task by a re-current meta-learner producing efficient parameter updates.MAML [7] and its variants [25, 33] have demonstrated im-pressive results; in these works, the parameters of a learnermodel are optimized so that they can be quickly adapted toa particular task.

At a high-level, these approaches incorporate the idea oftraversing all support classes, by performing a few weightupdate iterations for the few-shot task. However, as pointedout by [33, 21], while these approaches iterate over sam-ples from all classes in their task updates, they often havetrouble learning effective embeddings. [33] address this byapplying the weight update “inner-loop” only to top layerweights, which are initialized by sampling from a genera-tive distribution conditioned on the task samples, and pre-training visual features using an initial supervised phase.By contrast, metric learning based methods achieve consid-erable success in learning good features, but have not madeuse of inter-class views to determine the most relevant di-mensions for each task. We incorporate an all-class viewinto a metric learning framework, and obtain competitiveperformance. Our proposed method learns both the featureembeddings and classification dimensions, and is trained inan entirely from-scratch manner.

The third branch is large-training-corpus based meth-ods [11, 15, 12, 27, 29]. In these, a base network is trainedwith large amount of data, but also must be able to adapt toa few-shot learning task without forgetting the original basemodel concepts. These methods provide stronger featurerepresentations for base model classes that are still “com-patible” with novel few-class concept representations, sothat novel classes with few examples can be readily mixedwith classes from the base classifier.

3. Algorithm3.1. Description on Few-Shot Learning

In a few-shot classification task, we are given a smallsupport set of N distinct, previously unseen classes, withK examples each2. Given a query sample, the goal is toclassify it into one of the N support categories.

Training. The model is trained using a large trainingcorpus Ctrain of labeled examples (of categories different

2Typically, N is between 5 and 20, and K between 1 and 20.

Page 4: Finding Task-Relevant Features for Few-Shot Learning by ... · Module thatcanbeinsertedasaplug-and-playmoduleinto most metric-learning based few-shot learners. This compo-nent traverses

Figure 3. Detailed breakdown of componentsin CTM. It extracts features common to ele-ments in each class via a concentrator o, andallows the metric learner to concentrate onmore discriminative dimensions via a projec-tor p, constructed by traversing all categoriesin the support set.

from any that we will see in the eventual few-shot tasks dur-ing evaluation). The model is trained using episodes. Ineach episode, we construct a support set S and query setQ:

S = {s(1), · · · , s(c), · · · , s(N)} ⊂ Ctrain, |s(c)| = K,

Q = {q(1), · · · , q(c), · · · , q(N)} ⊂ Ctrain, |q(c)| = K,

where c is the class index and K is the number of samplesin class s(c); the support set has a total number of NK sam-ples and corresponds to a N -way K-shot problem. Let sjbe a single sample, where j is the index among all samplesin S. We define the label of sample i to be:

l∗i , l(si) = c, si ∈ s(c).

Similar notation applies for the query set Q.As illustrated in Fig. 1, the samples si, qj are first fed

into a feature extractor fθ(·). We use a CNN or ResNet [16]as the backbone for fθ. These features are used as inputto a comparison moduleM(·, ·). In practice,M could be adirect pair-wise feature distance [35, 38] or a further relationunit [36, 10] consisting additional CNN layers to measurethe relationship between two samples. Denote the outputscore fromM as Y = {yij}. The loss L for this trainingepisode is defined to be a cross-entropy classification lossaveraged across all query-support pairs:

yij =M(fθ(si), fθ(qj)

), (1)

L = − 1

(NK)2

∑i

∑j

1[l∗i = l∗j ] log yij . (2)

Training proceeds by iteratively sampling episodes, andperforming SGD update using the loss for each episode.

Inference. Generalization performance is measured ontest set episodes, where S,Q are now sampled from a cor-pus Ctest containing classes distinct from those used inCtrain. Labels in the support set are known whereas thosein the query are unknown, and used only for evaluation. Thelabel prediction for the query is found by taking class withhighest comparison score:

lj = argmaxc

ycj , (3)

where ycj =1K

∑i yij and l∗i = c. The mean accuracy is

therefore obtained by comparing lj with query labels for alength of test episodes (usually 600).

3.2. Category Traversal Module (CTM)

Fig. 3 shows the overall design of our model. The cate-gory traversal module takes support set features fθ(S) as in-put, and produces a mask p via a concentrator and projectorthat make use of intra- and inter-class views, respectively.The mask p is applied to reduced-dimension features of boththe support and query, producing improved features I withdimensions relevant for the current task. These improvedfeature embeddings are finally fed into a metric learner.

3.2.1 Concentrator: Intra-class Commonality

The first component in CTM is a concentrator to finduniversal features shared by all instances for one class.Denote the output shape from feature extractor fθ as(NK,m1, d1, d1), where m1, d1 indicate the number ofchannel and the spatial size, respectively. We define theconcentrator as follows:

fθ(S) : (NK,m1, d1, d1)Concentrator−−−−−−→ o : (N,m2, d2, d2),

(4)where m2, d2 denote the output number of channel and spa-tial size. Note that the input is first fed to a CNN moduleto perform dimension reduction; then samples within eachclass are averaged to have the final output o. In the 1-shotsetting, there is no average operation, as there is only oneexample for each class.

In practice the CNN module could be either a simpleCNN layer or a ResNet block [16]. The purpose is to re-move the difference among instances and extract the com-monality among instances within one category. This isachieved by way of an appropriate down-sampling fromm1, d1 to m2, d2. Such a learned component is proved tobe better than the averaging alternative [35], where the lat-ter could be deemed as a special case of our concentratorwhen m1 = m2, d1 = d2 without the learned parameters.

3.2.2 Projector: Inter-class Uniqueness

The second component is a projector to mask out irrelevantfeatures and select the ones most discriminative for the cur-rent few-shot task by looking at concentrator features from

Page 5: Finding Task-Relevant Features for Few-Shot Learning by ... · Module thatcanbeinsertedasaplug-and-playmoduleinto most metric-learning based few-shot learners. This compo-nent traverses

Model 5-way model size training time 20-way1-shot 5-shot (Mb) (sec. / episode) 1-shot 5-shot

(i) sample-wise style baseline 37.20% 53.35% 0.47 0.0545 17.96% 28.47%(ii) sample-wise, I1 41.62% 58.77% 0.55 0.0688 21.75% 32.26%

(iii) baseline same size 37.82% 53.46% 0.54 0.0561 18.11% 28.54%

(iv) cluster-wise style baseline 34.81% 50.93% 0.47 0.0531 16.95% 27.37%(v) cluster-wise, I2 39.55% 56.95% 0.55 0.0632 19.96% 30.17%

Table 1. Design choice of I(S) inthe category traversal module (CTM)and comparison with baselines. Wesee a substantial improvement usingCTM over the same-capacity baseline(ii, iii). The sample wise choice (ii)performs better, with marginal extracomputation cost compared with (v).

all support categories simultaneously:

o : (1, Nm2, d2, d2)Projector−−−−→ p : (1,m3, d3, d3). (5)

where o is just a reshaped version of o; m3, d3 follow sim-ilar meaning as in the concentrator. We achieve the goal oftraversing across classes by concatenating the class proto-types in the first dimension (N ) to the channel dimension(m2), applying a small CNN to the concatenated features toproduce a map of size (1,m3, d3, d3), and finally applying asoftmax over the channel dimension m3 (applied separatelyfor each of the d3 × d3 spatial dimensions) to produce amask p. This is used to mask the relevant feature dimen-sions for the task in the query and support set.

3.2.3 Reshaper

In order to make the projector output p influence the fea-ture embeddings fθ(·), we need to match the shape betweenthese modules in the network. This is achieved using a re-shaper network, applied separately to each of NK samples:

fθ(·)Reshaper−−−−−→ r(·) : (NK,m3, d3, d3).

It is designed in light-weight manner with one CNN layer.

3.2.4 Design choice in CTM

Armed by the components stated above we can generate amask output by traversing all categories: fθ(S) → p. Theeffect of CTM is achieved by bundling the projector outputonto the feature embeddings of both support and query, de-noted as I(·). The improved feature representations are thuspromised to be more discriminative to be distinguished.

For the query, the choice of I is simple since we donot have labels of query; the combination is an element-wise multiplication of embeddings and the projector out-put: I(Q) = r(Q) � p, where � stands for broadcastingthe value of p along the sample dimension (NK) in Q.

For the support, however, since we know the query la-bels, we can choose to mask p directly onto the embeddings(sample-wise), or if we keep (m2, d2, d2) = (m3, d3, d3),we can use it to mask the concentrator output o (cluster-wise). Mathematically, these two options are:

option 1: I1(S) = r(S)� p : (NK,m3, d3, d3),

option 2: I2(S) = o� p : (N,m3, d3, d3).

We found that option 1 results in better performance, for amarginal increase in execution time due to its larger numberof comparisons; details are provided in Sec. 4.2.1.

3.3. CTM in Action

The proposed category traversal module is a simple plug-and-play module and can be embedded into any metric-based few-shot learning approach. In this paper, we con-sider three metric-based methods and apply CTM to them,namely the matching network [38], the prototypical net-work [35] and the relation network [36]. As discussed inSec. 1, all these three methods are limited by not consid-ering the entire support set simultaneously. Since featuresare created independently for each class, embeddings irrel-evant to the current task can end up dominating the metriccomparison. These existing methods define their similaritymetric following Eqn. (1); we modify them to use our CTMas follows:

Y =M(r(S)� p, r(Q)� p

), Y = {yij}. (6)

As we show later (see Sec. 4.3.1), after integrating the pro-posed CTM unit, these methods get improved by a largemargin (2%-4%) under different settings.

4. EvaluationThe experiments are designed to answer the following

key questions: (1) Is CTM competitive to other state-of-the-art on large-scale few-shot learning benchmarks? (2) CanCTM be utilized as a simple plug-and-play and bring in gainto existing methods? What are the essential components andfactors to make CTM work? (3) How does CTM modifythe feature space to make features more discriminative andrepresentative?

4.1. Datasets and Setup

Datasets. The miniImageNet dataset [38] is a subset of100 classes selected from the ILSVRC-12 dataset [32] with600 images in each class. It is divided into training, valida-tion, and test meta-sets, with 64, 16, and 20 classes respec-tively. The tieredImageNet dataset [30] is a larger subsetof ILSVRC-12 with 608 classes (779,165 images) groupedinto 34 higher-level nodes based on WordNet hierarchy [5].This set of nodes is partitioned into 20, 6, and 8 disjoint

Page 6: Finding Task-Relevant Features for Few-Shot Learning by ... · Module thatcanbeinsertedasaplug-and-playmoduleinto most metric-learning based few-shot learners. This compo-nent traverses

sets of training, validation, and testing nodes, and the cor-responding classes consist of the respective meta-sets. Asargued in [30], the split in tieredImageNet is more challeng-ing, with realistic regime of test classes that are less similarto training ones. Note that the validation set is only used fortuning model parameters.

Evaluation metric. We report the mean accuracy (%) of600 randomly generated episodes as well as the 95% confi-dence intervals on test set. In every episode during test, eachclass has 15 queries, following most methods [35, 36, 33].

Implementation details. For training, the 5-way prob-lem has 15 query images while the 20-way problem has 8query images. The reason for a fewer number of query sam-ples in the 20-way setting is mainly due to the GPU memoryconsiderations. The input image is resized to 84× 84.

We use Adam [18] optimizer with an initial learning rateof 0.001. The total training episodes on miniImageNet andtieredImageNet are 600,000 and 1,000,000 respectively.The learning rate is dropped by 10% every 200,000 episodesor when loss enters a plateau. The weight decay is set to be0.0005. Gradient clipping is also applied.

4.2. Ablation Study

4.2.1 Shallow Network Verification

We first validate the effectiveness of category traversal bycomparing against same-capacity baselines using a sim-ple backbone network. Specifically, a 4-layer neural net-work is adopted as backbone; we directly compute fea-ture similarity between I(S) and I(Q). The mean accu-racy on miniImageNet is reported. After feature embed-ding, m1 = 64, d1 = 21; the concentrator is a CNN layerwith stride of 2, i.e., m2 = 32, d2 = 10. To compare be-tween choices (I1 or I2), the projector leaves dimensionsunchanged, i.e., m3 = m2, d3 = d2.

Baseline comparison. Results are reported in Tab. 1.Model size and training time are measured under the 5-way5-shot setting. The “baseline” in row (i) and (iv) evaluate amodel with reshaper network and metric comparisons only,omitting CTM concentrator and projector. Row (ii) shows amodel that includes our CTM. Since adding CTM increasesthe model capacity compared to the baseline (i), we alsoinclude a same-size model baseline for comparison, shownas “baseline same size” (iii), by adding additional layers tothe backbone such that its model size is similar to (ii). Notethat the only difference between (i) and (iv) is that the lattercase takes average of samples within each category.

We can see on average there is a 10% relative improve-ment using CTM in both 5-way and 20-way settings, com-pared to the baselines. Notably, the larger-capacity baselineimproves only marginally over the original baseline, whilethe improvements using CTM are substantial. This showsthat the performance increase obtained by CTM is indeeddue to its ability to find relevant features for each task.

Table 2. Ablation study on category traversal module.

Factor miniImageNet accuracy1-shot 5-shot

CTM with shallow (4 layer) backbone 41.62 58.77

CTM with ResNet-18 backbone 59.34 77.95(i) w/o concentrator network o 55.41 73.29

(ii) w/o projector p 57.18 74.25(iii) softmax all in p 57.77 75.03

relation net baseline without CTM 58.21 74.29relation netM, CTM, MSE loss 61.37 78.54

relation netM, CTM, cross entropy loss 62.05 78.63

Which option for I(S) is better? Table 1 (ii, v) showsthe comparison between I1 and I2. In general, the sample-wise choice I1 is 2% better than I2. Note the model sizebetween these two are exactly the same; the only differenceis how p is multiplied. However, a trivial drawback of I1is the slightly slower time (0.0688 vs 0.0632) since it needsto broadcast p across all samples. Despite the efficiency, wechoose the first option as our preference to generate I(S) =I1 nonetheless.

4.2.2 CTM with Deeper Network

Table 2 reports the ablation analysis on different compo-nents of CTM. Using a deeper backbone for the feature ex-tractor increases performance by a large margin. Experi-ments in the second block investigate the effect of the con-centrator and projector, respectively. Removing each com-ponent alone results in a performance decrease (cases i, ii,iii)3. The accuracy is inferior (-3.93%, 1-shot case) if we re-move the network part of the concentrator, implying that itsdimension reduction and spatial downsampling is importantto the final comparisons. Removing the projector p also re-sults in a significant drop (-2.16%, 1-shot), confirming thatthis step is necessary to find task-specific discriminate di-mensions. An interesting result is that if we perform thesoftmax operation across all the locations (m3, d3, d3) inp, the accuracy (57.77%) is inferior to performing softmaxalong the channel dimension (m3) for each location sepa-rately (59.34%); this is consistent with the data, where ab-solute position in the image is only modestly relevant to anyclass difference.

Moreover, we incorporate the relation module [36] as themetric learner for the last module M. It consists of twoCNN blocks with two subsequent fc layers generating therelationship score for one query-support pair. The baselinerelation net model without CTM has an accuracy of 58.21%.After including our proposed module, the performance in-

3Implementation details: case (i) without concentrator, support samplesare still averaged to generate an output of (N,m, d, d) for the projector;case (ii) without projector, the improved feature representation for supportand query are o(S), r(Q), respectively.

Page 7: Finding Task-Relevant Features for Few-Shot Learning by ... · Module thatcanbeinsertedasaplug-and-playmoduleinto most metric-learning based few-shot learners. This compo-nent traverses

Method 5-way 20-way 5-way 20-way1-shot 5-shot 1-shot 5-shot 1-shot 5-shot 1-shot 5-shot

Matching Net [38], paper 43.56 55.31 - - - - - -Matching Net [38], our implementation 48.89 66.35 23.18 36.73 54.02 70.11 23.46 41.65

Matching Net [38], CTM 52.43 70.09 25.84 40.98 57.01 73.45 25.69 45.07+3.54 +3.74 +2.66 +4.25 +2.99 +3.34 +2.23 +3.42

Prototypical Net [35], paper 49.42 68.20 - - 53.31 72.69 - -Prototypical Net [35], our implementation 56.11 74.16 28.53 42.36 60.27 75.80 28.56 49.34

Prototypical Net [35], CTM 59.34 77.95 32.08 47.11 63.77 79.24 31.02 51.44+3.23 +3.79 +3.55 +4.75 +3.50 +3.44 +2.46 +2.10

Relation Net [36], paper 50.44 65.32 - - 54.48 71.32 - -Relation Net [36], our implementation 58.21 74.29 31.35 45.19 61.11 77.39 26.77 47.82

Relation Net [36], CTM 62.05 78.63 35.11 48.72 64.78 81.05 31.53 52.18+3.84 +4.34 +3.76 +3.53 +3.67 +3.66 +4.76 +4.36

Table 3. Improvement after incorporating CTM into existing methods on miniImageNet (left) and tieredImageNet (right).

creases by 3.84%, to 62.05%. Note that the original paper[36] uses mean squared error (MSE); we find cross-entropyis slightly better (0.68% and 0.09% for 1-shot and 5-shot,respectively), as defined in Eqn. (2).

4.3. Comparison with State-of-the-Art

4.3.1 Adapting CTM into Existing Frameworks

To verify the effectiveness of our proposed category traver-sal module, we embed it into three metric-based algorithmsthat are closely related to ours. It is worth noticing that thecomparison should be conducted in a fair setting; however,different sources report different results4. Here we describethe implementation we use.

Matching Net [38] and Prototypical Net [35]. In thesecases, the metric module M is the pair-wise feature dis-tance. Note that a main source of improvement between[38] and [35] is that the query is compared to the averagefeature for each class; this has the effect of including intra-class commonality, which we make use of in our concentra-tor module. As for the improvement from original paper toour baseline, we use the ResNet-18 model with a Euclideandistance for the similarity comparison, instead of a shallowCNN network with cosine distance originally.

Relation Net [36]. As for the improvement from origi-nal paper to our baseline, the backbone structure is switchedfrom 4-conv to ResNet-18 model; the relation unit Madopts the ResNet blocks instead of CNN layers; the su-pervision loss is changed to the cross entropy.

Table 3 shows the gains obtained by including CTM intoeach method. We observe that on average, there is an ap-proximately 3% increase after adopting CTM. This showsthe ability of our module to plug-and-play into multiplemetric based systems. Moreover, the gains remain consis-tent for each method, regardless of the starting performance

4For example, the relation network has a 65.32% accuracy for 5-way5-shot setting on miniImageNet. [39] gives a 61.1%; [2] has 66.6%; [21]obtains 71.07% with a larger network

Model miniImageNet test accuracy1-shot 5-shot

Meta-learner LSTM [28] 43.44 ± 0.77 60.60 ± 0.71MAML [7] 48.70 ± 1.84 63.11 ± 0.92

REPTILE [25] 49.97 ± 0.32 65.99 ± 0.58Meta-SGD [22] 54.24 ± 0.03 70.86 ± 0.04

SNAIL [24] 55.71 ± 0.99 68.88 ± 0.92CAML [17] 59.23 ± 0.99 72.35 ± 0.18

LEO [33] 61.76 ± 0.08 77.59 ± 0.12Incremental [29] 55.72 ± 0.41 70.50 ± 0.36

Dynamic [12] 56.20 ± 0.86 73.00 ± 0.64Predict Params [27] 59.60 ± 0.41 73.74 ± 0.19Matching Net [38] 43.56 ± 0.84 55.31 ± 0.73

BANDE [1] 48.90 ± 0.70 68.30 ± 0.60Prototypical Net [35] 49.42 ±0.78 68.20 ± 0.66

Relation Net [36] 50.44 ± 0.82 65.32 ± 0.70Projective Subspace [34] ——– 68.12 ± 0.67

Individual Feature [13] 56.89 ——– 70.51 ——–IDeMe-Net [3] 57.71 ——– 74.34 ——–

TADAM [26] 58.50 ± 0.30 76.70 ± 0.30CTM (ours) 62.05 ± 0.55 78.63 ± 0.06

CTM (ours), data augment 64.12 ± 0.82 80.51 ± 0.13

Model tieredImageNet test accuracy1-shot 5-shot

MAML [7] 51.67 ± 1.81 70.30 ± 0.08Meta-SGD [22], reported by [33] 62.95 ± 0.03 79.34 ± 0.06

LEO [33] 66.33 ± 0.05 81.44 ± 0.09

Dynamic [12], reported by [29] 50.90 ± 0.46 66.69 ± 0.36Incremental [29] 51.12 ± 0.45 66.40 ± 0.36

Soft k-means [30] 52.39 ± 0.44 69.88 ± 0.20Prototypical Net [35] 53.31 ± 0.89 72.69 ± 0.74

Projective Subspace [34] ——– 71.15 ± 0.67Relation Net [36] 54.48 ± 0.93 71.32 ± 0.78

Transductive Prop. [23] 59.91 ——– 73.30 ——–CTM (ours) 64.78 ± 0.11 81.05 ± 0.52

CTM (ours), data augment 68.41 ± 0.39 84.28 ± 1.73

Table 4. Test accuracies for 5-way tasks, both 1-shot and 5-shot.We provide two versions of our model. See Sec. 4.3.2 for details.

level. This supports the hypothesis that our method is ableto incorporate signals previously unavailable to any of theseapproaches, i.e., the inter-class relations in each task.

Page 8: Finding Task-Relevant Features for Few-Shot Learning by ... · Module thatcanbeinsertedasaplug-and-playmoduleinto most metric-learning based few-shot learners. This compo-nent traverses

(a) Relation net, 47.82% accuracy (b) Relation Net with CTM, 52.18% accuracy

Figure 4. The t-SNE visualization [37]of the improved feature embeddings I(·)learned by our CTM approach. (a) cor-responds to the 20-way 5-shot setting ofthe relation network without CTM in Ta-ble 3 and (b) corresponds to the improvedversion with CTM. Only 10 classes areshown for better view. We can see that af-ter traversing across categories, the effectof projector p onto the features are obvi-ous - making clusters more compact anddiscriminative from each other.

4.3.2 Comparison beyond Metric-based Approaches

We compare our proposed CTM approach with other state-of-the-art methods in Table 4. For each dataset, the firstblock of methods are optimization-based, the second arebase-class-corpus algorithms, and the third are metric-basedapproaches. We use a ResNet-18 backbone for the featureextractor to compare with other approaches. The model istrained from scratch with standard initialization, and no ad-ditional training data (e.g., distractors [30, 23]) are utilized.We believe such a design aligns with most of the comparedalgorithms in a fair spirit.

It is observed that our CTM method compares favor-ably against most methods by a large margin, not limitedto the metric-based methods but also compared with theoptimization-based methods. For example, under the 5-way 1-shot setting, the performance is 62.05% vs 59.60%[27], and 64.78% vs 59.91% [23] on the two benchmarksminiImageNet and tieredImageNet, respectively.

LEO [33] is slightly better than ours (without dataaugmentation) on tieredImageNet. It uses wide resid-ual networks [40] with 28 layers; they also pretrain themodel using a supervised task on the entire training setand finetune the network based on these pre-trained fea-tures. For practical interest, we also train a version of ourmodel with supervised pretraining (using only the mini- ortieredImageNet training sets), basic data augmentation (in-cluding random crop, color jittering and horizontal flip), anda higher weight decay (0.005). The result is shown in thelast case for each dataset. Note that the network structure isstill ResNet-18, considering LEO’s wideResNet-28.

4.4. Feature Visualization Learned by CTM

Fig. 4 visualizes the feature distribution using t-SNE[37]. The features computed in a 20-way 5-shot setting, butonly 10 classes are displayed for easier comparison. Model(a) achieves an accuracy of 47.32% without CTM and theimproved version, Model (b), equipped with CTM has a bet-ter performance of 52.18%. When sampling features fort-SNE for our model, we use I(S), i.e. after the mask pis applied. Since this depends on the support sample, fea-

tures will be vastly different depending on the chosen task.Therefore, when sampling tasks to create these visualiza-tion features, we first chose 20 classes, and kept these fixedwhile drawing different random support samples from thisclass set. We draw a total of 50 episodes on the test set.

As can be clearly observed, CTM model has more com-pact and separable clusters, indicating that features are morediscriminative for the task. This descends from the designof the category traversal module. Without CTM, some clus-ters overlap with each other (e.g., light green with orange),making the metric learning difficult to compare.

5. Conclusion

In this paper, we propose a category traversal module(CTM) to extract feature dimensions most relevant to eachtask, by looking the context of the entire support set. Bydoing so, it is able to make use of both inter-class unique-ness and intra-class commonality properties, both of whichare fundamental to classification. By looking at all supportclasses together, our method is able to identify discrimina-tive feature dimensions for each task, while still learningeffective comparison features entirely from scratch. We de-vise a concentrator to first extract the feature commonal-ity among instances within the class by effectively down-sampling the input features and averaging. A projector isintroduced to traverse feature dimensions across all cate-gories in the support set. The projector inter-class relationsto focus on the on relevant feature dimensions for the task athand. The output of CTM is then combined onto the featureembeddings for both support and query; the enhanced fea-ture representations are more unique and discriminative forthe task. We have demonstrated that it improves upon previ-ous methods by a large margin, and has highly competitiveperformance compared with state-of-the-art.

Acknowledgment

We thank Nand Dalal, Michael Gormish, Yanan Jian andreviewers for helpful discussions and comments. H. Li issupported by Hong Kong Ph.D. Fellowship Scheme.

Page 9: Finding Task-Relevant Features for Few-Shot Learning by ... · Module thatcanbeinsertedasaplug-and-playmoduleinto most metric-learning based few-shot learners. This compo-nent traverses

References

[1] Kelsey R Allen, Hanul Shin, Evan Shelhamer, andJosh B. Tenenbaum. Variadic learning by bayesiannonparametric deep embedding. In OpenReview,2019.

[2] Wei-Yu Chen, Yen-Cheng Liu, Zsolt Kira, Yu-Chiang Frank Wang, and Jia-Bin Huang. A closer lookat few-shot classification. In ICLR, 2019.

[3] Zitian Chen, Yanwei Fu, Yu-Xiong Wang, Lin Ma,Wei Liu, and Martial Hebert. Image deformationmeta-networks for one-shot learning. In OpenReview,2019.

[4] Sumit Chopra, Raia Hadsell, and Yann Lecun. Learn-ing a similarity metric discriminatively, with applica-tion to face verification. In CVPR, 2005.

[5] J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, and L.Fei-Fei. ImageNet: A Large-Scale Hierarchical ImageDatabase. In CVPR, 2009.

[6] Nir Ailon Elad Hoffer. Deep Metric Learning usingTriplet Network. In ICLR, 2015.

[7] Chelsea Finn, Pieter Abbeel, and Sergey Levine.Model-agnostic meta-learning for fast adaptation ofdeep networks. In arXiv preprint:1703.03400, 2017.

[8] Ronald A Fisher. The use of multiple measurementsin taxonomic problems. Annals of eugenics, 7(2):179–188, 1936.

[9] Tianyu Gao, Xu Han, Zhiyuan Liu, and MaosongSun. Hybrid Attention-Based Prototypical Networksfor Noisy Few-Shot Relation Classification . In AAAI,2019.

[10] Victor Garcia and Joan Bruna. Few-shot Learningwith Graph Neural Networks. In ICLR, 2018.

[11] Mohammad Ghasemzadeh, Fang Lin, Bita DarvishRouhani, Farinaz Koushanfar, and Ke Huang. Ag-ilenet: Lightweight dictionary-based few-shot learn-ing. In arXiv preprint:1805.08311, 2018.

[12] Spyros Gidaris and Nikos Komodakis. Dynamic Few-Shot Visual Learning without Forgetting. In CVPR,2018.

[13] Jonathan Gordon, John Bronskill, Matthias Bauer, Se-bastian Nowozin, and Richard Turner. Meta-learningprobabilistic inference for prediction. In ICLR, 2019.

[14] Chunrui Han, Shiguang Shan, Meina Kan, ShuzheWu, and Xilin Chen. Meta-learning with individual-ized feature space for few-shot classification. 2019.

[15] Bharath Hariharan and Ross Girshick. Low-shot Vi-sual Recognition by Shrinking and Hallucinating Fea-tures. In ICCV, 2017.

[16] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and JianSun. Deep residual learning for image recognition. InCVPR, 2016.

[17] Xiang Jiang, Mohammad Havaei, Farshid Varno,Gabriel Chartrand, Nicolas Chapados, and StanMatwin. Learning to learn with conditional class de-pendencies. In ICLR, 2019.

[18] Diederik P. Kingma and Jimmy Ba. Adam: A methodfor stochastic optimization. In ICLR, 2015.

[19] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E. Hin-ton. Imagenet classification with deep convolutionalneural networks. In NIPS, 2012.

[20] Hongyang Li, Xiaoyang Guo, Bo Dai, Wanli Ouyang,and Xiaogang Wang. Neural network encapsulation.In ECCV, 2018.

[21] Kai Li, Martin Renqiang Min, Bing Bai, Yun Fu, andHans Peter Graf. Network reparameterization for un-seen class categorization. In OpenReview, 2019.

[22] Zhenguo Li, Fengwei Zhou, Fei Chen, and Hang Li.Meta-SGD: Learning to Learn Quickly for Few-ShotLearning. In arXiv preprint:1707.09835, 2017.

[23] Yanbin Liu, Juho Lee, Minseop Park, Saehoon Kim,and Yi Yang. Transductive Propagation Network forFew-shot Learning. In arXiv preprint:1805.10002,2018.

[24] Nikhil Mishra, Mostafa Rohaninejad, Xi Chen, andPieter Abbeel. A Simple Neural Attentive Meta-Learner. In ICLR, 2018.

[25] Alex Nichol, Joshua Achiam, and John Schulman.On First-Order Meta-Learning Algorithms. In arXivpreprint:1803.02999, 2018.

[26] Boris N. Oreshkin, Pau Rodriguez, and Alexandre La-coste. TADAM: Task dependent adaptive metric forimproved few-shot learning. In NIPS, 2018.

[27] Siyuan Qiao, Chenxi Liu, Wei Shen, and Alan Yuille.Few-Shot Image Recognition by Predicting Parame-ters from Activations. In CVPR, 2018.

[28] Sachin Ravi and Hugo Larochelle. Optimization as amodel for few-shot learning. In ICLR, 2017.

[29] Mengye Ren, Renjie Liao, Ethan Fetaya, andRichard S. Zemel. Incremental few-shot learn-ing with attention attractor networks. In arXivpreprint:1810.07218, 2018.

[30] Mengye Ren, Eleni Triantafillou, Sachin Ravi, JakeSnell, Kevin Swersky, Joshua B. Tenenbaum, HugoLarochelle, and Richard S. Zemel. Meta-Learning forSemi-supervised Few-Shot Classification. In ICLR,2018.

Page 10: Finding Task-Relevant Features for Few-Shot Learning by ... · Module thatcanbeinsertedasaplug-and-playmoduleinto most metric-learning based few-shot learners. This compo-nent traverses

[31] Shaoqing Ren, Kaiming He, Ross Girshick, and JianSun. Faster R-CNN: Towards Real-Time Object De-tection with Region Proposal Networks. In NIPS,2015.

[32] Olga Russakovsky, Jia Deng, Hao Su, JonathanKrause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang,Andrej Karpathy, Aditya Khosla, Michael Bernstein,Alexander C. Berg, and Li Fei-Fei. ImageNet LargeScale Visual Recognition Challenge. InternationalJournal of Computer Vision (IJCV), 115(3):211–252,2015.

[33] Andrei A. Rusu, Dushyant Rao, Jakub Sygnowski,Oriol Vinyals, Razvan Pascanu, Simon Osindero, andRaia Hadsell. Meta-learning with latent embeddingoptimization. In ICLR, 2019.

[34] Christian Simon, Piotr Koniusz, and Mehrtash Ha-randi. Projective subspace networks for few-shotlearning. In OpenReview, 2019.

[35] Jake Snell, Kevin Swersky, and Richard S. Zemel.Prototypical Networks for Few-shot Learning. InNIPS, 2017.

[36] Flood Sung, Yongxin Yang, Li Zhang, Tao Xiang,Philip H.S. Torr, and Timothy M. Hospedales. Learn-ing to compare: Relation network for few-shot learn-ing. In CVPR, 2018.

[37] Laurens van der Maaten and Geoffrey Hinton. Visual-izing data using t-SNE. Journal of Machine LearningResearch, 9:2579–2605, 2008.

[38] Oriol Vinyals, Charles Blundell, Timothy Lillicrap,Koray Kavukcuoglu, and Daan Wierstra. MatchingNetworks for One Shot Learning. In NIPS, 2016.

[39] Zhirong Wu, Alexei A. Efros, , and Stella X. Yu.Improving Generalization via Scalable NeighborhoodComponent Analysis. In ECCV, 2018.

[40] Sergey Zagoruyko and Nikos Komodakis. Wide resid-ual networks. In BMVC, 2016.