Categories
Offsites

Guiding Frozen Language Models with Learned Soft Prompts

Large pre-trained language models, which are continuing to grow in size, achieve state-of-art results on many natural language processing (NLP) benchmarks. Since the development of GPT and BERT, standard practice has been to fine-tune models on downstream tasks, which involves adjusting every weight in the network (i.e., model tuning). However, as models become larger, storing and serving a tuned copy of the model for each downstream task becomes impractical.

An appealing alternative is to share across all downstream tasks a single frozen pre-trained language model, in which all weights are fixed. In an exciting development, GPT-3 showed convincingly that a frozen model can be conditioned to perform different tasks through “in-context” learning. With this approach, a user primes the model for a given task through prompt design, i.e., hand-crafting a text prompt with a description or examples of the task at hand. For instance, to condition a model for sentiment analysis, one could attach the prompt, “Is the following movie review positive or negative?” before the input sequence, “This movie was amazing!

Sharing the same frozen model across tasks greatly simplifies serving and allows for efficient mixed-task inference, but unfortunately, this is at the expense of task performance. Text prompts require manual effort to design, and even well-designed prompts still far underperform compared to model tuning. For instance, the performance of a frozen GPT-3 175B parameter model on the SuperGLUE benchmark is 5 points below a fine-tuned T5 model that uses 800 times fewer parameters.

In “The Power of Scale for Parameter-Efficient Prompt Tuning”, presented at EMNLP 2021, we explore prompt tuning, a more efficient and effective method for conditioning frozen models using tunable soft prompts. Just like engineered text prompts, soft prompts are concatenated to the input text. But rather than selecting from existing vocabulary items, the “tokens” of the soft prompt are learnable vectors. This means a soft prompt can be optimized end-to-end over a training dataset. In addition to removing the need for manual design, this allows the prompt to condense information from datasets containing thousands or millions of examples. By comparison, discrete text prompts are typically limited to under 50 examples due to constraints on model input length. We are also excited to release the code and checkpoints to fully reproduce our experiments.

Prompt tuning retains the strong task performance of model tuning, while keeping the pre-trained model frozen, enabling efficient multitask serving.

Prompt Tuning
To create a soft prompt for a given task, we first initialize the prompt as a fixed-length sequence of vectors (e.g., 20 tokens long). We attach these vectors to the beginning of each embedded input and feed the combined sequence into the model. The model’s prediction is compared to the target to calculate a loss, and the error is back-propagated to calculate gradients, however we only apply these gradient updates to our new learnable vectors — keeping the core model frozen. While soft prompts learned in this way are not immediately interpretable, at an intuitive level, the soft prompt is extracting evidence about how to perform a task from the labeled dataset, performing the same role as a manually written text prompt, but without the need to be constrained to discrete language.

Our codebase, implemented in the new JAX-based T5X framework, makes it easy for anyone to replicate this procedure, and provides practical hyperparameter settings, including a large learning rate (0.3), which we found was important for achieving good results.

Since soft prompts have a small parameter footprint (we train prompts with as few as 512 parameters), one can easily pass the model a different prompt along with each input example. This enables mixed-task inference batches, which can streamline serving by sharing one core model across many tasks.

Left: With model tuning, incoming data are routed to task-specific models. Right: With prompt tuning, examples and prompts from different tasks can flow through a single frozen model in large batches, better utilizing serving resources.

Improvement with Scale
When evaluated on SuperGLUE and using a frozen T5 model, prompt tuning significantly outperforms prompt design using either GPT-3 or T5. Furthermore, as model size increases, prompt tuning catches up to the performance level of model tuning. Intuitively, the larger the pre-trained model, the less of a “push” it needs to perform a specific task, and the more capable it is of being adapted in a parameter-efficient way.

As scale increases, prompt tuning matches model tuning, despite tuning 25,000 times fewer parameters.

The effectiveness of prompt tuning at large model scales is especially important, since serving separate copies of a large model can incur significant computational overhead. In our paper, we demonstrate that larger models can be conditioned successfully even with soft prompts as short as 5 tokens. For T5 XXL, this means tuning just 20 thousand parameters to guide the behavior of an 11 billion parameter model.

Resilience to Domain Shift
Another advantage of prompt tuning is its resilience to domain shift. Since model tuning touches every weight in the network, it has the capacity to easily overfit on the provided fine-tuning data and may not generalize well to variations in the task at inference time. By comparison, our learned soft prompts have a small number of parameters, so the solutions they represent may be more generalizable.

To test generalizability, we train prompt tuning and model tuning solutions on one task, and evaluate zero-shot on a closely related task. For example, when we train on the Quora Question Pairs task (i.e., detecting if two questions are duplicates) and evaluate on MRPC (i.e., detecting if two sentences from news articles are paraphrases), prompt tuning achieves +3.2 points higher accuracy than model tuning.

Train    Eval    Tuning    Accuracy    F1
                          
QQP    MRPC    Model    73.1 ±0.9    81.2 ±2.1
Prompt    76.3 ±0.1    84.3 ±0.3
                          
MRPC    QQP    Model    74.9 ±1.3    70.9 ±1.2
Prompt    75.4 ±0.8    69.7 ±0.3   
On zero-shot domain transfer between two paraphrase detection tasks, prompt tuning matches or outperforms model tuning, depending on the direction of transfer.

Looking Forward
Prompt-based learning is an exciting new area that is quickly evolving. While several similar methods have been proposed — such as Prefix Tuning, WARP, and P-Tuningwe discuss their pros and cons and demonstrate that prompt tuning is the simplest and the most parameter efficient method.

In addition to the Prompt Tuning codebase, we’ve also released our LM-adapted T5 checkpoints, which we found to be better-suited for prompt tuning compared to the original T5. This codebase was used for the prompt tuning experiments in FLAN, and the checkpoints were used as a starting point for training the BigScience T0 model. We hope that the research community continues to leverage and extend prompt tuning in future research.

Acknowledgements
This project was a collaboration between Brian Lester, Rami Al-Rfou and Noah Constant. We are grateful to the following people for feedback, discussion and assistance: Waleed Ammar, Lucas Dixon, Slav Petrov, Colin Raffel, Adam Roberts, Sebastian Ruder, Noam Shazeer, Tu Vu and Linting Xue.

Categories
Offsites

Nested Hierarchical Transformer: Towards Accurate, Data-Efficient, and Interpretable Visual Understanding

In visual understanding, the Visual Transformer (ViT) and its variants have received significant attention recently due to their superior performance on many core visual applications, such as image classification, object detection, and video understanding. The core idea of ViT is to utilize the power of self-attention layers to learn global relationships between small patches of images. However, the number of connections between patches increases quadratically with image size. Such a design has been observed to be data inefficient — although the original ViT can perform better than convolutional networks with hundreds of millions of images for pre-training, such a data requirement is not always practical, and it still underperforms compared to convolutional networks when given less data. Many are exploring to find more suitable architectural re-designs that can learn visual representations effectively, such as by adding convolutional layers and building hierarchical structures with local self-attention.

The principle of hierarchical structure is one of the core ideas in vision models, where bottom layers learn more local object structures on the high-dimensional pixel space and top layers learn more abstracted and high-level knowledge at low-dimensional feature space. Existing ViT-based methods focus on designing a variety of modifications inside self-attention layers to achieve such a hierarchy, but while these offer promising performance improvements, they often require substantial architectural re-designs. Moreover, these approaches lack an interpretable design, so it is difficult to explain the inner-workings of trained models.

To address these challenges, in “Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and Interpretable Visual Understanding”, we present a rethinking of existing hierarchical structure–driven designs, and provide a novel and orthogonal approach to significantly simplify them. The central idea of this work is to decouple feature learning and feature abstraction (pooling) components: nested transformer layers encode visual knowledge of image patches separately, and then the processed information is aggregated. This process is repeated in a hierarchical manner, resulting in a pyramid network structure. The resulting architecture achieves competitive results on ImageNet and outperforms results on data-efficient benchmarks. We have shown such a design can meaningfully improve data efficiency with faster convergence and provide valuable interpretability benefits. Moreover, we introduce GradCAT, a new technique for interpreting the decision process of a trained model at inference time.

Architecture Design
The overall architecture is simple to implement by adding just a few lines of Python code to the source code of the original ViT. The original ViT architecture divides an input image into small patches, projects pixels of each patch to a vector with predefined dimension, and then feeds the sequences of all vectors to the overall ViT architecture containing multiple stacked identical transformer layers. While every layer in ViT processes information of the whole image, with this new method, stacked transformer layers are used to process only a region (i.e., block) of the image containing a few spatially adjacent image patches. This step is independent for each block and is also where feature learning occurs. Finally, a new computational layer called block aggregation then combines all of the spatially adjacent blocks. After each block aggregation, the features corresponding to four spatially adjacent blocks are fed to another module with a stack of transformer layers, which then process those four blocks jointly. This design naturally builds a pyramid hierarchical structure of the network, where bottom layers can focus on local features (such as textures) and upper layers focus on global features (such as object shape) at reduced dimensionality thanks to the block aggregation.

A visualization of the network processing an image: Given an input image, the network first partitions images into blocks, where each block contains 4 image patches. Image patches in every block are linearly projected as vectors and processed by a stack of identical transformer layers. Then the proposed block aggregation layer aggregates information from each block and reduces its spatial size by 4 times. The number of blocks is reduced to 1 at the top hierarchy and classification is conducted after the output of it.

Interpretability
This architecture has a non-overlapping information processing mechanism, independent at every node. This design resembles a decision tree-like structure, which manifests unique interpretability capabilities because every tree node contains independent information of an image block that is being received by its parent nodes. We can trace the information flow through the nodes to understand the importance of each feature. In addition, our hierarchical structure retains the spatial structure of images throughout the network, leading to learned spatial feature maps that are effective for interpretation. Below we showcase two kinds of visual interpretability.

First, we present a method to interpret the trained model on test images, called gradient-based class-aware tree-traversal (GradCAT). GradCAT traces the feature importance of each block (a tree node) from top to bottom of the hierarchy structure. The main idea is to find the most valuable traversal from the root node at the top layer to a child node at the bottom layer that contributes the most to the classification outcomes. Since each node processes information from a certain region of the image, such traversal can be easily mapped to the image space for interpretation (as shown by the overlaid dots and lines in the image below).

