Mittwoch, 7. August 2019

TensorFlow mit Java



"Geht nicht!", wird der geneigte Leser in anbetracht der Überschrift nun vielleicht denken. Und auch wenn es nicht so komfortabel ist wie mit Python, es geht doch: TensorFlow von Java aus aufrufen.

Wenn Sie zu den Lesern gehören sollten, die nun fragen, warum man so etwas tun sollte, können Sie sich das Weiterlesen vermutlich sparen. Ich finde es sehr praktisch, dass man den Einstieg in Googles Machine Learning auch als Java-Programmierer finden kann, ohne sich zusätzlich in eine neue Programmiersprache einarbeiten zu müssen. Wer hier allerdings tiefer einsteigen und bspw. eigene Modelle trainieren möchte, wir über kurz oder lang um Python vermutlich nicht herum kommen, da die Unterstützung für Java leider eher dürftig ist.

Das Einbinden von TensorFlow in ein Java-Progamm ist dank folgender Maven-Dependency jedenfalls hinreichend einfach:

   <dependency>
      <groupId>org.tensorflow</groupId>
      <artifactId>tensorflow</artifactId>
      <version>1.13.1</version>

  </dependency>

Alternativ lässt sich natürlich auch das entsprechende JAR-File herunterladen und direkt in den Classpath legen. Die aktuelleste Version ist übrigens 1.14.0, diese läuft auf meinem Mac allerdings nicht, so dass ich auf 1.13.1 zurückgegriffen habe, was dann auch problemlos funktioniert.

Das Kopieren des reichlich nichtssagenden Einstiegsbeispiels von der TensorFlow-Webseite spare ich mir an dieser Stelle, wer damit testen möchte, kann es leicht von dort kopieren. Bevor ich gleich ein etwas nachvollziehbareres Beispiel zum Besten gebe, möchte ich an dieser Stelle aber doch ein Worte zu den Hintergründen verlieren.

Wie die allermeisten Leser sicher bereits gehört haben, ist TensorFlow ein von Google entwickeltes und als Open-Source verbreitetes Framework für Machine Learning, das sich auf Grund seiner vielseitigen Einsetzbarkeit (sowohl was die Einsatzmöglichkeiten, als auch was die Plattformen angeht) einer großen Beliebtheit erfreut.

Von der Grundidee her, verwendet TensorFlow (eine Art Datenfluss-)Graphen, um Daten (also die Tensoren) zu verarbeiten. Ein Tensor ist übrigens nichts weiter als ein Array von Daten mit einer quasi beliebigen Anzahl von Dimensionen. Im einfachsten Fall mit null Dimensionen, womit es sich  um nichts weiter als einen Skalar (also eine Zahl) handelt. Ein eindimensionales Array ist eine Liste von Daten, ein zweidimensionales eine Tabelle usw...

Wir können uns die Graphen in Tensorflow ein wenig wie Aktivitätsdiagramme vorstellen oder auch wie die in modernen Streaming-Frameworks verwendeten Graphen zur Beschreibung der gewünschten Datenverarbeitung. Genau zu diesem Zweck werden Sie auch in TensorFlow verwendet. Entsprechend muss in einem Java-Programm auch zunächst ein Graph als Objekt angelegt werden:

  Graph g = new Graph();

Diesen Graphen können wir nun Schritt für Schritt zusammenbauen, wie das geht, wollen wir uns zunächst am Beispiel einer einfachen Addition betrachten. Dazu legen wir mit Hilfe der OperationBuilder-Methode des Graphen Operanden für zwei Skalare an. Diese werden durch den Aufruf der zugehörigen Build-Methode im Graphen verfügbar gemacht:

  Operation x = g.opBuilder("Placeholder", "x").setAttr("dtype", DataType.DOUBLE).build();     
    
  Operation y = g.opBuilder("Placeholder", "y").setAttr("dtype", DataType.DOUBLE).build();

Um diese beiden einzelnen Knoten zu einem echten Graphen zu verbinden, benötigen wir noch unsere Additionsoperation, die die beiden Placeholder (bzw. die darin enthaltenen Daten) zusammenführt:

  Operation z = g.opBuilder("Add", "sum").addInput(x.output(0)).addInput(y.output(0)).build();

