prime-div/tensor.go

70 lines
1.6 KiB
Go
Raw Normal View History

package main
import (
"fmt"
"log"
tf "github.com/wamuir/graft/tensorflow"
"github.com/wamuir/graft/tensorflow/op"
)
func generatePrimesGPU(upperLimit int64, primes []int64, progress chan float32) []int64 {
// setting up tensorflow graph
scope := op.NewScope()
candidatePlaceholder := must(Placeholder(scope, tf.Int64, "candidate"))
primesPlaceholder := must(Placeholder(scope, tf.Int64, "primes"))
minDimension := op.Const(scope, int64(0))
mod := op.Mod(scope, candidatePlaceholder, primesPlaceholder)
min := op.Min(scope, mod, minDimension)
graph := must(scope.Finalize())
session := must(tf.NewSession(graph, nil))
defer session.Close()
feeds := make(map[tf.Output]*tf.Tensor)
fetches := []tf.Output{min}
fmt.Printf("Calculating with tensorflow\n\n")
for i := (primes[len(primes) - 1] + 1); i <= upperLimit; i++ {
select {
case progress <- calculateProgress(6, upperLimit, 1, i):
default:
}
feeds[candidatePlaceholder] = must(tf.NewTensor(i))
feeds[primesPlaceholder] = must(tf.NewTensor(primes))
result := must(session.Run(feeds, fetches, nil))
if result[0].Value().(int64) > 0 { // min of all rests. if > 0, no divisor was found
primes = append(primes, i)
}
}
return primes
}
func Placeholder(scope *op.Scope, dtype tf.DataType, name string) (tf.Output, error) {
err := scope.Err()
if err != nil {
return tf.Output{}, err
}
opspec := tf.OpSpec{
Type: "Placeholder",
Name: name,
Attrs: map[string]interface{}{
"dtype": dtype,
},
}
return scope.AddOperation(opspec).Output(0), nil
}
func must[T any](obj T, err error) T {
if err != nil {
log.Fatal(err)
}
return obj
}