The following is an example of the model’s top-4 predictions and corresponding interpretability results on the left input image (containing 4 animals). As shown below, GradCAT highlights the decision path along the hierarchical structure as well as the corresponding visual cues in local image regions on the images.

Given the left input image (containing four objects), the figure visualizes the interpretability results of the top-4 prediction classes. The traversal locates the model decision path along the tree and simultaneously locates the corresponding image patch (shown by the dotted line on images) that has the highest impact to the predicted target class.

Moreover, the following figures visualize results on the ImageNet validation set and show how this approach enables some intuitive observations. For instance, the example of the lighter below (upper left panel) is particularly interesting because the ground truth class — lighter/matchstick — actually defines the bottom-right matchstick object, while the most salient visual features (with the highest node values) are actually from the upper-left red light, which conceptually shares visual cues with a lighter. This can also be seen from the overlaid red lines, which indicate the image patches with the highest impact on the prediction. Thus, although the visual cue is a mistake, the output prediction is correct. In addition, the four child nodes of the wooden spoon below have similar feature importance values (see numbers visualized in the nodes; higher indicates more importance), which is because the wooden texture of the table is similar to that of the spoon.

Visualization of the results obtained by the proposed GradCAT. Images are from the ImageNet validation dataset.

Second, different from the original ViT, our hierarchical architecture retains spatial relationships in learned representations. The top layers output low-resolution features maps of input images, enabling the model to easily perform attention-based interpretation by applying Class Attention Map (CAM) on the learned representations at the top hierarchical level. This enables high-quality weakly-supervised object localization with just image-level labels. See the following figure for examples.

Visualization of CAM-based attention results. Warmer colors indicate higher attention. Images are from the ImageNet validation dataset.

Convergence Advantages
With this design, feature learning only happens at local regions independently, and feature abstraction happens inside the aggregation function. This design and simple implementation is general enough for other types of visual understanding tasks beyond classification. It also improves the model convergence speed greatly, significantly reducing the training time to reach the desired maximum accuracy.

We validate this advantage in two ways. First, we compare the ViT structure on the ImageNet accuracy with a different number of total training epochs. The results are shown on the left side of the figure below, demonstrating much faster convergence than the original ViT, e.g., around 20% improvement in accuracy over ViT with 30 total training epochs.

Second, we modify the architecture to conduct unconditional image generation tasks, since training ViT-based models for image generation tasks is challenging due to convergence and speed issues. Creating such a generator is straightforward by transposing the proposed architecture: the input is an embedding vector, the output is a full image in RGB channels, and the block aggregation is replaced by a block de-aggregation component supported by Pixel Shuffling. Surprisingly, we find our generator is easy to train and demonstrates faster convergence speed, as well as better FID score (which measures how similar generated images are to real ones), than the capacity-comparable SAGAN.

Left: ImageNet accuracy given different number of total training epochs compared with standard ViT architecture. Right: ImageNet 64×64 image generation FID scores (lower is better) with single 1000-epoch training. On both tasks, our method shows better convergence speed.

Conclusion
In this work we demonstrate the simple idea that decoupled feature learning and feature information extraction in this nested hierarchy design leads to better feature interpretability through a new gradient-based class-aware tree traversal method. Moreover, the architecture improves convergence on not only classification tasks but also image generation tasks. The proposed idea is focusing on aggregation function and thereby is orthogonal to advanced architecture design for self-attention. We hope this new research encourages future architecture designers to explore more interpretable and data-efficient ViT-based models for visual understanding, like the adoption of this work for high-resolution image generation. We have also released the source code for the image classification portion of this work.

Acknowledgements
We gratefully acknowledge the contributions of other co-authors, including Han Zhang, Long Zhao, Ting Chen, Sercan Arik, Tomas Pfister. We also thank Xiaohua Zhai, Jeremy Kubica, Kihyuk Sohn, and Madeleine Udell for the valuable feedback of the work.

Categories
Offsites

Robot See, Robot Do

People learn to do things by watching others — from mimicking new dance moves, to watching YouTube cooking videos. We’d like robots to do the same, i.e., to learn new skills by watching people do things during training. Today, however, the predominant paradigm for teaching robots is to remote control them using specialized hardware for teleoperation and then train them to imitate pre-recorded demonstrations. This limits both who can provide the demonstrations (programmers & roboticists) and where they can be provided (lab settings). If robots could instead self-learn new tasks by watching humans, this capability could allow them to be deployed in more unstructured settings like the home, and make it dramatically easier for anyone to teach or communicate with them, expert or otherwise. Perhaps one day, they might even be able to use Youtube videos to grow their collection of skills over time.

Our motivation is to have robots watch people do tasks, naturally with their hands, and then use that data as demonstrations for learning. Video by Teh Aik Hui and Nathaniel Lim. License: CC-BY

However, an obvious but often overlooked problem is that a robot is physically different from a human, which means it often completes tasks differently than we do. For example, in the pen manipulation task below, the hand can grab all the pens together and quickly transfer them between containers, whereas the two-fingered gripper must transport one at a time. Prior research assumes that humans and robots can do the same task similarly, which makes manually specifying one-to-one correspondences between human and robot actions easy. But with stark differences in physique, defining such correspondences for seemingly easy tasks can be surprisingly difficult and sometimes impossible.

Physically different end-effectors (i.e., “grippers”) (i.e., the part that interacts with the environment) induce different control strategies when solving the same task. Left: The hand grabs all pens and quickly transfers them between containers. Right: The two-fingered gripper transports one pen at a time.

In “XIRL: Cross-Embodiment Inverse RL”, presented as an oral paper at CoRL 2021, we explore these challenges further and introduce a self-supervised method for Cross-embodiment Inverse Reinforcement Learning (XIRL). Rather than focusing on how individual human actions should correspond to robot actions, XIRL learns the high-level task objective from videos, and summarizes that knowledge in the form of a reward function that is invariant to embodiment differences, such as shape, actions and end-effector dynamics. The learned rewards can then be used together with reinforcement learning to teach the task to agents with new physical embodiments through trial and error. Our approach is general and scales autonomously with data — the more embodiment diversity presented in the videos, the more invariant and robust the reward functions become. Experiments show that our learned reward functions lead to significantly more sample efficient (roughly 2 to 4 times) reinforcement learning on new embodiments compared to alternative methods. To extend and build on our work, we are releasing an accompanying open-source implementation of our method along with X-MAGICAL, our new simulated benchmark for cross-embodiment imitation.

Cross-Embodiment Inverse Reinforcement Learning (XIRL)
The underlying observation in this work is that in spite of the many differences induced by different embodiments, there still exist visual cues that reflect progression towards a common task objective. For example, in the pen manipulation task above, the presence of pens in the cup but not the mug, or the absence of pens on the table, are key frames that are common to different embodiments and indirectly provide cues for how close to being complete a task is. The key idea behind XIRL is to automatically discover these key moments in videos of different length and cluster them meaningfully to encode task progression. This motivation shares many similarities with unsupervised video alignment research, from which we can leverage a method called Temporal Cycle Consistency (TCC), which aligns videos accurately while learning useful visual representations for fine-grained video understanding without requiring any ground-truth correspondences.

We leverage TCC to train an encoder to temporally align video demonstrations of different experts performing the same task. The TCC loss tries to maximize the number of cycle-consistent frames (or mutual nearest-neighbors) between pairs of sequences using a differentiable formulation of soft nearest-neighbors. Once the encoder is trained, we define our reward function as simply the negative Euclidean distance between the current observation and the goal observation in the learned embedding space. We can subsequently insert the reward into a standard MDP and use an RL algorithm to learn the demonstrated behavior. Surprisingly, we find that this simple reward formulation is effective for cross-embodiment imitation.

XIRL self-supervises reward functions from expert demonstrations using temporal cycle consistency (TCC), then uses them for downstream reinforcement learning to learn new skills from third-person demonstrations.

X-MAGICAL Benchmark
To evaluate the performance of XIRL and baseline alternatives (e.g., TCN, LIFS, Goal Classifier) in a consistent environment, we created X-MAGICAL, which is a simulated benchmark for cross-embodiment imitation. X-MAGICAL features a diverse set of agent embodiments, with differences in their shapes and end-effectors, designed to solve tasks in different ways. This leads to differences in execution speeds and state-action trajectories, which poses challenges for current imitation learning techniques, e.g., ones that use time as a heuristic for weak correspondences between two trajectories. The ability to generalize across embodiments is precisely what X-MAGICAL evaluates.

The SweepToTop task we considered for our experiments is a simplified 2D equivalent of a common household robotic sweeping task, where an agent has to push three objects into a goal zone in the environment. We chose this task specifically because its long-horizon nature highlights how different agent embodiments can generate entirely different trajectories (shown below). X-MAGICAL features a Gym API and is designed to be easily extendable to new tasks and embodiments. You can try it out today with pip install x-magical.

Different agent shapes in the SweepToTop task in the X-MAGICAL benchmark need to use different strategies to reposition objects into the target area (pink), i.e., to “clear the debris”. For example, the long-stick can clear them all in one fell swoop, whereas the short-stick needs to do multiple consecutive back-and-forths.
Left: Heatmap of state visitation for each embodiment across all expert demonstrations. Right: Examples of expert trajectories for each embodiment.

Highlights
In our first set of experiments, we checked whether our learned embodiment-invariant reward function can enable successful reinforcement learning, when the expert demonstrations are provided through the agent itself. We find that XIRL significantly outperforms alternative methods especially on the tougher agents (e.g., short-stick and gripper).

Same-embodiment setting: Comparison of XIRL with baseline reward functions, using SAC for RL policy learning. XIRL is roughly 2 to 4 times more sample efficient than some of the baselines on the harder agents (short-stick and gripper).

