Scaling a Model in a Month: ML Engineering at Reticular AI (YC F'24)

Introduction

In Winter 2025 I had the unique opportunity to work at Reticular AI, a Y Combinator F'24 startup focused on AI for biology.

We published Towards Interpretable Protein Structure Prediction with Sparse Autoencoders at ICLR workshops and as of July 2025 we were highlighted in Anthropic's highly impactful Transformer Circuits Thread July Updates.

My role: I owned scaling our model. Keeping compute costs/SAE roughly the same, my contribution was scaling the model 10x and cutting training time 10x.

For my efforts I was second-author with the two founders but we were all working tirelessly around the clock on different parts of the paper for the better part of two months.

What the Scaling Unlocked

When I joined, the research questions were completely hidden behind the golden door of "a scaled model." We didn't know what embeddings we were scaling to, nor what performance we'd see in any given metric. We didn't even have scaling laws to work off of. That's the beauty of working in research — "this" had never been done.

The scaled model had two contributions: size and speed.

Size gave us our headline. We realized we were extending interpretability to structure prediction — something nobody saw coming until I proved we could train SAEs efficiently on ESM-3B (the backbone of ESMFold). From there we could steer structure predictions like the paper touts.

We didn't see crazy results from our initial scaled SAEs — it's not as if gold fell from the embedding layers and we just picked it up and bundled it. It was a matter of having more vectors to tease out possible interesting research from.

Speed gave us room to actually do research. Broader hyperparameter sweeps and enough runs to experiment with different SAE architectures altogether. That's what led to the Matryoshka finding — structure prediction performance recoverable with only 8 to 32 active latents. Fewer than you'd expect for something as complex as a 3D fold.

Neither result was on the roadmap when I started. Both fell out of simply having the scale to look.

My Challenge

When I joined, the team had already found their interesting problem. We were going to extend this paper: InterPLM

This paper had found that we could mechanistically interpret features from protein language model layer embeddings with SAEs, but had interpreted only the smallest open-source ESM-8M model.

We spoke with Elana (author of InterPLM) on why she wasn't able to scale any more aggressively. At Stanford the types of GPUs at the time (2024) were all gaming GPUs limited on memory (think RTX 4070 12GB). On any given day she couldn't guarantee usage of the same GPU to use within the Stanford server, affecting batch size and model size considerations.

Working at a company with all of these new YC-offer compute credits to burn was a huge opportunity to extend this work.

I had a good fit because I was (at that point) newly armed with systems and ML knowledge to scale an in-house model that the two non-CS founders would have likely outsourced. In fact, before I came on, Reticular was considering paying another startup 150k+ on training this model. We ended up saving that cost by having me own developing the entire in-house training pipeline, a very special opportunity that I am proud to say I made the most of.

What we were Scaling

To explain my contribution it's not even necesssary to deep dive SAE architecture too much, so for simplicity's sake I'll just drop a link to the seminal Anthropic blog post here and abstract the models as I see fit for blog post level detail.

SAEs are used for interpreting transformer models though, with the activations of a NN layer as we run inference being the input to the SAE model.

Our goal was to scale the InterPLM model to interpreting layers from models larger than ESM-8M. Our end-goal wasn't fully decided at first, but we chose 3B because of its good performance and that it unlocked the ability to link the model with ESM-Fold, which unlocked "structure prediction."

In terms of providing some quantitative heuristics, we had a parameter increase from 8 million to 3 billion (375x increase). This sounds like a complete redesign is in order, but SAEs train on layers, which only had an increase in embedding dimension (our input space) from 320 to 2560 (8x). This means we still had to speed up our training by a lot, but only by about 100x, and our memory usage wouldn't be overly taxed by scaling the model (once again about 100x).

InterPLM had gotten its results with 10 million proteins from UniProt as the inputs to embeddings, and we ended up using hte same amount for our final SAEs. To get our data (embeddings) from these sequences, we had to run inference using ESM-3B and pull the activations and store them. It just about ended up that 10 million embeddings was about 10TB(!) of data going through the model.

In terms of time for training, we wanted it to be as fast as possible, as we didn't just want to train one SAE for each layer and be done, but we wanted to iterate on the dictionary learning framework, as well as hyperparameter tune as much as possible.

Skipping forward to the result, we were able to train an SAE in 12 hours on a 8xL40S instance, and getting it that fast was a huge bonus for pushing the paper out.

Getting the Input Materials

The first goal was as follows: get the embeddings for ESM-3B.

This ended up requiring our first process improvement, which was asynchronous job scheduling.

We were using Lighting AI as our compute platform, a useful service which abstracts some of the headaches of AWS Sagemaker into a browser click and run items.

By running our inference tasks on several cheap AWS GPU ec2 instances, Lightning provided a place where it took only a couple of days to grab all of the embeddings we needed.

After we had started doing huge data like 3B, necessitating dataloader optimization, we started using S3 buckets to store our embeddings, but we started by storing in filesystems, which promptly broke at scale.

Distributing Training

I'm going to handwave some of this section, and just say that Pytorch Lightning was a mature and very easy to use library for our case. It is an open source project that abstracts distributed training from the user. This was super useful as we could just run our models on as many GPUs as we needed.

This gave us nearly a clean 8x bump because we could go from using 1xGPU machines to 8x.

I went through and rewrote the entire training step to fit the Pytorch Lightning framework, while making sure our code was still learning the same way. The way we verified this was by tracking our loss on weights and biases.

Optimizing Dataset, Dataloader -- Using Streaming

I motivated the problem by stating our input data was, in totality on the magnitude of 10TB for 10M protein sequences of ESM-3B embeddings. To feed all of this into our model, the dataloader would need to be incredibly good. This is doubly true in the context of our model. The SAE is an incredibly simple model, where we merely run one forward step, one filtering step, and one backprop on a single matmul. This means our training speed was completely hindered by our data ingestion speed.

When we profiled our training step though, we found our data ingestion model was very slow in two points, pulling from the dataset as well as the dataloader. Lightning AI provided us with a streaming dataset/dataloader library called Litdata to improve these.

Decision-Making

At this point we had 3 weeks to finalize our SAEs we would analyze deeply for submission. Each SAE could take upwards of a week to run on 5 epochs even with distributed training. We had so many hyper parameters to tune and work with that we decided 3 full iterations was suboptimal and would severely hinder our ability to tease out interesting results.

Getting Things to Work

To use litdata we had to preprocess our dataset into Litdata's "StreamingDataset".

The conversion to StreamingDataset happened with the .optimize() function, which would never terminate on our system. Speed would curb hard before we could hit 1TB of optimized data.

When profiling, we realized a number of weak points in our system. The first was that storing in filesystems gets too slow. We had so much file I/O overhead trying to optimize our embeddings that we had to move to AWS S3, which helped a lot.

At this point, Litdata had a glitch where s3 wasn't properly supported. I helped diagnose the problem with (more) profiling and suggested the fix to the researchers who own the project.

The second issue occurred with what I fed to each stage of .optimize().

InterPLM and us had been prebatching our data by randomizing the order we passed tokens in but stored the embeddings together in one file. This file would then be loaded and processed as a single batch.

Example: In our case, we had 2048 activations as our batch dimension (being fed into GPU during one training step), meaning we would need to load singular activations one by one and then recombine these activations into one 2048 sized batch.

It took me awhile to see I was shooting myself in the foot because I would optimize one embedding at a time. This lost us our 2048x prebatching causing us to rebatch our shuffled dataloader's activations.

Our platform weakness boiled down to job management. We had a lot of multithreading that needed to be handled in Python. Luckily, pytorch provides workers that can be referenced by PID. Leveraging this, I made every single job massively parallel, with our async job scheduler and our cpu workers all being maximally leveraged.

I believe I had three 96core VMs running with 96 workers each at some points. I also definitely had 40 L40S's at once doing inference for pulling activations sometimes, not to mention hyperparameter sweep GPU usage.

After adjusting the training script to use the new batching, our previously week long training steps would complete in less than a day, a very ergonomic amount of time for quick iteration speed and efficient hyperparameter tuning.

Wrapping Up

The three weeks getting Litdata working ended up as the most interesting contribution I'd ever made to any project at the time. Processing our data with Litdata was definitely a stretch goal but ended up paying divdends as a clean 10x difference in speed.

Overall playing with hundreds of thousands in compute for frontier ML research is not a chance juniors in undergrad often get to have, and it never stopped amusing me to send off runs of thousands dollars of compute.