package main import ( "fmt" "log" tf "github.com/wamuir/graft/tensorflow" "github.com/wamuir/graft/tensorflow/op" ) func generatePrimesGPU(upperLimit int64, primes []int64, progress chan float32) []int64 { // setting up tensorflow graph scope := op.NewScope() candidatePlaceholder := must(Placeholder(scope, tf.Int64, "candidate")) primesPlaceholder := must(Placeholder(scope, tf.Int64, "primes")) minDimension := op.Const(scope, int64(0)) mod := op.Mod(scope, candidatePlaceholder, primesPlaceholder) min := op.Min(scope, mod, minDimension) graph := must(scope.Finalize()) session := must(tf.NewSession(graph, nil)) defer session.Close() feeds := make(map[tf.Output]*tf.Tensor) fetches := []tf.Output{min} fmt.Printf("Calculating with tensorflow\n\n") continueFrom := primes[len(primes) - 1] + 2 for i := continueFrom; i <= upperLimit; i += 2 { select { case progress <- calculateProgress(continueFrom, upperLimit, i): default: } feeds[candidatePlaceholder] = must(tf.NewTensor(i)) feeds[primesPlaceholder] = must(tf.NewTensor(primes)) result := must(session.Run(feeds, fetches, nil)) if result[0].Value().(int64) > 0 { // min of all rests. if > 0, no divisor was found primes = append(primes, i) } } return primes } func Placeholder(scope *op.Scope, dtype tf.DataType, name string) (tf.Output, error) { err := scope.Err() if err != nil { return tf.Output{}, err } opspec := tf.OpSpec{ Type: "Placeholder", Name: name, Attrs: map[string]interface{}{ "dtype": dtype, }, } return scope.AddOperation(opspec).Output(0), nil } func must[T any](obj T, err error) T { if err != nil { log.Fatal(err) } return obj }