We also find that our approach shows great potential for learning reward functions that generalize to novel embodiments. For instance, when reward learning is performed on embodiments that are different from the ones on which the policy is trained, we find that it results in significantly more sample efficient agents compared to the same baselines. Below, in the gripper subplot (bottom right) for example, the reward is first learned on demonstration videos from long-stick, medium-stick and short-stick, after which the reward function is used to train the gripper agent.

Cross-embodiment setting: XIRL performs favorably when compared with other baseline reward functions, trained on observation-only demonstrations from different embodiments. Each agent (long-stick, medium-stick, short-stick, gripper) had its reward trained using demonstrations from the other three embodiments.

We also find that we can train on real-world human demonstrations, and use the learned reward to train a Sawyer arm in simulation to push a puck to a designated target zone. In these experiments as well, our method outperforms baseline alternatives. For example, our XIRL variant trained only on the real-world demonstrations (purple in the plots below) reaches 80% of the total performance roughly 85% faster than the RLV baseline (orange).

What Do The Learned Reward Functions Look Like?
To further explore the qualitative nature of our learned rewards in more challenging real-world scenarios, we collect a dataset of the pen transfer task using various household tools.

Below, we show rewards extracted from a successful (top) and unsuccessful (bottom) demonstration. Both demonstrations follow a similar trajectory at the start of the task execution. The successful one nets a high reward for placing the pens consecutively into the mug then into the glass cup, while the unsuccessful one obtains a low reward because it drops the pens outside the glass cup towards the end of the execution (orange circle). These results are promising because they show that our learned encoder can represent fine-grained visual differences relevant to a task.

Conclusion
We highlighted XIRL, our approach to tackling the cross-embodiment imitation problem. XIRL learns an embodiment-invariant reward function that encodes task progress using a temporal cycle-consistency objective. Policies learned using our reward functions are significantly more sample-efficient than baseline alternatives. Furthermore, the reward functions do not require manually paired video frames between the demonstrator and the learner, giving them the ability to scale to an arbitrary number of embodiments or experts with varying skill levels. Overall, we are excited about this direction of work, and hope that our benchmark promotes further research in this area. For more details, please check out our paper and download the code from our GitHub repository.

Acknowledgments
Kevin and Andy summarized research performed together with Pete Florence, Jonathan Tompson, Jeannette Bohg (faculty at Stanford University) and Debidatta Dwibedi. All authors would additionally like to thank Alex Nichol, Nick Hynes, Sean Kirmani, Brent Yi, Jimmy Wu, Karl Schmeckpeper and Minttu Alakuijala for fruitful technical discussions, and Sam Toyer for invaluable help with setting up the simulated benchmark.

Categories
Offsites

Unlocking the Full Potential of Datacenter ML Accelerators with Platform-Aware Neural Architecture Search

Continuing advances in the design and implementation of datacenter (DC) accelerators for machine learning (ML), such as TPUs and GPUs, have been critical for powering modern ML models and applications at scale. These improved accelerators exhibit peak performance (e.g., FLOPs) that is orders of magnitude better than traditional computing systems. However, there is a fast-widening gap between the available peak performance offered by state-of-the-art hardware and the actual achieved performance when ML models run on that hardware.

One approach to address this gap is to design hardware-specific ML models that optimize both performance (e.g., throughput and latency) and model quality. Recent applications of neural architecture search (NAS), an emerging paradigm to automate the design of ML model architectures, have employed a platform-aware multi-objective approach that includes a hardware performance objective. While this approach has yielded improved model performance in practice, the details of the underlying hardware architecture are opaque to the model. As a result, there is untapped potential to build full capability hardware-friendly ML model architectures, with hardware-specific optimizations, for powerful DC ML accelerators.

In “Searching for Fast Model Families on Datacenter Accelerators”, published at CVPR 2021, we advanced the state of the art of hardware-aware NAS by automatically adapting model architectures to the hardware on which they will be executed. The approach we propose finds optimized families of models for which additional hardware performance gains cannot be achieved without loss in model quality (called Pareto optimization). To accomplish this, we infuse a deep understanding of hardware architecture into the design of the NAS search space for discovery of both single models and model families. We provide quantitative analysis of the performance gap between hardware and traditional model architectures and demonstrate the advantages of using true hardware performance (i.e., throughput and latency), instead of the performance proxy (FLOPs), as the performance optimization objective. Leveraging this advanced hardware-aware NAS and building upon the EfficientNet architecture, we developed a family of models, called EfficientNetX, that demonstrate the effectiveness of this approach for Pareto-optimized ML models on TPUs and GPUs.

Platform-Aware NAS for DC ML Accelerators
To achieve high performance, ML models need to adapt to modern ML accelerators. Platform-aware NAS integrates knowledge of the hardware accelerator properties into all three pillars of NAS: (i) the search objectives; (ii) the search space; and (iii) the search algorithm (shown below). We focus on the new search space because it contains the building blocks needed to compose the models and is the key link between the ML model architectures and accelerator hardware architectures.

We construct TPU/GPU specialized search spaces with TPU/GPU-friendly operations to infuse hardware awareness into NAS. For example, a key adaptation is maximizing parallelism to ensure different hardware components inside the accelerators work together as efficiently as possible. This includes the matrix multiplication units (MXUs) in TPUs and the TensorCore in GPUs for matrix/tensor computation, as well as the vector processing units (VPUs) in TPUs and CUDA cores in GPUs for vector processing. Maximizing model arithmetic intensity (i.e., optimizing the parallelism between computation and operations on the high bandwidth memory) is also critical to achieve top performance. To tap into the full potential of the hardware, it is crucial for ML models to achieve high parallelism inside and across these hardware components.

Overview of platform-aware NAS on TPUs/GPUs, highlighting the search space and search objectives.

Advanced platform-aware NAS has an optimized search space containing a set of complementary techniques to holistically improve parallelism for ML model execution on TPUs and GPUs:

  1. It uses specialized tensor reshaping techniques to maximize the parallelism in the MXUs / TensorCores.
  2. It dynamically selects different activation functions depending on matrix operation types to ensure overlapping of vector and matrix/tensor processing.
  3. It employs hybrid convolutions and a novel fusion strategy to strike a balance between total compute and arithmetic intensity to ensure that computation and memory access happens in parallel and to reduce the contention on VPUs / CUDA cores.
  4. With latency-aware compound scaling (LACS), which uses hardware performance instead of FLOPs as the performance objective to search for model depth, width and resolutions, we ensure parallelism at all levels for the entire model family on the Pareto-front.

EfficientNet-X: Platform-Aware NAS-Optimized Computer Vision Models for TPUs and GPUs
Using this approach to platform-aware NAS, we have designed EfficientNet-X, an optimized computer vision model family for TPUs and GPUs. This family builds upon the EfficientNet architecture, which itself was originally designed by traditional multi-objective NAS without true hardware-awareness as the baseline. The resulting EfficientNet-X model family achieves an average speedup of ~1.5x–2x over EfficientNet on TPUv3 and GPUv100, respectively, with comparable accuracy.

In addition to the improved speeds, EfficientNet-X has shed light on the non-proportionality between FLOPs and true performance. Many think FLOPs are a good ML performance proxy (i.e., FLOPs and performance are proportional), but they are not. While FLOPs are a good performance proxy for simple hardware such as scalar machines, they can exhibit a margin of error of up to 400% on advanced matrix/tensor machines. For example, because of its hardware-friendly model architecture, EfficientNet-X requires ~2x more FLOPs than EfficientNet, but is ~2x faster on TPUs and GPUs.

EfficientNet-X family achieves 1.5x–2x speedup on average over the state-of-the-art EfficientNet family, with comparable accuracy on TPUv3 and GPUv100.

Self-Driving ML Model Performance on New Accelerator Hardware Platforms
Platform-aware NAS exposes the inner workings of the hardware and leverages these properties when designing hardware-optimized ML models. In a sense, the “platform-awareness” of the model is a “gene” that preserves knowledge of how to optimize performance for a hardware family, even on new generations, without the need to redesign the models. For example, TPUv4i delivers up to 3x higher peak performance (FLOPS) than its predecessor TPUv2, but EfficientNet performance only improves by 30% when migrating from TPUv2 to TPUv4i. In comparison, EfficientNet-X retains its platform-aware properties even on new hardware and achieves a 2.6x speedup when migrating from TPUv2 to TPUv4i, utilizing almost all of the 3x peak performance gain expected when upgrading between the two generations.

Hardware peak performance ratio of TPUv2 to TPUv4i and the geometric mean speedup of EfficientNet-X and EfficientNet families, respectively, when migrating from TPUv2 to TPUv4i.

Conclusion and Future Work
We demonstrate how to improve the capabilities of platform-aware NAS for datacenter ML accelerators, especially TPUs and GPUs. Both platform-aware NAS and the EfficientNet-X model family have been deployed in production and materialize up to ~40% efficiency gains and significant quality improvements for various internal computer vision projects across Google. Additionally, because of its deep understanding of accelerator hardware architecture, platform-aware NAS was able to identify critical performance bottlenecks on TPUv2-v4i architectures and has enabled design enhancements to future TPUs with significant potential performance uplift. As next steps, we are working on expanding platform-aware NAS’s capabilities to the ML hardware and model design beyond computer vision.

Acknowledgements
Special thanks to our co-authors: Mingxing Tan, Ruoming Pang, Andrew Li, Liqun Cheng, Quoc Le. We also thank many collaborators including Jeff Dean, David Patterson, Shengqi Zhu, Yun Ni, Gang Wu, Tao Chen, Xin Li, Yuan Qi, Amit Sabne, Shahab Kamali, and many others from the broad Google research and engineering teams who helped on the research and the subsequent broad production deployment of platform-aware NAS.

Categories
Offsites

2021년 회고

2021년은 코로나가 없어지길 바랐지만, 오미크론이라는 변이 바이러스가 나오기도 하고 계속해서 힘든 상황들이 반복되는 것 같습니다. 하루빨리 나아졌으면 좋겠네요.

Work

21년에 가장 큰 이벤트는 아무래도 이직이 아닐까 싶습니다. 대기업에서 다시 스타트업으로 이직을 하게 되었고, 첫직장과 마찬가지로 아주 초기 단계의 스타트업으로 가게 되었습니다.

