JAX Training (JAXJob)

Using JAXJob to train a model with JAX

This page describes JAXJob for training a machine learning model with JAX.

The JAXJob is a Kubernetes custom resource to run JAX training jobs on Kubernetes. The Kubeflow implementation of the JAXJob is in the training-operator.

The current custom resource for JAX has been tested to run multiple processes on CPUs using gloo for communication between CPUs. Worker with replica 0 is recognized as a JAX coordinator. Process 0 will start a JAX coordinator service exposed via the IP address of process 0 in your cluster, together with a port available on that process, to which the other processes in the cluster will connect. We are looking for user feedback to run JAXJob on GPUs and TPUs.

Creating a JAX training job

You can create a training job by defining a JAXJob config file. See the manifests for the simple JAXJob example. You may change the Job config file based on your requirements.

Deploy the JAXJob resource to start training:

kubectl create -f https://raw.githubusercontent.com/kubeflow/training-operator/refs/heads/master/examples/jax/cpu-demo/demo.yaml

You should now be able to see the created pods matching the specified number of replicas.

kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-simple

Distributed computation takes several minutes on a CPU cluster. Logs can be inspected to see its progress.

PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow)
kubectl logs -f ${PODNAME} -n kubeflow
I1016 14:30:28.956959 139643066051456 distributed.py:106] Starting JAX distributed service on [::]:6666
I1016 14:30:28.959352 139643066051456 distributed.py:119] Connecting to JAX distributed service on jaxjob-simple-worker-0:6666
I1016 14:30:30.633651 139643066051456 xla_bridge.py:895] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I1016 14:30:30.638316 139643066051456 xla_bridge.py:895] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
JAX process 0/1 initialized on jaxjob-simple-worker-0
JAX global devices:[CpuDevice(id=0), CpuDevice(id=131072)]
JAX local devices:[CpuDevice(id=0)]
JAX device count:2
JAX local device count:1
[2.]

Monitoring a JAXJob

kubectl get -o yaml jaxjobs jaxjob-simple -n kubeflow

See the status section to monitor the job status. Here is sample output when the job is successfully completed.

apiVersion: kubeflow.org/v1
kind: JAXJob
metadata:
  annotations:
    kubectl.kubernetes.io/last-applied-configuration: |
      {"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-simple","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"command":["python3","train.py"],"image":"docker.io/kubeflow/jaxjob-simple:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}}      
  creationTimestamp: "2024-09-22T20:07:59Z"
  generation: 1
  name: jaxjob-simple
  namespace: kubeflow
  resourceVersion: "1972"
  uid: eb20c874-44fc-459b-b9a8-09f5c3ff46d3
spec:
  jaxReplicaSpecs:
    Worker:
      replicas: 2
      restartPolicy: OnFailure
      template:
        spec:
          containers:
          - command:
            - python3
            - train.py
            image: docker.io/kubeflow/jaxjob-simple:latest
            imagePullPolicy: Always
            name: jax
status:
  completionTime: "2024-09-22T20:11:34Z"
  conditions:
  - lastTransitionTime: "2024-09-22T20:07:59Z"
    lastUpdateTime: "2024-09-22T20:07:59Z"
    message: JAXJob jaxjob-simple is created.
    reason: JAXJobCreated
    status: "True"
    type: Created
  - lastTransitionTime: "2024-09-22T20:11:28Z"
    lastUpdateTime: "2024-09-22T20:11:28Z"
    message: JAXJob kubeflow/jaxjob-simple is running.
    reason: JAXJobRunning
    status: "False"
    type: Running
  - lastTransitionTime: "2024-09-22T20:11:34Z"
    lastUpdateTime: "2024-09-22T20:11:34Z"
    message: JAXJob kubeflow/jaxjob-simple successfully completed.
    reason: JAXJobSucceeded
    status: "True"
    type: Succeeded
  replicaStatuses:
    Worker:
      selector: training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker
      succeeded: 2
  startTime: "2024-09-22T20:07:59Z"

Feedback

Was this page helpful?


Last modified October 22, 2024: Add documentation for JAXJob (#3877) (463843e)