Facebook Linkedin Twitter
Posted Fri Dec 17, 2021 •  Reading time: 8 minutes

Iris Classification with gRPC and Machine Learning

Today we’ll be going pretty fast over how to create a service using gRPC and embedding a machine learning model, all written (of course) in Go.

We’ll split this post into 3 major parts

  1. Creating a gRPC service
  2. Creating a machine learning model
  3. Using the model in our service

Feel free to just browse the source code if you’ve already read the blog post.

Creating a gRPC Service

In this section we’re going to learn how to define a contract for our service, as well as implement the code backing the service - though we’re not going to return anything useful.

Let’s start!

syntax = "proto3";

package iris_classification.v1;

option go_package = "github.com/trelore/iris-classification/proto/gen/go;irisclassificationpb";

// IrisClassificationService is a service to predict the Iris Classification given input
service IrisClassificationService {
  // Predict the Iris Classification
  rpc Predict(PredictRequest) returns (PredictResponse);
}

// petal length, petal width, sepal length, sepal width
message PredictRequest {
  // length of petal
  float petal_length = 1;
  // width of petal
  float petal_width = 2;
  //length of sepal
  float sepal_length = 3;
  // width of sepal
  float sepal_width = 4;
}

// the predication response
message PredictResponse {
  // prediction of what classification of iris it is
  string predicition = 1;
}

That was a lot, and I promise we’ll dive into the Go code soon. First, let’s dig into what this contract is.

The key things here are the service and the two messages. The service defines our rpc calls that we allow. The messages define the structure of the messages in and out of the service. The proto file is designed to be human and machine readable, I’d definitely recommend looking at the language guide for further reading.

Next step, let’s generate the code.

As of December 2021, I’d 100% recommend using buf for this. It allows developers to easily generate code, as well as lint their proto files, and check pull requests for breaking changes. Check it out if you haven’t already.

Let’s take a look at our buf.gen.yaml file. For this to work, we’ll need some dependencies protoc-gen-go and protoc-gen-go-grpc.

version: v1
plugins:
  - name: go
    out: gen/go
    opt: paths=source_relative
  - name: go-grpc
    out: gen/go
    opt:
      - paths=source_relative

To generate our code, we can simply run buf generate. Thanks buf!

Let’s take a look at what code that generated for us (not all of it, just the important bits).

type IrisClassificationServiceServer interface {
	// Predict the Iris Classification
	Predict(context.Context, *PredictRequest) (*PredictResponse, error)
	mustEmbedUnimplementedIrisClassificationServiceServer()
}

As you might expect, it’s defined a server interface, with the Predict function, which takes our PredictRequest and returns our PredictResponse (and an error). So all we really need to do is define a struct that has a Predict function, and we’re in a really good place.

Lastly you’ll notice the mustEmbedUnimplementedIrisClassificationServiceServer function, there was a big discussion on github as to why you should embed unimplemented.

Now we’ve defined our service and generated the code, let’s add it to a go application.

Let’s create a struct, that we can make methods on, we’re going to put this in the package server. The below snippet will define a struct, with a New func to instantiate it, as well as ‘implement’ the Predict function. For now we’re just going to return unimplemented whilst we switch gears to focus on our ML model.

package server

import (
  "context"

  pb "github.com/trelore/iris-classification/proto/gen/go/iris_classification/v1"
  "google.golang.org/grpc/codes"
  "google.golang.org/grpc/status"
)

// New returns a new S
func New() S {
  return S{}
}

// S Implements the IrisClassificationService
type S struct {
  pb.UnimplementedIrisClassificationServiceServer
}

// Predict implements proto
func (s *S) Predict(ctx context.Context, req *pb.PredictRequest) (*pb.PredictResponse, error) {
  return nil, status.Errorf(codes.Unimplemented, "method Predict not implemented")
}

Next up we need to instantiate our server, create a new grpc server and connect the two. After that, serve the server. This chunk of code belongs in our main function of our server.

import (
  pb "github.com/trelore/iris-classification/proto/gen/go/iris_classification/v1"
  "github.com/trelore/iris-classification/svc/server"
  "google.golang.org/grpc"
)

func main() {
  s := server.New()
  grpcS := grpc.NewServer()
  defer grpcS.GracefulStop()

  pb.RegisterIrisClassificationServiceServer(grpcS, &s)

  address := ":32400"
  log.Printf("listening to address %s", address)
  listener, err := net.Listen("tcp", address)
  if err != nil {
    log.Fatal(err)
  }
  grpcS.Serve(listener)
}

If we try to run our main package now, we should see our service is listening on port 32400. Here we can introduce evans, a tool to call gRPC services from the command line. By running the following command we can run evans against our own service (picture this like postman, or curl, but for grpc!).

evans -p 32400 proto/idl/iris_classification/v1/service.proto

From here we can explore our gRPC service through various commands like show, and desc. After a bit of exploring, we can run call Predict, this allows us to enter the details of the Request message from earlier. Enter in some random data (it doesn’t matter what so long as it’s the right type), as we return a random iris classification anyway.

Your code should return an Unimplemented server code at this point.