이직

이직의 가장 큰 이유는 ‘개인의 성장’ 입니다. 처음 네이버로 이직을 하게 되었을 때는 Machine Learning/Deep Learning 이라는 기술에 초점을 맞추었습니다. 제품보다는 커리어를 ML 쪽으로 전환하는 것이 목표였죠.

첫 1년 동안에는 완전히 기술 쪽으로 집중하면서 NLP 프레임워크도 설계부터 개발까지 해보고, 이후로 전사용 언어모델 만들기, 전화받는 AI 프로젝트들의 프로젝트 매니저의 역할을 하다 보니, 자연스럽게 다시 제품과 더 직접적으로 연결이 되어서 일을 하고 싶다는 생각이 강해졌습니다.

물론, 네이버에서 이런 경험을 할 수 없다는 것은 아닙니다. 그렇지만 스타트업과 대기업의 속도 차이는 구조적으로 발생할 수밖에 없는 문제가 있고, 기존의 구축되어 있는 브랜딩 이미지를 이용하는 것과 새롭게 브랜드를 만들어 나가는 것은 제품을 만드는 방식에도 큰 차이가 있기 때문입니다.

한 마디로 조금 더 직접적으로 고객을 만나고, 겪고 있는 문제를 이해하고 해결하는 제품을 만들고 싶었습니다.

image.png

출처: 1. 우아한형제들 부검 – 왜 떠나는지 by 향로

그리고 초기 스타트업에서 시리즈 단계 별로 커가는 경험을 하고 싶었습니다. 우아한형제들에서 인프런으로 이직을 하신 이동욱님이 작성하신 이직 부검에서 이야기 하고 있는것과 비슷하게, 저 역시 Seed 단계의 스타트업과 대기업의 경험만 있기 때문입니다. 위와 같이 초기부터 100명의 인원이 넘어가는 단계(시리즈 C 혹은 D)까지 성장을 하는 경험은 무엇보다 귀중한 디딤돌들이 될 것이라 생각합니다.

그동안 다양한 경험을 하면서 깨달은 것은 ‘직접 상황에 처했을 때, 더 깊은 고민을 하게 된다는 것’ 입니다. 프로젝트 매니저의 역할을 맡았을 때도 마찬가지였습니다. 깊게 고민해 보면서 책에서 답을 찾아보기도 하고, 주변의 도움을 받아보기도 하고, 직접 다양한 시도를 하면서 배우는 것들이 많았습니다.

그래서 성장하고 싶은 방향으로 저 자신을 놓아둘 수 있으면, 그 상황 속에서 더 많이 배우고 성장할 수 있을 것이라는 믿음이 있습니다.

엘박스

image.png

출처: https://lbox.kr/

그렇게 이직을 하게 된 것은 리걸테크 스타트업 엘박스 입니다. 엘박스를 선택했던 큰 이유로는 다음 3가지가 있습니다.

  1. 약 10명 내외의 Pre-A 단계의 초기 스타트업이었고, 좋은 팀이 구성되어 있었다는 것
  2. 리걸 테크라는 도메인이 전반적으로 성장해가고 있다는 것
  3. 법률 도메인에서 NLP 기술을 통해서 풀 수 있는 문제들이 굉장히 많다는 것

위 이유들을 간단하게 말하면 Early Stage의 팀, 시장, 내가 잘할 수 있는 것 이렇게 정리가 될 것 같네요. 그 중 2번과 3번이 아주 잘 융합이 될 수 있다고 생각이 됩니다. 머신러닝 기술은 점점 고도화가 되고 있고, 시장 역시 점점 디지털화가 되어가고 있기 때문에, 수많은 텍스트 데이터가 즐비한 법률 도메인에서 NLP 기술이 빛을 발할 수 있다고 보았습니다.

엘박스에 와서 했던 작업들 중에는 아래와 같은 것들이 있습니다. 관련 내용들은 회사의 기술 블로그로 공유를 기약해 봅니다.

  • 유사판례와 키워드 검색의 성능 및 속도 향상
  • 추천검색어 (검색어 자동완성) 개선
  • 법률 도메인에 특화된 언어모델 개발
  • ML API 개발과 ML Serving 파이프라인 개선, 데이터 파이프라인(Airflow) 도입
  • 서비스 모니터링 시스템 도입
  • 데이터 기반 의사결정을 위한 BI Tool 셋팅
  • 일하는 방식의 개선 (Cell 단위 조직화)

멤버십 플랜 출시

이직을 하고 한 가지 특별했던 이벤트는 구독을 기반으로 하는 멤버십 플랜을 출시했다는 것입니다. 유료화는 그동안 제공하던 서비스가 정말로 그 만한 가치를 제공하는지 고객의 반응을 그대로 볼 수 있는, 이벤트라기보다는 시험대에 올라가는 순간이 더 적절한 표현 같습니다.

이렇게 실제로 멤버십 플랜을 구독하는 사용자들을 보면서, 가격이 서비스의 가치에 직결이 된다는 ≪헤르만 지몬의 프라이싱≫의 책을 내용들을 이해할 수 있었습니다.

궁극적으로 소비자는 자신이 얻을 가치만큼만 지불하고자 한다. 모든 판매자의 과제는 소비자가 지각하는 가치가 어느 정도인지 알아내고, 상품이나 서비스의 가격을 이에 맞추는 일이다. 판매자와의 교환이 공정했다는 여운을 남길 때에만 소비자는 충성스러운 고객으로 남는다. 고객 만족은 장기이익을 극대화하는 유일한 방법이다.

  • 서문 고백 중에서

“가격결정력은 회사를 평가하는 데 있어 단일 요소로는 가장 강력한 기업 결정이다.” 투자자 워런 버핏이 한 말이다. “만약 당신이 가격을 올리기 전에 기도 따위나 하고 있다면 사업을 정말 잘못하고 있는 것이다.”

  • 가격결정력 중에서

사내에서 많이 이야기하는 주제 중에 이런 말이 있습니다. ‘우리는 비타민이 아닌 진통제를 만들어야 한다’ 입니다. 진통제를 만들어야만 고객에게 실제로 가치를 줄 수 있고, 우리가 가격을 결정할 수 있기 때문이죠.

프라이싱의 책에서 가장 중요한 문구가 위 서문이라고 생각을 합니다. ‘고객 만족’에 집중하여 가치를 만들고 그 가치를 가격으로 결정할 수 있어야 한다는 것이죠.

Quantified Self

QS는 실제 데이터를 기반으로 더 좋은 습관을 만들기 위한 활동입니다.

올해에도 저 자신에 대한 데이터를 기반으로 이것저것 살펴보려고 합니다.

image.png

그림 1: 월별 생산도 점수와 수면 점수의 변화

한 가지 큰 변화는 4월에 이직을 하면서 출퇴근 시간이 많이 늘어났다는 것입니다. 늘어난 만큼 일찍 자야 하지만, 항상 그렇듯 자는 시간을 줄이게 됩니다. 이렇게 상황이 바뀌게 되었을 때, 수면 리듬을 맞춰가는 데 있어서 보통 6개월의 시간은 필요한 것 같습니다. 3월부터 쭉 점수가 줄어들다가 9월 이후로 계속해서 같은 시간에 자고 일어나는 연습을 통해서 나아진 모습을 보이고 있습니다. 그와 별개로 이직하고 일을 열심히 하고 있다는 것이 생산성 점수에서 보이고 있습니다.

image.png

그림 2: 월별 작업 시간 그래프 (연구: 짙은 파란색, 개발: 파란색, 미팅: 보라색: 매니징관련: 자주색)

이직을 하고 나서는 단연 직접 개발하고 모델링을 하는 시간을 늘었습니다. 아직 작은 스타트업이기도 하고 머신러닝 엔지니어로서 기여를 하는 방향으로 들어갔기 때문에, 아마 계속해서 직접 개발해서 서비스에 연결하는 일들을 하고 있을 것 같습니다.

최근에 회사에서 미팅 시간이 길어지는 것들에 대해서 모두 경각심을 가지고, 줄이기 위한 노력들을 하고 있는데 딱 12월에 미팅 시간이 갑자기 확 늘어난 것이 눈에 보이네요. 22년에 미팅 시간을 다시 살펴보면 노력의 결과들을 볼 수 있을 것 같네요.

Rize

image.png

출처: rize.io

최근에 알게 된 몰입을 위한 시간 측정도구인데, 이렇게 시간을 자동으로 추적하고 인지할 수 있도록 도와주는 점에서 그동안 사이드 프로젝트로 해왔던 일들과 유사한 점이 많다고 느꼈습니다. 굉장히 써보고 싶어지기도 하면서, Rize에서 내세우는 ‘몰입’을 측정하는 것도 해보고 싶다는 생각이 듭니다.

아래는 Rize의 Founder 들이 인터뷰했던 내용의 일부입니다.

습관을 들이는 가장 좋은 방법은 그것을 정량화하고 매일 복습하는 것이라고 생각합니다. ”측정하지 않는 것은 개선할 수 없다”고 말하지만, 측정하는 것도 긍정적인 습관을 강화하는 데 도움이 되는 좋은 방법이라고 생각합니다. 예를 들어, 집중 시간을 되찾고 싶다면 매일 얼마나 받고 있는지, 시간이 지남에 따라 개선되고 있는지 확인합니다. 집중 시간이 훨씬 적은 날에는 문제를 악화시킨 원인(보통 회의임)을 이해하기 위해 더 깊이 파고들어 거기서부터 반복합니다.

이런 비슷한 제품을 만들어 보고 싶다는 생각을 한 적이 많았는데, 이렇게 실제 제품을 보니까 묘한 기분이 들기도 하고, ‘역시 중요한 것은 실행이구나’라는 생각을 또 하게 됩니다.

Book

image.png

그림 3: 2021년에 읽은 책들과 각각 소요된 시간

