How to hire JAX and distributed training engineers in 2026: A sourcing guide
JAX powers the training infrastructure at Anthropic and Google DeepMind. Distributed training specialists who can coordinate 5,000+ GPU clusters are among the rarest engineers in tech. The talent pool is small, concentrated at a handful of labs, and not looking for jobs. Here's how to find them anyway.
Anthropic trains Claude on JAX. Google DeepMind trains Gemini on JAX. These are two of the most capable AI systems in the world, and they both run on a framework that most ML engineers have never used in production. Meanwhile, every company building large language models needs engineers who can make training work across thousands of GPUs without burning millions of dollars on wasted compute. That creates one of the tightest labor markets in software engineering.
This is not a guide about hiring general ML engineers. There are plenty of those. This is about the specific, much smaller group of engineers who know how to write functional ML code in JAX, who understand XLA compilation, who have debugged sharding strategies across a pod of TPUs or a cluster of A100s. And it is about the overlapping group of distributed training specialists who coordinate model parallelism, pipeline parallelism, and gradient synchronization at scales where a single misconfigured NCCL call can waste $50,000 in compute per hour.
We covered CUDA and GPU engineers and niche language sourcing strategies in previous guides. JAX and distributed training combine a niche framework with a specialized systems discipline, both driven by a single industry's demand for large-scale AI training.
The JAX and distributed training market in 2026
JAX accounts for roughly 1,200 to 1,500 job postings globally at any point. PyTorch, by contrast, appears in 37.7% of all AI job listings. That ratio does not reflect JAX's quality. It reflects concentration. JAX production usage is almost entirely at Google DeepMind and Anthropic, with smaller pockets at research labs, quantitative trading firms, and a handful of AI startups. When a company posts a JAX role, the candidate pool is tiny.
Distributed training specialists are slightly more numerous but still scarce. Somewhere between 2,000 and 5,000 engineers globally have genuine experience training models on clusters of 5,000 or more GPUs. This is not a soft number that varies based on how you define the skill. Only so many organizations operate at that scale, and the people who have done the work are individually identifiable. They are at Anthropic, Google DeepMind, OpenAI, Meta FAIR, NVIDIA, Microsoft Research, and a small number of well-funded startups.
Compensation matches the scarcity. Distributed training engineers typically earn $200,000 to $400,000 or more in total compensation. At the top end, Anthropic pays engineers between $550,000 and $759,000 TC. Google DeepMind pays comparably for senior distributed systems researchers. These are not inflated numbers. They are what companies pay to attract people with experience that only a few thousand humans have.
Hiring cycles for these roles are long. Expect 60 to 120 days minimum, and some searches take longer. The bottleneck is not your interview pipeline or hiring committee. It is finding people who qualify and who are willing to move. Most distributed training engineers are on multi-year equity vesting schedules at companies with strong retention packages. They are not browsing job boards or updating their LinkedIn. They are training models.
Demand keeps growing. Every major AI lab is scaling training clusters. Enterprise companies are building internal training capabilities. Startups are raising rounds specifically to train foundation models. Sovereign AI programs in the EU, Middle East, and Asia are building national compute infrastructure and need engineers to run it. The supply of qualified candidates is growing too, but much more slowly. University ML programs produce thousands of graduates annually, but the jump from "can fine-tune a model on a single GPU" to "can coordinate a 10,000-GPU training run with custom XLA kernels and pipeline parallelism" takes years of hands-on work at the right organizations.
Why JAX is harder to hire for than PyTorch
If you have hired ML engineers before, you might assume JAX hiring is the same process with a different keyword filter. It is not. The gap between JAX and PyTorch hiring is structural, not cosmetic.
The functional paradigm. JAX is a functional framework. Pure functions, immutable arrays, explicit random number key management, no hidden state. PyTorch is imperative and object-oriented: you subclass nn.Module, mutate parameters in place, and rely on autograd's tape-based differentiation. These are different ways of thinking about computation. A strong PyTorch engineer cannot switch to JAX over a weekend. They have to unlearn patterns they have used for years: mutable state, in-place operations, implicit control flow in model definitions. The shift is closer to Java to Haskell than React to Vue.
XLA compilation. JAX compiles computation graphs through XLA (Accelerated Linear Algebra), which means code that works eagerly may not compile under jax.jit. Data-dependent control flow, dynamic shapes, Python-side effects inside jitted functions all break. An experienced JAX engineer thinks about what will and will not trace through the compiler at write time. That skill takes months to develop and does not transfer from PyTorch, which uses eager execution by default.
Community size. PyTorch has millions of users across academia and industry. The ecosystem is huge: tutorials, Stack Overflow answers, blog posts, courses, videos, books. When a PyTorch engineer hits a problem, they Google it and usually find an answer. JAX's community is a fraction of that size. The documentation is good but not as comprehensive. Many problems require reading the JAX source code, XLA internals, or asking on a GitHub issue. JAX engineers tend to be more self-sufficient and have stronger debugging instincts, but fewer engineers choose to invest the learning time.
Two-company concentration. Google DeepMind and Anthropic account for the majority of production JAX usage. Almost everyone with production JAX experience works at one of two organizations, both of which pay extremely well and work on some of the most interesting problems in the field. Convincing someone to leave requires a compelling technical mission, competitive compensation, or both. There is no pool of JAX engineers at mid-tier companies waiting for a better offer.
Anthropic's decision to build on JAX is worth understanding as a market signal. They chose JAX not because it was popular but because functional purity and XLA compilation fit large-scale, reproducible training better. When you coordinate a training run across thousands of devices, you need deterministic behavior, explicit state management, and a compilation model that can optimize across the entire computation graph. JAX gives you that. PyTorch's flexibility and ease of use help with research iteration on a single GPU, but they become liabilities at scale where every implicit state mutation is a potential source of non-determinism.
Where JAX and distributed training engineers contribute on GitHub
JAX and distributed training engineers are disproportionately active on GitHub. The core frameworks, neural network libraries, training infrastructure, and research implementations are all developed in the open. GitHub is the most effective sourcing channel for these engineers because the code they write tells you exactly what they know and how well they know it.
JAX core. jax-ml/jax is the framework itself: the function transformations (jit, grad, vmap, pmap), the NumPy-compatible API, and the XLA compilation backend. Contributors to this repo work at the lowest level of the stack. They understand how JAX traces Python functions into XLA HLO graphs, how device placement works, and how the sharding system maps logical arrays to physical devices. Even substantial issue participation here signals someone operating at the framework level, not just using it.
Neural network libraries. JAX does not have a single dominant neural network library the way PyTorch has torch.nn. There are several competing approaches, and a candidate's choice tells you something about their style. google/flax is Google's official library, the most widely used, and the one most JAX tutorials teach. It uses a "lifted transform" pattern and explicit parameter management. deepmind/haiku was DeepMind's library, designed to feel more like TF1-era Sonnet with explicit init/apply separation. patrick-kidger/equinox is the most "Pythonic" option: it uses PyTrees to represent modules, embraces JAX's functional model fully, and has a growing following among researchers who want clean, composable code. Equinox contributors tend to be the strongest JAX programmers because the library requires deep understanding of JAX's transformation system.
Distributed training frameworks. NVIDIA/Megatron-LM is the reference implementation for large-scale transformer training. It implements tensor parallelism, pipeline parallelism, and data parallelism for models that cannot fit on a single GPU. Megatron-LM contributors have direct experience with the hardest part of distributed training: getting parallelism strategies to work correctly and efficiently across hundreds or thousands of GPUs. microsoft/DeepSpeed provides ZeRO optimization (partitioning optimizer state, gradients, and parameters across devices), mixed-precision training, and pipeline parallelism. DeepSpeed is used by a broader range of organizations than Megatron because it integrates with the Hugging Face ecosystem. Its contributors understand memory optimization, communication patterns, and the tradeoffs between parallelism strategies.
PyTorch distributed. Even if you are hiring for a JAX stack, PyTorch's distributed training modules are relevant. pytorch/pytorch includes FSDP (Fully Sharded Data Parallel), which has become the standard approach for distributed training in the PyTorch ecosystem. Engineers who have contributed to or extensively used FSDP understand the same core concepts that JAX's pjit and sharding system address, just with different APIs. The mental model transfers even if the code does not.
XLA and compiler infrastructure. openxla/xla is the compiler backend that JAX targets. Engineers who work at this level understand how high-level operations become low-level device instructions. They can write custom XLA operations, optimize memory layout, and debug compilation failures. This is uncommon expertise. Anyone contributing to the XLA repository is worth contacting immediately. openxla/stablehlo, the portable ML operation set, is another signal of deep compiler knowledge.
Research implementations. google-deepmind/alphafold is written in JAX and Haiku. google-deepmind/gemma provides open weights with JAX-native implementations. Many NeurIPS, ICML, and ICLR papers release JAX research code. Contributors to these repositories bridge research and engineering. They can read a paper, implement the math in JAX, and make it work on real hardware. This combination of mathematical understanding and systems skill is exactly what training teams need.
Quality signals in JAX and distributed training code
Reading JAX code requires different evaluation criteria than reading PyTorch code. The functional paradigm changes what "good" looks like. General seniority signals on GitHub still apply, but JAX and distributed training have specific markers that distinguish experienced practitioners from people who followed a tutorial.
Function transformation composition. JAX's core power is composable transformations: jit for compilation, grad for automatic differentiation, vmap for automatic batching, pmap for single-program-multiple-data parallelism. An experienced JAX engineer composes these fluently. They write a loss function, take its gradient, vectorize it across a batch, and compile the result: jit(vmap(grad(loss_fn))). The order of transformations matters, and getting it wrong produces subtle bugs or compilation failures. A candidate who can explain why vmap(jit(f)) and jit(vmap(f)) have different performance characteristics understands JAX at a fundamental level.
Explicit state management. In JAX, there is no hidden state. Random number generators require explicit key splitting. Model parameters are passed as arguments, not stored as object attributes. Optimizer state is a separate data structure that flows through the training loop. An experienced JAX engineer treats this as a feature, not a burden. Their code makes state flow visible: you can read a training step function and see exactly what goes in, what comes out, and what changes. Compare this to a developer who fights the functional model by stuffing state into global variables or using mutable containers to work around immutability. That is a sign of someone translating imperative habits into JAX rather than thinking functionally.
Sharding annotations and strategies. For distributed training, sharding is where theory meets reality. JAX's jax.sharding module and the NamedSharding/PositionalSharding APIs let you specify how arrays are distributed across devices. An experienced engineer chooses sharding strategies based on the communication patterns of their model: which dimensions to shard, how to handle all-reduce operations, when to replicate vs partition. Look for code that defines explicit mesh shapes, uses with_sharding_constraint to control intermediate computations, and handles the transition between different sharding specs across layers. Bad sharding code either replicates everything (wasting memory) or shards naively (creating communication bottlenecks).
Custom XLA kernels. At the extreme end, some engineers write custom XLA operations using jax.extend or Pallas (JAX's kernel language for writing custom GPU/TPU operations). This is rare and highly valuable. Anyone who has written a custom XLA kernel that compiles and runs correctly on real hardware understands the full stack from Python down to the device. It is the JAX equivalent of writing custom CUDA kernels and signals deep systems expertise.
Gradient checkpointing and memory optimization. Large model training requires careful memory management. jax.checkpoint (also called jax.remat) lets you trade compute for memory by recomputing activations during the backward pass instead of storing them. An experienced engineer uses gradient checkpointing strategically: applying it to specific layers based on their memory footprint and recomputation cost, not blanket-applying it everywhere. They also understand activation memory vs parameter memory vs optimizer state memory and can estimate the total memory budget for a training run before launching it.
Communication patterns in distributed code. In any distributed training framework, the choice of communication primitives matters enormously. All-reduce, all-gather, reduce-scatter, point-to-point: each has different bandwidth and latency characteristics. An experienced distributed training engineer thinks about communication topology, overlaps communication with computation, and minimizes synchronization points. In their code, you will see explicit handling of gradient accumulation steps, careful placement of synchronization barriers, and awareness of network bandwidth constraints. The difference between a 70% and 95% GPU utilization rate during distributed training often comes down to how well communication is overlapped with compute.
Flax/Haiku/Equinox patterns. The neural network library a candidate uses and how they use it is informative. Flax users who leverage nn.scan for recurrent patterns, nn.compact vs setup methods appropriately, and handle variable collections cleanly show framework mastery. Equinox users who write models as filtered PyTrees and use eqx.filter_jit and eqx.filter_grad show deep understanding of JAX's type system. Haiku users who properly separate init and apply and manage state without leaking it between calls demonstrate clean functional design.
The distributed training stack: data, model, and pipeline parallelism
If you are sourcing distributed training engineers, you need to understand what they actually do. This section is not a full technical tutorial, but it gives you enough context to evaluate candidates and write outreach that does not sound generic.
Data parallelism is the simplest form: replicate the model on every device, split the batch across devices, compute gradients independently, then synchronize gradients via all-reduce. Every distributed training engineer understands this. It is the baseline. The subtlety is in gradient synchronization: when to use synchronous vs asynchronous updates, how to handle gradient accumulation for effective batch sizes larger than what fits in memory, and how to scale learning rates when you scale the number of devices. An engineer who only knows data parallelism is not a distributed training specialist. It is the entry point.
Model parallelism splits the model itself across devices. Tensor parallelism shards individual layers (splitting a large matrix multiplication across multiple GPUs). This is what Megatron-LM pioneered for transformers: splitting attention heads and feed-forward layers across a tensor-parallel group, typically within a single node where NVLink provides high-bandwidth inter-GPU communication. An engineer who understands tensor parallelism can explain how column-parallel and row-parallel linear layers work, why the communication pattern is an all-reduce in the forward pass and a different all-reduce in the backward pass, and how to handle the activation memory implications.
Pipeline parallelism splits the model into stages, with each stage running on different devices. Micro-batches flow through the pipeline, allowing stages to compute in parallel. The challenge is "pipeline bubbles": periods when stages are idle waiting for data. GPipe, PipeDream, and interleaved schedules each handle this differently, trading off memory, throughput, and implementation complexity. An engineer who can discuss the tradeoffs between 1F1B (one forward, one backward) schedules and interleaved schedules, and who understands how pipeline parallelism interacts with tensor parallelism, has real experience at scale.
ZeRO optimization (from DeepSpeed) partitions optimizer states, gradients, and optionally parameters across data-parallel ranks. ZeRO Stage 1 shards optimizer states. Stage 2 adds gradient partitioning. Stage 3 adds parameter partitioning, which allows training models larger than any single GPU's memory. This is conceptually similar to PyTorch FSDP and JAX's sharding-based approaches, but with different implementation tradeoffs. An engineer who can compare ZeRO, FSDP, and JAX's approach to the same problem understands distributed training at a conceptual level, not just an API level.
Mixed-precision training uses lower-precision formats (FP16, BF16, FP8) for most computation while maintaining FP32 master weights for numerical stability. This roughly doubles throughput and halves memory usage compared to full FP32 training. But getting mixed precision right at scale is tricky: loss scaling to prevent underflow, handling operations that require higher precision (like normalization layers), and managing the interaction between mixed precision and distributed sharding. BF16 has become the default at most labs because it has the same exponent range as FP32, avoiding the loss scaling complexity of FP16. FP8 is emerging for even greater throughput on newer hardware (H100, B200) but requires careful handling of quantization ranges.
A strong distributed training candidate can discuss all of these parallelism strategies and, more importantly, explain when to use which combination. A 7B parameter model does not need the same parallelism strategy as a 400B model. The right combination depends on model architecture, cluster topology, network bandwidth, and training budget. The ability to reason about these tradeoffs, not just implement one strategy, is what separates a distributed training engineer from someone who ran a DeepSpeed script once.
How to search for these engineers on GitHub
GitHub search for JAX and distributed training engineers requires different tactics than searching for web developers or even general ML engineers. The niche stack sourcing approaches apply here with adjustments specific to the ML infrastructure world.
Repository contributor graphs. Start with the repositories listed above: jax-ml/jax, google/flax, patrick-kidger/equinox, NVIDIA/Megatron-LM, microsoft/DeepSpeed. Pull up the contributor list and sort by recent activity. Ignore bot accounts and automated commits. Focus on people who write substantial code changes, review others' pull requests, and participate in design discussions. A contributor who added a new sharding strategy to Megatron-LM or fixed an XLA compilation issue in JAX core has demonstrated exactly the expertise you are hiring for.
Personal training implementations. Many JAX engineers write their own training loops, model implementations, or research reproductions as personal projects. Search for repositories with names containing "jax-", "distributed-training", or "large-scale-ml". Look for README files that describe custom sharding strategies, multi-node training setups, or XLA optimization experiments. A personal repo that implements a transformer from scratch in JAX with proper sharding, mixed precision, and gradient checkpointing is worth more than a dozen PyTorch tutorial completions.
Paper implementations. ML researchers often release code alongside their publications. Search GitHub for repository descriptions mentioning NeurIPS, ICML, or ICLR paper titles related to distributed training, model parallelism, or communication efficiency. These repositories are typically short-lived (posted around conference deadlines), but their authors are exactly the people who understand both the theory and implementation of large-scale training.
Conference and workshop activity. The MLSys conference, the Distributed ML Workshop at NeurIPS, and the Efficient ML workshop series produce publicly available papers and talks. Cross-referencing author names with GitHub profiles surfaces engineers who both publish research on distributed training and write code that implements it. This dual signal is strong: many ML researchers write proof-of-concept code but not production systems. Those who do both are the most valuable hires for training teams.
Cross-framework contributions. Engineers who contribute to both JAX and PyTorch distributed training infrastructure have the broadest understanding of the design space. They have seen how different frameworks solve the same problems and can make informed tradeoff decisions. Look for GitHub profiles with contributions to both jax-ml/jax and pytorch/pytorch, or to both NVIDIA/Megatron-LM and a JAX-based training framework. These people are rare, and they are worth reaching out to even if they are currently employed somewhere you assume they would not leave.
A practical sourcing workflow
Here is a step-by-step approach, from role definition to outreach.
Step 1: Separate JAX from distributed training in the role definition. These are related but different skills. A JAX expert who has trained a 1B model on 8 GPUs is not the same as a distributed training engineer who has coordinated 10,000 H100s using Megatron-LM. Some roles need both. Others need one or the other. A research scientist writing new architectures in JAX needs JAX fluency but may not need deep distributed systems expertise. An infrastructure engineer building the training platform needs distributed systems skills and should be comfortable with JAX but does not need to write novel architectures. Get this distinction right before you start sourcing. Blurring it wastes everyone's time.
Step 2: Map the role to target repositories. For JAX-focused roles: jax-ml/jax, google/flax, patrick-kidger/equinox, deepmind/haiku, openxla/xla. For distributed training roles: NVIDIA/Megatron-LM, microsoft/DeepSpeed, PyTorch FSDP code within pytorch/pytorch. For roles that need both: look for contributors who appear in repos from both lists, or who have personal repositories implementing distributed training in JAX. For research-adjacent roles: add google-deepmind/alphafold, google-deepmind/gemma, and recent paper implementation repos.
Step 3: Extract and filter contributors. Use GitHub's contributor views or tools like riem.ai that index GitHub events at scale. For each target repo, identify contributors from the last 6 to 12 months with meaningful code contributions. Filter out: bot accounts, single-typo-fix contributors, documentation-only contributors (unless the docs show deep technical knowledge). For distributed training repos like Megatron-LM, even a small code contribution is significant because the codebase is complex and the bar for getting changes merged is high.
Step 4: Evaluate individual profiles. For each candidate, look for the quality signals described above. Do they compose JAX transformations correctly? Do their personal repos show clean functional style? For distributed training candidates, look for evidence of working at scale: configuration code for multi-node setups, custom communication kernels, sharding strategy implementations. Cross-project contributions and code review activity are strong seniority signals in this community, just like anywhere else. Check if they review pull requests, not just author them. Review participation in JAX core or Megatron-LM indicates deep expertise and community trust.
Step 5: Expand to adjacent pools. Pure JAX engineers are rare. Expand your search to: PyTorch distributed training engineers (the parallelism concepts transfer), CUDA and GPU engineers (they understand the hardware stack), XLA and compiler engineers (they understand JAX's compilation model), and researchers who publish in ML systems venues (MLSys, OSDI, SOSP). Also consider engineers at hardware companies (NVIDIA, AMD, Intel) who work on training framework support. They may not write JAX day-to-day, but they understand the lower levels of the stack that JAX compiles to. Engineers with strong functional programming backgrounds (Haskell, OCaml, Scala) can pick up JAX's paradigm faster than most. We covered the CUDA and GPU engineering talent pool in detail, and there is real overlap with the distributed training pool.
Step 6: Write outreach that demonstrates understanding. These engineers receive recruiter messages constantly. Generic outreach gets deleted. Effective developer outreach must reference specific work: "I saw your implementation of interleaved pipeline parallelism in the Megatron-LM repo. We're building a training platform that needs to handle heterogeneous GPU clusters, and your approach to handling different micro-batch schedules across pipeline stages is exactly the kind of thinking we need." Mention the scale you operate at (number of GPUs, model sizes), the technical problems that are unsolved, and why their specific contribution is relevant. Do not say "we're looking for a senior ML engineer." Say what you are building and why their GitHub work tells you they can help build it.
Step 7: Automate discovery. Manually scanning contributor lists across a dozen repositories, cross-referencing profiles, and evaluating code quality for each candidate does not scale. Tools like riem.ai analyze 30 million-plus GitHub events per month and can surface JAX and distributed training engineers based on contribution patterns. Instead of spending hours on GitHub's contributor page, you describe the technical profile ("JAX engineers who contribute to Flax or Equinox with experience in sharding strategies and multi-device training") and get a ranked list with contribution summaries, quality scores, and context about what each person has actually built. The initial candidate list that takes days manually takes minutes with the right tooling.
Frequently asked questions
How many JAX engineers are there?
JAX has a small but concentrated developer base. There are roughly 1,200 to 1,500 JAX-specific job postings globally at any given time, compared to PyTorch's 37.7% share of all AI job postings. The active JAX contributor community is heavily concentrated at two organizations: Google DeepMind and Anthropic. Outside those two, adoption exists at research labs, some hedge funds, and a handful of AI startups, but the pool of production JAX engineers is measured in the low thousands. This concentration makes JAX engineers among the hardest AI specialists to source.
What salary should I expect to pay JAX and distributed training engineers?
Distributed training engineers command $200,000 to $400,000 or more in total compensation. At the top end, Anthropic pays engineers between $550,000 and $759,000 in total compensation. Google DeepMind compensation is similarly aggressive. These numbers reflect extreme scarcity: the number of people who have trained models across 5,000 or more GPUs is tiny. Even at lower seniority, distributed training experience commands a significant premium over general ML engineering. Companies that cannot match top-of-market cash compensation need to compete on equity, research freedom, or the quality of their compute infrastructure.
Should I require JAX experience or accept PyTorch engineers?
It depends on what you are building. JAX's functional paradigm is genuinely different from PyTorch's imperative style. A PyTorch engineer cannot switch to JAX overnight. They need to unlearn mutable state, object-oriented module design, and in-place operations. That said, strong PyTorch engineers with functional programming experience (Haskell, OCaml, Scala) or deep XLA/compiler knowledge can transition faster. For distributed training roles specifically, the parallelism concepts (data parallel, model parallel, pipeline parallel) transfer well even if the implementation APIs differ. If your stack is JAX, prefer candidates with JAX experience for senior roles. For mid-level positions, strong PyTorch engineers who show interest in functional ML are viable if you can invest in ramp-up time.
What GitHub repositories should I look for on candidate profiles?
For JAX core: jax-ml/jax. For neural network frameworks: google/flax, deepmind/haiku, patrick-kidger/equinox. For distributed training: NVIDIA/Megatron-LM, microsoft/DeepSpeed, pytorch/pytorch (FSDP modules). For XLA and compilers: openxla/xla, google/jaxlib. For specific domains: google-deepmind/alphafold, google-deepmind/gemma. Contributors to Equinox or JAX core tend to be the most technically strong. Megatron-LM and DeepSpeed contributors have direct large-scale training experience. Also look for custom training loop implementations, sharding strategy code, and XLA kernel contributions in personal repositories.
How long does it take to hire a distributed training engineer?
Expect 60 to 120 days minimum. The candidate pool is extremely small: perhaps 2,000 to 5,000 people globally who have real experience training models on clusters of 5,000 or more GPUs. Most of them work at well-funded AI labs (Anthropic, Google DeepMind, OpenAI, Meta FAIR, NVIDIA) with strong retention incentives. Non-competes are not common, but golden handcuffs via equity vesting schedules are. Sourcing from GitHub contribution data can shorten the timeline by identifying engineers who are actively building distributed training infrastructure but may not be on traditional job boards. Referral networks within the ML research community are also high-yield.
Why is JAX harder to hire for than PyTorch?
Three reasons. First, the community is much smaller. PyTorch dominates academic ML and has millions of users; JAX has a fraction of that. Second, JAX requires a functional programming mindset (pure functions, immutable state, explicit random number handling) which is a steeper learning curve for most ML practitioners trained on imperative PyTorch or TensorFlow. Third, JAX production usage is concentrated at essentially two companies (Google DeepMind and Anthropic), which means the pool of people with production JAX experience at scale is very small and those people are well-compensated. PyTorch engineers are abundant by comparison. JAX engineers are rare, expensive, and not looking.
Find the engineers who've already built it
Search 30M+ monthly GitHub events. Match on real code, not resumes.
Get started