2024-10-03 00:32:59 +02:00
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"log"
|
|
|
|
|
|
|
|
tf "github.com/wamuir/graft/tensorflow"
|
|
|
|
"github.com/wamuir/graft/tensorflow/op"
|
|
|
|
)
|
|
|
|
|
|
|
|
func generatePrimesGPU(upperLimit int64, primes []int64, progress chan float32) []int64 {
|
|
|
|
// setting up tensorflow graph
|
|
|
|
scope := op.NewScope()
|
|
|
|
candidatePlaceholder := must(Placeholder(scope, tf.Int64, "candidate"))
|
|
|
|
primesPlaceholder := must(Placeholder(scope, tf.Int64, "primes"))
|
|
|
|
minDimension := op.Const(scope, int64(0))
|
|
|
|
mod := op.Mod(scope, candidatePlaceholder, primesPlaceholder)
|
|
|
|
min := op.Min(scope, mod, minDimension)
|
|
|
|
|
|
|
|
graph := must(scope.Finalize())
|
|
|
|
session := must(tf.NewSession(graph, nil))
|
|
|
|
defer session.Close()
|
|
|
|
|
|
|
|
feeds := make(map[tf.Output]*tf.Tensor)
|
|
|
|
fetches := []tf.Output{min}
|
|
|
|
|
|
|
|
fmt.Printf("Calculating with tensorflow\n\n")
|
|
|
|
|
2024-10-03 00:51:21 +02:00
|
|
|
for i := (primes[len(primes) - 1] + 2); i <= upperLimit; i += 2 {
|
2024-10-03 00:32:59 +02:00
|
|
|
select {
|
|
|
|
case progress <- calculateProgress(6, upperLimit, 1, i):
|
|
|
|
default:
|
|
|
|
}
|
|
|
|
|
|
|
|
feeds[candidatePlaceholder] = must(tf.NewTensor(i))
|
|
|
|
feeds[primesPlaceholder] = must(tf.NewTensor(primes))
|
|
|
|
|
|
|
|
result := must(session.Run(feeds, fetches, nil))
|
|
|
|
|
|
|
|
if result[0].Value().(int64) > 0 { // min of all rests. if > 0, no divisor was found
|
|
|
|
primes = append(primes, i)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return primes
|
|
|
|
}
|
|
|
|
|
|
|
|
func Placeholder(scope *op.Scope, dtype tf.DataType, name string) (tf.Output, error) {
|
|
|
|
err := scope.Err()
|
|
|
|
if err != nil {
|
|
|
|
return tf.Output{}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
opspec := tf.OpSpec{
|
|
|
|
Type: "Placeholder",
|
|
|
|
Name: name,
|
|
|
|
Attrs: map[string]interface{}{
|
|
|
|
"dtype": dtype,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
return scope.AddOperation(opspec).Output(0), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func must[T any](obj T, err error) T {
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal(err)
|
|
|
|
}
|
|
|
|
return obj
|
|
|
|
}
|