21년에는 무엇보다 책에 습관이 더 단단해졌던 한 해였습니다. 20년에 읽다가 넘어온 책도 있을 것이고, 여전히 읽고 있는 책도 있을 것이지만 대략적으로 25권 정도의 책을 보았습니다. 그리고 책을 읽는데 쓴 시간은 대략 265.4 시간이네요. 지난해(259.5시간)에 비해서 약 2% 올랐습니다. 줄지 않은 것만으로 다행이라고 생각이 드네요. 요즘에는 더더욱 책을 읽기 위해서 일부로 시간을 만들고 또 갑자기 빈 시간을 책을 읽는데 활용하려는 습관이 생겼다는 것을 스스로도 느끼고 있습니다.

역설적이게도 이렇게 책을 더 읽을수록, 새롭게 읽고 싶은 책은 더 빨리 쌓이고 있네요. 읽다 보면 새로운 책을 소개해 주니, 계속해서 보려고 적어놓는 책 리스트의 길이가 길어만 지네요..!

image.png

출처: 장서의 괴로운 – 알라딘

회사에서 ‘밀리의 서재’를 지원해 줘서 E-Book을 체험해 보고 있는데, 확실히 편리한 장점이 있네요. 로망 중에 하나였던 책장을 다 채우고 나니 앞으로 이사 가는 것이 걱정되기도 하고, 동시에 여러 권의 책을 보는 데 있어서 디지털이 확실히 편리하다 보니, 조금씩 E-Book으로 옮겨가는 것을 고려해 보고 있습니다. 아직은 한참 멀었지만, 빨간책방에서 다루었던 ≪장서의 괴로움≫을 듣고 나니 E-Book으로 빨리 넘어가는 것도 좋겠다는 생각을 하게 되네요.

올해의 책

image.png

출처: 스토너 – 알라딘

올해의 책은 ≪스토너≫ 입니다. 평범한 교수의 일대기를 다루는 작품인데, ‘빨간책방’ 에서 이동진 작가의 이야기가 묘하게 머리 한구석에 자리를 잡아서 계속해서 읽어봐야지 생각을 하게 되었던 책이기도 합니다.

이 책을 보는 사람은 두 종류로 나뉠 것 같아요. 끝까지 읽은 사람과 그렇지 못한 사람. 끝까지 못 읽는 사람이 있을 순 있겠지만, 다 읽은 사람이면 안 좋아할 수 없는 책인 거 같아요. 이상하게 감동이 있는 책입니다.

실제로 끝까지 다 읽고 나니, 이상하게 감동이 있었던 책입니다. 막상 책을 다 읽었을 때는 스토너의 인생의 여러 가지 일들을 생각해 보며 씁쓸한 마음이 들었는데, 저자의 인터뷰를 보면서 ‘아, 그렇지..’ 이런 생각이 들었습니다. 읽어보시기를 권해드리고 싶네요.

나는 그가 진짜 영웅이라고 생각합니다. 이 소설을 읽은 많은 사람들이 스토너의 삶을 슬프고 불행한 것으로 봅니다. 하지만 내가 보기에 그의 삶은 아주 훌륭한 것이었습니다. 그가 대부분의 사람들보다 나은 삶을 살았던 것은 분명합니다. 자신이 하고 싶은 일을 하면서 그 일에 어느 정도 애정을 갖고 있었고, 그 일에 의미가 있다는 생각도 했으니까요……. 내가 보기에 이 소설에서 중요한 것은 스토너가 자신의 일에 대해 갖고 있는 생각입니다……. 훌륭하고 명예로운 의미의 ‘일’ 말입니다. 그는 일 덕분에 특정한 정체감을 얻었습니다.

Blog

글을 꾸준히 쓴다는 것은 정말 매번 결심하면서도 흐지부지되는, 그러면서도 또다시 결심을 하게 되는 일 같습니다. 작년에는 56시간 정도를 사용했었는데, 올해는 약 43시간으로 많이 줄었습니다. 그래도 올해에는 개인 블로그에는 총 7개의 글, 회사 기술블로그 1개를 포함해서 총 8개의 글을 작성하였습니다. 목표는 월에 1개 즉, 한 해에는 12개의 글을 작성해 보고자 합니다. 다음은 PV순으로 정렬해 본 블로그 글의 리스트입니다.

회사 기술블로그

개인 블로그

글의 PV는 회사 기술블로그가 월등히 높은 편인데, 회사의 입장에서 외부에 어필할 수 있는 내용이면서 따로 홍보를 해주었기 때문에 그럴 수밖에 없다는 생각이 듭니다. 그럼에도 지금 작성하고 있는 글의 목적이 무엇인가는 생각을 해보게 됩니다. 어찌 되었든 블로그에 글을 올린다는 것은 누군가 읽는다는 것을 가정하고 있기 때문이죠. 그래서 회사 기술블로그 글에서 열심히 홍보했던 것처럼, 개인 블로그의 글들도 홍보에 대한 필요성을 생각하게 됩니다. 이러한 생각이 들었던 가장 큰 이유는 ‘퍼스널브랜딩’ 입니다.

퍼스널브랜딩

이제는 커리어상으로 자신을 드러낼 수 있는 방법이 이력 외에도 굉장히 많은 것들이 있습니다. 개발자에게는 대표적으로 Github, 개발블로그, 컨퍼런스 등이 있습니다. 여기서 요즘에는 컨퍼런스도 다양하게 세분화되는 것이 눈에 보입니다. 예전에는 PyCon, Deview 등의 큰 컨퍼런스가 있었다면, 최근에는 각 회사마다 컨퍼런스를 개최하고 또 원티트 혹은 패스트캠퍼스에서 연사들을 모아서 특정 주제의 컨퍼런스를 만들기도 합니다. 또 누군가는 YouTube에 컨텐츠를 올리기도 하고, 페이스북 그룹에 자신의 작업들을 홍보하거나, 관련 글들을 공유하기도 합니다.

이렇듯 자신의 경험과 결과물들을 다양한 방식으로 이야기를 하는 시대가 되고 있습니다. 그래서 블로그 글을 작성함에 있어서도 누구나 쓸 수 있는 글보다는 나의 경험과 생각이 녹아있는 글을 쓰려고 하고 있습니다. 그런데 이 목적에 한 가지 조건을 더 첨가해야겠다는 생각이 들었던 영상이 하나 있었습니다.

최근에 EO에서 나왔던 이진선님의 영상을 보면 이런 말이 있습니다.

Q. 내가 쓰고 싶은 글과 사람들이 원하는 글이 다른데 어떤 글을 써야 할까요?
퍼스널브랜딩을 위한 글쓰기를 하고 싶다면, 반드시 내가 쓰고 싶은 글과 사람들이 원하는 글의 교집합을 찾아야 합니다.

지금까지는 내가 쓰고 싶은 글에 집중하고 있었다고 생각이 들고, 조금씩 사람들이 원하는 글은 무엇인지 알아보면서 퍼스널브랜딩을 강화시키고자 합니다. 결국은 시장이 원하는 것을 알아야 한다는 것이죠.

끝으로

조직과 경영

다시 스타트업으로 돌아오면서 가장 많이 생각하게 되었던 주제 중의 하나는 조직과 경영입니다. 이 주제에 대해서는 다음과 같은 생각들을 가지고 있었습니다.

1. 적합한 사람들을 뽑는다.
2. 이 사람들이 마음껏 능력을 발휘할 수 있는 구조를 만든다.

여기서 간과했던 점은 1번의 ‘적합한 사람들을 뽑지 못한다면?’ 이였습니다. 스타트업의 고질적인 문제인 리소스를 겪다 보면, 1번이 얼마나 어려운 일인지 실감을 하게 되기도 합니다. 그래서 추가하게 된 명제가 있습니다.

바로 피터드러커가 이야기한 ‘조직의 목적은 평범한 사람이 비범한 일을 할 수 있도록 만드는 것이다.’ 입니다. 이 명제를 추가함으로써 현재 가지고 있는 조직과 경영에 대한 생각은 다음과 같습니다.

1. 최대한 적합한 사람들을 뽑는다.
2. 이 사람들이 탁월한 결과를 만들 수 있게 돕고, 마음껏 능력을 발휘할 수 있는 구조를 만든다.

22년에는 무엇보다 많이 시도해 보고 또 배울 수 있는 한 해가 되기를 바랍니다!

Categories
Offsites

The mathematically optimal Wordle strategy

Categories
Offsites

Can Robots Follow Instructions for New Tasks?

People can flexibly maneuver objects in their physical surroundings to accomplish various goals. One of the grand challenges in robotics is to successfully train robots to do the same, i.e., to develop a general-purpose robot capable of performing a multitude of tasks based on arbitrary user commands. Robots that are faced with the real world will also inevitably encounter new user instructions and situations that were not seen during training. Therefore, it is imperative for robots to be trained to perform multiple tasks in a variety of situations and, more importantly, to be capable of solving new tasks as requested by human users, even if the robot was not explicitly trained on those tasks.

Existing robotics research has made strides towards allowing robots to generalize to new objects, task descriptions, and goals. However, enabling robots to complete instructions that describe entirely new tasks has largely remained out-of-reach. This problem is remarkably difficult since it requires robots to both decipher the novel instructions and identify how to complete the task without any training data for that task. This goal becomes even more difficult when a robot needs to simultaneously handle other axes of generalization, such as variability in the scene and positions of objects. So, we ask the question: How can we confer noteworthy generalization capabilities onto real robots capable of performing complex manipulation tasks from raw pixels? Furthermore, can the generalization capabilities of language models help support better generalization in other domains, such as visuomotor control of a real robot?

In “BC-Z: Zero-Shot Task Generalization with Robotic Imitation Learning”, published at CoRL 2021, we present new research that studies how robots can generalize to new tasks that they were not trained to do. The system, called BC-Z, comprises two key components: (i) the collection of a large-scale demonstration dataset covering 100 different tasks and (ii) a neural network policy conditioned on a language or video instruction of the task. The resulting system can perform at least 24 novel tasks, including ones that require interaction with pairs of objects that were not previously seen together. We are also excited to release the robot demonstration dataset used to train our policies, along with pre-computed task embeddings.

The BC-Z system allows a robot to complete instructions for new tasks that the robot was not explicitly trained to do. It does so by training the policy to take as input a description of the task along with the robot’s camera image and to predict the correct action.