Creating a Machine Learning Model

At this point I offer a disclaimer, ML is not my forte (yet!), so expect a lot of links to other resources.

I tried a couple of different Go libraries for this, my main requirements was having a model I could train in some kind of CI step, and embed that model into a go application. The code also had to be simple to read, as personally that’s one of the main selling points of Go for me.

In the end I settled for gorgonia, which amazingly had a tutorial covering exactly what we were creating - iris classification. We’ll refactor it slightly, but largely it’ll be the same code. Their full code (as well as many other examples) can be found here.

I’ll leave the explanations of various ML concepts to gorgonia - we’ll just copy their code for the purposes of this blog post - but I heavily recommend reading their tutorial.

The only ammendment we’ll do is embedding the training data. I’m still playing around with the best practices for this, but I’ve found in this particular case we can just put our csv file alongside our code in the same directory. This allows us to write a couple of lines of code (for the purposes of this, put it into a new directory separate from our gRPC service);

package datasets

import (
	_ "embed"
)

//go:embed iris.csv
var data []byte

Note the use of the underscore embed here, it allows us to use the embed pragma for our data. This allows for a few benefits, mainly our code to be ported more easily (this will come into play more when we build a binary for our server).

We’re also going to modify the start of getXYMat() so it should look something like this;

df := dataframe.ReadCSV(bytes.NewReader(data))

Done. That’s our only change to make.

With this in place, we should be able to run the training model, which should spit out a theta.bin file.

Using the Model in Our Service

We’re almost done. The last thing we need to do is change our Unimplemented function from earlier to use the model. But how? Easy, if we look at the iris-classification code from earlier, we can see there’s a cmd/main.go (here), this is the ‘predict’ script. Again, take a look at the tutorial if you want a more in depth explanation of what everything does.

One of the key things to note is there is a part of the code that reads in the model, and another that uses the model. We’re going to use that logical separation and load the model when instantiating the server, and use the model when a request comes in.

Let’s go back to part 1 and modify our New() function in the server package. It should look something like this.

// New returns a new S
func New() S {
	var thetaT *tensor.Dense
	err := gob.NewDecoder(bytes.NewReader(data)).Decode(&thetaT)
	if err != nil {
		log.Fatal(err)
	}
	return S{thetaT: thetaT}
}

This looks good, but what is data? If we do a similar trick with embedding the model into the service like this;

package models

import (
	_ "embed"
)

//go:embed theta.bin
var data []byte

Now to explain why this is so beneficial. What happens when we run go build main.go? Well, with embed we no longer need to carry a theta.bin in a very specific place, we have that file in the binary itself. This does come at the cost of a larger file size, but I think the cost of 189 bytes (in my case) it’s worth it.

The last modification we need to make is to the Predict method handler use the s.thetaT from the model. In our case we can replace their getInput function with our *pb.PredictRequest input and instead of printing the output we can return it wrapped in our *pb.PredictResponse.

// Predict implements proto
func (s *S) Predict(ctx context.Context, req *pb.PredictRequest) (*pb.PredictResponse, error) {
	g := gorgonia.NewGraph()
	theta := gorgonia.NodeFromAny(g, s.thetaT)

	values := []float64{
		req.GetSepalLength(),
		req.GetSepalWidth(),
		req.GetPetalLength(),
		req.GetPetalWidth(),
		1.0,
	}
	xT := tensor.New(tensor.WithBacking(values))
	x := gorgonia.NodeFromAny(g, xT)
	y, err := gorgonia.Mul(x, theta)
	if err != nil {
		return nil, status.Error(codes.Internal, err.Error())
	}
	machine := gorgonia.NewTapeMachine(g)
	defer machine.Close()

	if err = machine.RunAll(); err != nil {
		return nil, status.Error(codes.Internal, err.Error())
	}
	defer machine.Reset()

	var class string
	switch math.Round(y.Value().Data().(float64)) {
	case 1:
		class = "setosa"
	case 2:
		class = "virginica"
	case 3:
		class = "versicolor"
	default:
		return nil, status.Error(codes.Internal, "unknown iris")
	}
	return &pb.PredictResponse{
		Predicition: class,
	}, nil
}

As far as the code goes, we’re now done. We have a script that creates a ML model, and a service that embeds the model predicting iris classifications.

If you refer back to the end of part 1, we tested our service returned the Unimplemented status code using evans. If we try running our service and start up evans again, and pass in some sensible values (try copying a particular flower from the input set), we can see the service returns sensible results.

iris_classification.v1.IrisClassificationService@127.0.0.1:32400> call Predict
sepal_length (TYPE_DOUBLE) => 5.1
sepal_width (TYPE_DOUBLE) => 3.5
petal_length (TYPE_DOUBLE) => 1.4
petal_width (TYPE_DOUBLE) => 0.2
{
  "predicition": "setosa"
}

Closing Remarks

I hope you learnt something from reading through this. We covered a lot in a short space, with a lot of links out to other resources - with any luck you’ll have about 10 tabs open now just relating to this.

See if you can follow along with the tutorial with a different linear regression problem. Or for bonus points, try another ML algorithm altogether.

Thanks!