Bis hierher ist noch keinerlei Berechnung passiert, wir haben nur eine leere Hülse für den Graphen angelegt. Um diese mit Leben zu füllen und die eigentliche Berechnung auszuführen, benötigen wir noch eine Session, die wieder ganz normal als Java-Objekt instanziiert wird und dabei den auszuführenden Graphen als Parameter übergeben bekommt:

  Session s = new Session(g); 

Im Anschluss definieren wir als Rückgabe wie folgt einen parametrisierten Tensorund lassen seinen Double-Wert auf der Konsole ausgeben:

  Tensor<Double> t = s.runner().fetch("sum")
      .feed("x", Tensor.<Double>create(4.0, Double.class))
      .feed("y", Tensor.<Double>create(6.0, Double.class))
      .run().get(0).expect(Double.class);


  System.out.println(t.doubleValue());

  t.close();
  s.close(); 

Wie unschwer am Beispiel-Code zu erkennen ist, wird sum als Ergebnis nach Ausführen der Additionsoperation abgegriffen, vor der Ausführung müssen natürlich noch Werte für die beiden Platzhalter x und y an den Session-Runner gefüttert werden. Abschließend sollten sowohl der angelegte Tensor als auch die Session wieder geschlossen werden.

Matrix mit Skalar multiplizieren

Um auch noch ein anspruchsvolleres Beispiel zu demonstrieren, möchte ich im Folgenden erklären, wie sich mit TensorFlow eine Matrix mit einem Skalar multiplizieren lässt. Dazu legen wir zunächst  eine einfache Matrix als Tensor an und legen diesen als Konstante im Graphen ab:

  Tensor<Integer> t2 = Tensor.create(new int[][]{ {5,2},{3,4} }, Integer.class);            
  Operation a = g.opBuilder("Const", "u").setAttr("dtype", DataType.INT32).setAttr("value", t2).build();


Entsprechend können wir mit einem Skalar verfahren:
         
  Operation b = g.opBuilder("Const", "v").setAttr("dtype", DataType.INT32).setAttr("value", Tensor.<Integer>create(4, Integer.class)).build();

Abschließend müssen wir noch die beiden Knoten noch mit einer Multiplikation verbinden, diese über den Session-Runner ausführen lassen und den Ergebnistensor, der natürlich wieder eine Matrix ist, in ein Java-Array umwandeln lassen:           

  g.opBuilder("Mul", "prod").addInput(b.output(0)).addInput(a.output(0)).build();
           
  Tensor<Integer> t3 = s.runner().fetch("prod").run().get(0).expect(Integer.class);
            

  System.out.println(t3.toString());
  int[][] res = t3.copyTo(new int[2][2]);
  System.out.println(res[0][0] + " " + res[0][1]);
  System.out.println(res[1][0] + " " + res[1][1]);


Das war "quick and clean" eine erste Einführung in TensorFlow für/mit Java, weitere Beispiele werden bei nächster Gelegenheit sicherlich folgen.

Beispielcode

Der Vollständigkeit halber sei abschließend noch das vollständige Additionsprogramm mit allen zugehörigen Imports gepostet:

import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

public class TfAdd {

    public static void main(String[] args) {

        try {
            Graph g = new Graph();

            Operation x = g.opBuilder("Placeholder", "x").setAttr("dtype", DataType.DOUBLE).build();       
            Operation y = g.opBuilder("Placeholder", "y").setAttr("dtype", DataType.DOUBLE).build();
          
            Operation z = g.opBuilder("Add", "sum").addInput(x.output(0)).addInput(y.output(0)).build();

            Session s = new Session(g);

            Tensor<Double> t = s.runner().fetch("sum")
                    .feed("x", Tensor.<Double>create(4.0, Double.class))
                    .feed("y", Tensor.<Double>create(6.0, Double.class))
                    .run().get(0).expect(Double.class);


            System.out.println(t.doubleValue());
            
            t.close();
            s.close();
          
        } catch(Exception e) {
            e.printStackTrace();
        }
    }



Auch die Matrixmultiplikation lässt sich ohne Probleme in diesen Rahmen einbauen. Und als kleine Übungsaufgaben bieten sich natürlich Matrix-Vektor- oder auch Matrix-Matrix-Multiplikationen an.