Collecting Data for 100 Tasks
Generalizing to a new task altogether is substantially harder than generalizing to held-out variations in training tasks. Simply put, we want robots to have more generalization all around, which requires that we train them on large amounts of diverse data.

We collect data by teleoperating the robot with a virtual reality headset. This data collection follows a scheme similar to how one might teach an autonomous car to drive. First, the human operator records complete demonstrations of each task. Then, once the robot has learned an initial policy, this policy is deployed under close supervision where, if the robot starts to make a mistake or gets stuck, the operator intervenes and demonstrates a correction before allowing the robot to resume.

This mixture of demonstrations and interventions has been shown to significantly improve performance by mitigating compounding errors. In our experiments, we see a 2x improvement in performance when using this data collection strategy compared to only using human demonstrations.

Example demonstrations collected for 12 out of the 100 training tasks, visualized from the perspective of the robot and shown at 2x speed.

Training a General-Purpose Policy
For all 100 tasks, we use this data to train a neural network policy to map from camera images to the position and orientation of the robot’s gripper and arm. Crucially, to allow this policy the potential to solve new tasks beyond the 100 training tasks, we also input a description of the task, either in the form of a language command (e.g., “place grapes in red bowl”) or a video of a person doing the task.

To accomplish a variety of tasks, the BC-Z system takes as input either a language command describing the task or a video of a person doing the task, as shown here.

By training the policy on 100 tasks and conditioning the policy on such a description, we unlock the possibility that the neural network will be able to interpret and complete instructions for new tasks. This is a challenge, however, because the neural network needs to correctly interpret the instruction, visually identify relevant objects for that instruction while ignoring other clutter in the scene, and translate the interpreted instruction and perception into the robot’s action space.

Experimental Results
In language models, it is well known that sentence embeddings generalize on compositions of concepts encountered in training data. For instance, if you train a translation model on sentences like “pick up a cup” and “push a bowl”, the model should also translate “push a cup” correctly.

We study the question of whether the compositional generalization capabilities found in language encoders can be transferred to real robots, i.e., being able to compose unseen object-object and task-object pairs.

We test this method by pre-selecting a set of 28 tasks, none of which were among the 100 training tasks. For example, one of these new test tasks is to pick up the grapes and place them into a ceramic bowl, but the training tasks involve doing other things with the grapes and placing other items into the ceramic bowl. The grapes and the ceramic bowl never appeared in the same scene during training.

In our experiments, we see that the robot can complete many tasks that were not included in the training set. Below are a few examples of the robot’s learned policy.

The robot completes three instructions of tasks that were not in its training data, shown at 2x speed.

Quantitatively, we see that the robot can succeed to some degree on a total of 24 out of the 28 held-out tasks, indicating a promising capacity for generalization. Further, we see a notably small gap between the performance on the training tasks and performance on the test tasks. These results indicate that simply improving multi-task visuomotor control could considerably improve performance.

The BC-Z performance on held-out tasks, i.e., tasks that the robot was not trained to perform. The system correctly interprets the language command and translates that into action to complete many of the tasks in our evaluation.

Takeaways
The results of this research show that simple imitation learning approaches can be scaled in a way that enables zero-shot generalization to new tasks. That is, it shows one of the first indications of robots being able to successfully carry out behaviors that were not in the training data. Interestingly, language embeddings pre-trained on ungrounded language corpora make for excellent task conditioners. We demonstrated that natural language models can not only provide a flexible input interface to robots, but that pretrained language representations actually confer new generalization capabilities to the downstream policy, such as composing unseen object pairs together.

In the course of building this system, we confirmed that periodic human interventions are a simple but important technique for achieving good performance. While there is a substantial amount of work to be done in the future, we believe that the zero-shot generalization capabilities of BC-Z are an important advancement towards increasing the generality of robotic learning systems and allowing people to command robots. We have released the teleoperated demonstrations used to train the policy in this paper, which we hope will provide researchers with a valuable resource for future multi-task robotic learning research.

Acknowledgements
We would like to thank the co-authors of this research: Alex Irpan, Mohi Khansari, Daniel Kappler, Frederik Ebert, Corey Lynch, and Sergey Levine. This project was a collaboration between Google Research and the Everyday Robot Project. We would like to give special thanks to Noah Brown, Omar Cortes, Armando Fuentes, Kyle Jeffrey, Linda Luu, Sphurti Kirit More, Jornell Quiambao, Jarek Rettinghouse, Diego Reyes, Rosario Jau-regui Ruano, and Clayton Tan for overseeing robot operations and collecting human videos of the tasks, as well as Jeffrey Bingham, Jonathan Weisz, and Kanishka Rao for valuable discussions. We would also like to thank Tom Small for creating animations in this post and Paul Mooney for helping with dataset open-sourcing.

Categories
Offsites

Applying Differential Privacy to Large Scale Image Classification

Machine learning (ML) models are becoming increasingly valuable for improved performance across a variety of consumer products, from recommendations to automatic image classification. However, despite aggregating large amounts of data, in theory it is possible for models to encode characteristics of individual entries from the training set. For example, experiments in controlled settings have shown that language models trained using email datasets may sometimes encode sensitive information included in the training data and may have the potential to reveal the presence of a particular user’s data in the training set. As such, it is important to prevent the encoding of such characteristics from individual training entries. To these ends, researchers are increasingly employing federated learning approaches.

Differential privacy (DP) provides a rigorous mathematical framework that allows researchers to quantify and understand the privacy guarantees of a system or an algorithm. Within the DP framework, privacy guarantees of a system are usually characterized by a positive parameter ε, called the privacy loss bound, with smaller ε corresponding to better privacy. One usually trains a model with DP guarantees using DP-SGD, a specialized training algorithm that provides DP guarantees for the trained model.

However training with DP-SGD typically has two major drawbacks. First, most existing implementations of DP-SGD are inefficient and slow, which makes it hard to use on large datasets. Second, DP-SGD training often significantly impacts utility (such as model accuracy) to the point that models trained with DP-SGD may become unusable in practice. As a result most DP research papers evaluate DP algorithms on very small datasets (MNIST, CIFAR-10, or UCI) and don’t even try to perform evaluation of larger datasets, such as ImageNet.

In “Toward Training at ImageNet Scale with Differential Privacy”, we share initial results from our ongoing effort to train a large image classification model on ImageNet using DP while maintaining high accuracy and minimizing computational cost. We show that the combination of various training techniques, such as careful choice of the model and hyperparameters, large batch training, and transfer learning from other datasets, can significantly boost accuracy of an ImageNet model trained with DP. To substantiate these discoveries and encourage follow-up research, we are also releasing the associated source code.

Testing Differential Privacy on ImageNet
We choose ImageNet classification as a demonstration of the practicality and efficacy of DP because: (1) it is an ambitious task for DP, for which no prior work shows sufficient progress; and (2) it is a public dataset on which other researchers can operate, so it represents an opportunity to collectively improve the utility of real-life DP training. Classification on ImageNet is challenging for DP because it requires large networks with many parameters. This translates into a significant amount of noise added into the computation, because the noise added scales with the size of the model.

Scaling Differential Privacy with JAX
Exploring multiple architectures and training configurations to research what works for DP can be debilitatingly slow. To streamline our efforts, we used JAX, a high-performance computational library based on XLA that can do efficient auto-vectorization and just-in-time compilation of the mathematical computations. Using these JAX features was previously recommended as a good way to speed up DP-SGD in the context of smaller datasets such as CIFAR-10.

We created our own implementation of DP-SGD on JAX and benchmarked it against the large ImageNet dataset (the code is included in our release). The implementation in JAX was relatively simple and resulted in noticeable performance gains simply because of using the XLA compiler. Compared to other implementations of DP-SGD, such as that in Tensorflow Privacy, the JAX implementation is consistently several times faster. It is typically even faster compared to the custom-built and optimized PyTorch Opacus.

Each step of our DP-SGD implementation takes approximately two forward-backward passes through the network. While this is slower than non-private training, which requires only a single forward-backward pass, it is still the most efficient known approach to train with the per-example gradients necessary for DP-SGD. The graph below shows training runtimes for two models on ImageNet with DP-SGD vs. non-private SGD, each on JAX. Overall, we find DP-SGD on JAX sufficiently fast to run large experiments just by slightly reducing the number of training runs used to find optimal hyperparameters compared to non-private training. This is significantly better than alternatives, such as Tensorflow Privacy, which we found to be ~5x–10x slower on our CIFAR10 and MNIST benchmarks.

Time in seconds per training epoch on ImageNet using a Resnet18 or Resnet50 architecture with 8 V100 GPUs.

Combining Techniques for Improved Accuracy
It is possible that future training algorithms may improve DP’s privacy-utility tradeoff. However, with current algorithms, such as DP-SGD, our experience points to an engineering “bag-of-tricks” approach to make DP more practical on challenging tasks like ImageNet.

Because we can train models faster with JAX, we can iterate quickly and explore multiple configurations to find what works well for DP. We report the following combination of techniques as useful to achieve non-trivial accuracy and privacy on ImageNet:

  • Full-batch training

    Theoretically, it is known that larger minibatch sizes improve the utility of DP-SGD, with full-batch training (i.e., where a full dataset is one batch) giving the best utility [1, 2], and empirical results are emerging to support this theory. Indeed, our experiments demonstrate that increasing the batch size along with the number of training epochs leads to a decrease in ε while still maintaining accuracy. However, training with extremely large batches is non-trivial as the batch cannot fit into GPU/TPU memory. So, we employed virtual large-batch training by accumulating gradients for multiple steps before updating the weights instead of applying gradient updates on each training step.

    Batch size 1024 4 × 1024 16 × 1024 64 × 1024
    Number of epochs 10 40 160 640
    Accuracy 56% 57.5% 57.9% 57.2%
    Privacy loss bound ε 9.8 × 108 6.1 × 107 3.5 × 106 6.7 × 104

  • Transfer learning from public data

    Pre-training on public data followed by DP fine-tuning on private data has previously been shown to improve accuracy on other benchmarks [3, 4]. A question that remains is what public data to use for a given task to optimize transfer learning. In this work we simulate a private/public data split by using ImageNet as “private” data and using Places365, another image classification dataset, as a proxy for “public” data. We pre-trained our models on Places365 before fine-tuning them with DP-SGD on ImageNet. Places365 only has images of landscapes and buildings, not of animals as ImageNet, so it is quite different, making it a good candidate to demonstrate the ability of the model to transfer to a different but related domain.

    We found that transfer learning from Places365 gave us 47.5% accuracy on ImageNet with a reasonable level of privacy (ε = 10). This is low compared to the 70% accuracy of a similar non-private model, but compared to naïve DP training on ImageNet, which yields either very low accuracy (2 – 5%) or no privacy (ε=109), this is quite good.

Privacy-accuracy tradeoff for Resnet-18 on ImageNet using large-batch training with transfer learning from Places365.

Next Steps
We hope these early results and source code provide an impetus for other researchers to work on improving DP for ambitious tasks such as ImageNet as a proxy for challenging production-scale tasks. With the much faster DP-SGD on JAX, we urge DP and ML researchers to explore diverse training regimes, model architectures, and algorithms to make DP more practical. To continue advancing the state of the field, we recommend researchers start with a baseline that incorporates full-batch training plus transfer learning.

Acknowledgments
This work was carried out with the support of the Google Visiting Researcher Program while Prof. Geambasu, an Associate Professor with Columbia University, was on sabbatical with Google Research. This work received substantial contributions from Steve Chien, Shuang Song, Andreas Terzis and Abhradeep Guha Thakurta.

Categories
Offsites

Controlling Neural Networks with Rule Representations

Deep neural networks (DNNs) provide more accurate results as the size and coverage of their training data increases. While investing in high-quality and large-scale labeled datasets is one path to model improvement, another is leveraging prior knowledge, concisely referred to as “rules” — reasoning heuristics, equations, associative logic, or constraints. Consider a common example from physics where a model is given the task of predicting the next state in a double pendulum system. While the model may learn to estimate the total energy of the system at a given point in time only from empirical data, it will frequently overestimate the energy unless also provided an equation that reflects the known physical constraints, e.g., energy conservation. The model fails to capture such well-established physical rules on its own. How could one effectively teach such rules so that DNNs absorb the relevant knowledge beyond simply learning from the data?

In “Controlling Neural Networks with Rule Representations”, published at NeurIPS 2021, we present Deep Neural Networks with Controllable Rule Representations (DeepCTRL), an approach used to provide rules for a model agnostic to data type and model architecture that can be applied to any kind of rule defined for inputs and outputs. The key advantage of DeepCTRL is that it does not require retraining to adapt the rule strength. At inference, the user can adjust rule strength based on the desired operation point of accuracy. We also propose a novel input perturbation method, which helps generalize DeepCTRL to non-differentiable constraints. In real-world domains where incorporating rules is critical — such as physics and healthcare — we demonstrate the effectiveness of DeepCTRL in teaching rules for deep learning. DeepCTRL ensures that models follow rules more closely while also providing accuracy gains at downstream tasks, thus improving reliability and user trust in the trained models. Additionally, DeepCTRL enables novel use cases, such as hypothesis testing of the rules on data samples and unsupervised adaptation based on shared rules between datasets.

The benefits of learning from rules are multifaceted:

  • Rules can provide extra information for cases with minimal data, improving the test accuracy.
  • A major bottleneck for widespread use of DNNs is the lack of understanding the rationale behind their reasoning and inconsistencies. By minimizing inconsistencies, rules can improve the reliability of and user trust in DNNs.
  • DNNs are sensitive to slight input changes that are human-imperceptible. With rules, the impact of these changes can be minimized as the model search space is further constrained to reduce underspecification.

Learning Jointly from Rules and Tasks
The conventional approach to implementing rules incorporates them by including them in the calculation of the loss. There are three limitations of this approach that we aim to address: (i) rule strength needs to be defined before learning (thus the trained model cannot operate flexibly based on how much the data satisfies the rule); (ii) rule strength is not adaptable to target data at inference if there is any mismatch with the training setup; and (iii) the rule-based objective needs to be differentiable with respect to learnable parameters (to enable learning from labeled data).

DeepCTRL modifies canonical training by creating rule representations, coupled with data representations, which is the key to enable the rule strength to be controlled at inference time. During training, these representations are stochastically concatenated with a control parameter, indicated by α, into a single representation. The strength of the rule on the output decision can be improved by increasing the value of α. By modifying α at inference, users can control the behavior of the model to adapt to unseen data.

DeepCTRL pairs a data encoder and rule encoder, which produce two latent representations, which are coupled with corresponding objectives. The control parameter α is adjustable at inference to control the relative weight of each encoder.

Integrating Rules via Input Perturbations
Training with rule-based objectives requires the objectives to be differentiable with respect to the learnable parameters of the model. There are many valuable rules that are non-differentiable with respect to input. For example, “higher blood pressure than 140 is likely to lead to cardiovascular disease” is a rule that is hard to be combined with conventional DNNs. We also introduce a novel input perturbation method to generalize DeepCTRL to non-differentiable constraints by introducing small perturbations (random noise) to input features and constructing a rule-based constraint based on whether the outcome is in the desired direction.

Use Cases
We evaluate DeepCTRL on machine learning use cases from physics and healthcare, where utilization of rules is particularly important.

  • Improved Reliability Given Known Principles in Physics
  • We quantify reliability of a model with the verification ratio, which is the fraction of output samples that satisfy the rules. Operating at a better verification ratio could be beneficial, especially if the rules are known to be always valid, as in natural sciences. By adjusting the control parameter α, a higher rule verification ratio, and thus more reliable predictions, can be achieved.

    To demonstrate this, we consider the time-series data generated from double pendulum dynamics with friction from a given initial state. We define the task as predicting the next state of the double pendulum from the current state while imposing the rule of energy conservation. To quantify how much the rule is learned, we evaluate the verification ratio.

    DeepCTRL enables controlling a model’s behavior after learning, but without retraining. For the example of a double pendulum, conventional learning imposes no constraints to ensure the model follows physical laws, e.g., conservation of energy. The situation is similar for the case of DeepCTRL where the rule strength is low. So, the total energy of the system predicted at time t+1 ( blue) can sometimes be greater than that measured at time t (red), which is physically disallowed (bottom left). If rule strength in DeepCTRL is high, the model may follow the given rule but lose accuracy (discrepancy between red and blue is larger; bottom right). If rule strength is between the two extremes, the model may achieve higher accuracy (blue curve is close to red) and follow the rule properly (blue curve is lower than red one).

    We compare the performance of DeepCTRL on this task to conventional baselines of training with a fixed rule-based constraint as a regularization term added to the objective, λ. The highest of these regularization coefficients provides the highest verification ratio (shown by the green line in the second graph below), however, the prediction error is slightly worse than that of λ = 0.1 (orange line). We find that the lowest prediction error of the fixed baseline is comparable to that of DeepCTRL, but the highest verification ratio of the fixed baseline is still lower, which implies that DeepCTRL could provide accurate predictions while following the law of energy conservation. In addition, we consider the benchmark of imposing the rule-constraint with Lagrangian Dual Framework (LDF) and demonstrate two results where its hyperparameters are chosen by the lowest mean absolute error (LDF-MAE) and the highest rule verification ratio (LDF-Ratio) on the validation set. The performance of the LDF method is highly sensitive to what the main constraint is and its output is not reliable (black and pink dashed lines).

    Experimental results for the double pendulum task, showing the task-based mean absolute error (MAE), which measures the discrepancy between the ground truth and the model prediction, versus DeepCTRL as a function of the control parameter α. TaskOnly doesn’t have a rule constraint and Task & Rule has different rule strength (λ). LDF enforces rules by solving a constraint optimization problem.
    As above, but showing the verification ratio from different models.
    Experimental results for the double pendulum task showing the current and predicted energy at time t and t + 1, respectively.

    Additionally, the figures above illustrate the advantage DeepCTRL has over conventional approaches. For example, increasing the rule strength λ from 0.1 to 1.0 improves the verification ratio (from 0.7 to 0.9), but does not improve the mean absolute error. Arbitrarily increasing λ will continue to drive the verification ratio closer to 1, but will result in worse accuracy. Thus, finding the optimal value of λ will require many training runs through the baseline model, whereas DeepCTRL can find the optimal value for the control parameter α much more quickly.

  • Adapting to Distribution Shifts in Healthcare
  • The strengths of some rules may differ between subsets of the data. For example, in disease prediction, the correlation between cardiovascular disease and higher blood pressure is stronger for older patients than younger patients. In such situations, when the task is shared but data distribution and the validity of the rule differ between datasets, DeepCTRL can adapt to the distribution shifts by controlling α.

    Exploring this example, we focus on the task of predicting whether cardiovascular disease is present or not using a cardiovascular disease dataset. Given that higher systolic blood pressure is known to be strongly associated with cardiovascular disease, we consider the rule: “higher risk if the systolic blood pressure is higher”. Based on this, we split the patients into two groups: (1) unusual, where a patient has high blood pressure, but no disease or lower blood pressure, but has disease; and (2) usual, where a patient has high blood pressure and disease or low blood pressure, but no disease.

    We demonstrate below that the source data do not always follow the rule, and thus the effect of incorporating the rule can depend on the source data. The test cross entropy, which indicates classification accuracy (lower cross entropy is better), vs. rule strength for source or target datasets with varying usual / unusual ratio are visualized below. The error monotonically increases as α → 1 because the enforcement of the imposed rule, which doesn’t accurately reflect the source data, becomes more strict.

    Test cross entropy vs. rule strength for a source dataset with usual / unusual ratio of 0.30.

    When a trained model is transferred to the target domain, the error can be reduced by controlling α. To demonstrate this, we show three domain-specific datasets, which we call Target 1, 2, and 3. In Target 1, where the majority of patients are from the usual group, as α is increased, the rule-based representation has more weight and the resultant error decreases monotonically.

    As above, but for a Target dataset (1) with a usual / unusual ratio of 0.77.

    When the ratio of usual patients is decreased in Target 2 and 3, the optimal α is an intermediate value between 0 and 1. These demonstrate the capability to adapt the trained model via α.

    As above, but for Target 2 with a usual / unusual ratio of 0.50.
    As above, but for Target 3 with a usual / unusual ratio of 0.40.

Conclusions
Learning from rules can be crucial for constructing interpretable, robust, and reliable DNNs. We propose DeepCTRL, a new methodology used to incorporate rules into data-learned DNNs. DeepCTRL enables controllability of rule strength at inference without retraining. We propose a novel perturbation-based rule encoding method to integrate arbitrary rules into meaningful representations. We demonstrate three use cases of DeepCTRL: improving reliability given known principles, examining candidate rules, and domain adaptation using the rule strength.

Acknowledgements
We greatly appreciate the contributions of Jinsung Yoon, Xiang Zhang, Kihyuk Sohn and Tomas Pfister.

Categories
Offsites

Does Your Medical Image Classifier Know What It Doesn’t Know?

Deep machine learning (ML) systems have achieved considerable success in medical image analysis in recent years. One major contributing factor is access to abundant labeled datasets, which are used to train highly effective supervised deep learning models. However, in the real-world, these models may encounter samples exhibiting rare conditions that are individually too infrequent for per-condition classification. Nevertheless, such conditions can be collectively common because they follow a long-tail distribution and when taken together can represent a significant portion of cases — e.g., in a recent deep learning dermatological study, hundreds of rare conditions composed around 20% of cases encountered by the model at test time.

To prevent models from generating erroneous outputs on rare samples at test time, there remains a considerable need for deep learning systems with the ability to recognize when a sample is not a condition it can identify. Detecting previously unseen conditions can be thought of as an out-of-distribution (OOD) detection task. By successfully identifying OOD samples, preventive measures can be taken, like abstaining from prediction or deferring to a human expert.

Traditional computer vision OOD detection benchmarks work to detect dataset distribution shifts. For example, a model may be trained on CIFAR images but be presented with street view house numbers (SVHN) as OOD samples, two datasets with very different semantic meanings. Other benchmarks seek to detect slight differences in semantic information, e.g., between images of a truck and a pickup truck, or two different skin conditions. The semantic distribution shifts in such near-OOD detection problems are more subtle in comparison to dataset distribution shifts, and thus, are harder to detect.

In “Does Your Dermatology Classifier Know What it Doesn’t Know? Detecting the Long-Tail of Unseen Conditions”, published in Medical Image Analysis, we tackle this near-OOD detection task in the application of dermatology image classification. We propose a novel hierarchical outlier detection (HOD) loss, which leverages existing fine-grained labels of rare conditions from the long tail and modifies the loss function to group unseen conditions and improve identification of these near OOD categories. Coupled with various representation learning methods and the diverse ensemble strategy, this approach enables us to achieve better performance for detecting OOD inputs.

The Near-OOD Dermatology Dataset
We curated a near-OOD dermatology dataset that includes 26 inlier conditions, each of which are represented by at least 100 samples, and 199 rare conditions considered to be outliers. Outlier conditions can have as low as one sample per condition. The separation criteria between inlier and outlier conditions can be specified by the user. Here the cutoff sample size between inlier and outlier was 100, consistent with our previous study. The outliers are further split into training, validation, and test sets that are intentionally mutually exclusive to mimic real-world scenarios, where rare conditions shown during test time may have not been seen in training.

Long tail distribution of different dermatological conditions in our dataset. The 26 inlier conditions, with at least 100 samples, (blue) and the remaining 199 rare outlier conditions (orange). Outlier conditions can have as low as one sample per condition.
    Train set  Validation set      Test set
Inlier Outlier Inlier Outlier Inlier Outlier
Number of classes 26 68 26 66 26 65
Number of samples 8854 1111 1251 1082 1192 937
Inlier and outlier conditions in our benchmark dataset and detailed dataset split statistics. The outliers are further split into mutually exclusive train, validation, and test sets.

Hierarchical Outlier Detection Loss
We propose to use “known outlier” samples during training that are leveraged to aid detection of “unknown outlier” samples during test time. Our novel hierarchical outlier detection (HOD) loss performs a fine-grained classification of individual classes for all inlier or outlier classes and, in parallel, a coarse-grained binary classification of inliers vs. outliers in a hierarchical setup (see the figure below). Our experiments confirmed that HOD is more effective than performing a coarse-grained classification followed by a fine-grained classification, as this could result in a bottleneck that impacted the performance of the fine-grained classifier.

We use the sum of the predictive probabilities of the outlier classes as the OOD score. As a primary OOD detection metric we use the area under receiver operating characteristics (AUROC) curve, which ranges between 0 and 1 and gives us a measure of separability between inliers and outliers. A perfect OOD detector, which separates all inliers from outliers, is assigned an AUROC score of 1. A popular baseline method, called reject bucket, separates each inlier individually from the outliers, which are grouped into a dedicated single abstention class. In addition to a fine-grained classification for each individual inlier and outlier classes, the HOD loss–based approach separates the inliers collectively from the outliers with a coarse-grained prediction loss, resulting in better generalization. While similar, we demonstrate that our HOD loss–based approach outperforms other baseline methods that leverage outlier data during training, achieving an AUROC score of 79.4% on the benchmark, a significant improvement over that of reject bucket, which achieves 75.6%.

Our model architecture and the HOD loss. The encoder (green) represents the wide ResNet 101×3 model pre-trained with different representation learning models (ImageNet, BiT, SimCLR, and MICLe; see below). The output of the encoder is sent to the HOD loss where fine-grained and coarse-grained predictions for inliers (blue) and outliers (orange) are obtained. The coarse predictions are obtained by summing over the fine-grained probabilities as indicated in the figure. The OOD score is defined as the sum of the probabilities of outlier classes.

Representation Learning and the Diverse Ensemble Strategy
We also investigate how different types of representation learning help in OOD detection in conjunction with HOD by pretraining on ImageNet, BiT-L, SimCLR and MICLe models. We observe that including HOD loss improves OOD performance compared to the reject bucket baseline method for all four representation learning methods.

Representation Learning
Methods
OOD detection metric (AUROC %)
With reject bucket With HOD loss
ImageNet 74.7% 77%
BiT-L 75.6% 79.4%
SimCLR 75.2% 77.2%
MICLe 76.7% 78.8%
OOD detection performance for different representation learning models with reject bucket and with HOD loss.

Another orthogonal approach for improving OOD detection performance and accuracy is deep ensemble, which aggregates outputs from multiple independently trained models to provide a final prediction. We build upon deep ensemble, but instead of using a fixed architecture with a fixed pre-training, we combine different representation learning architectures (ImageNet, BiT-L, SimCLR and MICLe) and introduce objective loss functions (HOD and reject bucket). We call this a diverse ensemble strategy, which we demonstrate outperforms the deep ensemble for OOD performance and inlier accuracy.

Downstream Clinical Trust Analysis
While we mainly focus on improving the performance for OOD detection, the ultimate goal for our dermatology model is to have high accuracy in predicting inlier and outlier conditions. We go beyond traditional performance metrics and introduce a “penalty” matrix that jointly evaluates inlier and outlier predictions for model trust analysis to approximate downstream impact. For a fixed confidence threshold, we count the following types of mistakes: (i) incorrect inlier predictions (i.e., mistaking inlier condition A as inlier condition B); (ii) incorrect abstention of inliers (i.e., abstaining from making a prediction for an inlier); and (iii) incorrect prediction for outliers as one of the inlier classes.

To account for the asymmetrical consequences of the different types of mistakes, penalties can be 0, 0.5, or 1. Both incorrect inlier and outlier-as-inlier predictions can potentially erode user trust in the model and were penalized with a score of 1. Incorrect abstention of an inlier as an outlier was penalized with a score of 0.5, indicating that potential model users should seek additional guidance given the model-expressed uncertainty or abstention. For correct decisions no cost is incurred, indicated by a score of 0.

                  Action of the Model
Prediction as Inlier Abstain
Inlier 0 (Correct)

1 (Incorrect, mistakes
that may erode trust)

0.5 (Incorrect,
abstains inliers)
Outlier     1 (Incorrect, mistakes
that may erode trust)
0 (Correct)
The penalty matrix is designed to capture the potential impact of different types of model errors.

Because real-world scenarios are more complex and contain a variety of unknown variables, the numbers used here represent simplifications to enable qualitative approximations for the downstream impact on user trust of outlier detection models, which we refer to as “cost”. We use the penalty matrix to estimate a downstream cost on the test set and compare our method against the baseline, thereby making a stronger case for its effectiveness in real-world scenarios. As shown in the plot below, our proposed solution incurs a much lower estimated cost in comparison to baseline over all possible operating points.

Trust analysis comparing our proposed method to the baseline (reject bucket) for a range of outlier recall rates, indicated by 𝛕. We show that our method reduces downstream estimated cost, potentially reflecting improved downstream impact.

Conclusion
In real-world deployment, medical ML models may encounter conditions that were not seen in training, and it’s important that they accurately identify when they do not know a specific condition. Detecting those OOD inputs is an important step to improving safety. We develop an HOD loss that leverages outlier data during training, and combine it with pre-trained representation learning models and a diverse ensemble to further boost performance, significantly outperforming the baseline approach on our new dermatology benchmark dataset. We believe that our approach, aligned with our AI Principles, can aid successful translation of ML algorithms into real-world scenarios. Although we have primarily focused on OOD detection for dermatology, most of our contributions are fairly generic and can be easily incorporated into OOD detection for other applications.

Acknowledgements
We would like to thank Shekoofeh Azizi, Aaron Loh, Vivek Natarajan, Basil Mustafa, Nick Pawlowski, Jan Freyberg, Yuan Liu, Zach Beaver, Nam Vo, Peggy Bui, Samantha Winter, Patricia MacWilliams, Greg S. Corrado, Umesh Telang, Yun Liu, Taylan Cemgil, Alan Karthikesalingam, Balaji Lakshminarayanan, and Jim Winkens for their contributions. We would also like to thank Tom Small for creating the post animation.