diff --git a/go.mod b/go.mod index 69b4248..17c0e72 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,7 @@ module datalore/prime-div go 1.23.0 + +require github.com/wamuir/graft v0.9.0 + +require google.golang.org/protobuf v1.34.2 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..23633ba --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/wamuir/graft v0.9.0 h1:5DbPtr3MfWRq9bFHivbbvNic8h8jtcKK12Rxk0644iY= +github.com/wamuir/graft v0.9.0/go.mod h1:k6NJX3fCM/xzh5NtHky9USdgHTcz2vAvHp4c23I6UK4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= diff --git a/main.go b/main.go index cac7e11..b3f7828 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,7 @@ import ( "time" ) -func calculateProgress(start, end, step, index int) float32 { +func calculateProgress(start, end, step, index int64) float32 { steps := (end - start) / step current := index - start if current == 0 { @@ -32,7 +32,7 @@ func printProgress(progress <-chan float32, message string) { } } -func checkIsDivisibleByPrime(number, offset, stride int, primes *[]int, resultChannel chan bool, ctx context.Context) { +func checkIsDivisibleByPrime(number int64, offset, stride int, primes *[]int64, resultChannel chan bool, ctx context.Context) { for i := offset; i < len(*primes); i += stride { select { case <-ctx.Done(): @@ -50,28 +50,45 @@ func checkIsDivisibleByPrime(number, offset, stride int, primes *[]int, resultCh resultChannel <- false } -func generatePrimes(upperLimit int, loadList bool, numRoutines int) []int { - var primes []int +func generatePrimes(upperLimit int64, loadList bool, numRoutines int, useTensors bool) []int64 { + var primes []int64 bootTime := time.Now() if loadList { primes = loadPrimes() } else { - primes = []int{2, 3, 5} + primes = []int64{2, 3, 5} } - if numRoutines == 0 { - numRoutines = runtime.NumCPU() - } - - fmt.Printf("Startup time: %v\nCalculating with %d routines\n\n", time.Now().Sub(bootTime), numRoutines) + fmt.Printf("Startup time: %v\n", time.Now().Sub(bootTime)) progress := make(chan float32) go printProgress(progress, "Generating primes") startTime := time.Now() + if useTensors { + primes = generatePrimesGPU(upperLimit, primes, progress) + } else { + primes = generatePrimesCPU(upperLimit, primes, numRoutines, progress) + } + close(progress) + + endTime := time.Now() + + fmt.Printf("Prime generation took %v\nLargest prime found: %d\nTotal prime number count: %d\n", endTime.Sub(startTime), primes[len(primes) - 1], len(primes)) + + return primes +} + +func generatePrimesCPU(upperLimit int64, primes []int64, numRoutines int, progress chan float32) []int64 { + if numRoutines == 0 { + numRoutines = runtime.NumCPU() + } + + fmt.Printf("Calculating with %d routines\n\n", numRoutines) + for i := (primes[len(primes) - 1] + 1); i <= upperLimit; i++ { select { case progress <- calculateProgress(6, upperLimit, 1, i): @@ -99,30 +116,24 @@ func generatePrimes(upperLimit int, loadList bool, numRoutines int) []int { primes = append(primes, i) } } - close(progress) - - endTime := time.Now() - - fmt.Printf("Prime generation took %v\nLargest prime found: %d\nTotal prime number count: %d\n", endTime.Sub(startTime), primes[len(primes) - 1], len(primes)) - return primes } -func calculatePrimeParts(number int, primes []int) []int { +func calculatePrimeParts(number int64, primes []int64) []int64 { // don't calculate if number is a prime itself if primes[len(primes)-1] == number { - return []int{number} + return []int64{number} } progress := make(chan float32) go printProgress(progress, "Calculating") - var primeParts []int + var primeParts []int64 - for i := len(primes) - 1; i >= 0; i-- { + for i := int64(len(primes) - 1); i >= 0; i-- { select { - case progress <- 1.0 - calculateProgress(0, len(primes) - 1, 1, i): + case progress <- 1.0 - calculateProgress(0, int64(len(primes) - 1), 1, i): default: } flooredDiv := number / primes[i] @@ -143,13 +154,17 @@ func calculatePrimeParts(number int, primes []int) []int { } func main() { - var primeList bool - var dontLoad bool - var numRoutines int + var ( + primeList bool + dontLoad bool + numRoutines int + useTensors bool + ) flag.BoolVar(&primeList, "p", false, "Only calculate and print prime list") flag.BoolVar(&dontLoad, "d", false, "Don't load precalculated primes, calculate from 0") flag.IntVar(&numRoutines, "r", 0, "How many routines to use for calculation. 0 = number of available CPU cores") + flag.BoolVar(&useTensors, "t", false, "Use tensorflow") flag.Parse() numStr := flag.Arg(0) @@ -159,20 +174,20 @@ func main() { return } - number, err := strconv.Atoi(numStr) + number, err := strconv.ParseInt(numStr, 10, 64) if err != nil { log.Fatal(err) } if primeList { - onlyGenerate(number, dontLoad, numRoutines) + onlyGenerate(number, dontLoad, numRoutines, useTensors) } else { - calculate(number, dontLoad, numRoutines) + calculate(number, dontLoad, numRoutines, useTensors) } } -func onlyGenerate(number int, dontLoad bool, numRoutines int) { - primes := generatePrimes(number, !dontLoad, numRoutines) +func onlyGenerate(number int64, dontLoad bool, numRoutines int, useTensors bool) { + primes := generatePrimes(number, !dontLoad, numRoutines, useTensors) file, err := os.Create("prime.txt") if err != nil { @@ -189,12 +204,12 @@ func onlyGenerate(number int, dontLoad bool, numRoutines int) { file.Close() } -func calculate(number int, dontLoad bool, numRoutines int) { - primes := generatePrimes(number, !dontLoad, numRoutines) +func calculate(number int64, dontLoad bool, numRoutines int, useTensors bool) { + primes := generatePrimes(number, !dontLoad, numRoutines, useTensors) primeParts := calculatePrimeParts(number, primes) - sum := 0 + var sum int64 for i, prime := range primeParts { sum += prime @@ -209,18 +224,18 @@ func calculate(number int, dontLoad bool, numRoutines int) { println(sum) } -func loadPrimes() []int { +func loadPrimes() []int64 { file, err := os.Open("prime.txt") if err != nil { log.Fatal(err) } - var primes []int + var primes []int64 scanner := bufio.NewScanner(file) for scanner.Scan() { - nextPrime, err := strconv.Atoi(scanner.Text()) + nextPrime, err := strconv.ParseInt(scanner.Text(), 10, 64) if err != nil { log.Fatal(err) } diff --git a/tensor.go b/tensor.go new file mode 100644 index 0000000..4948101 --- /dev/null +++ b/tensor.go @@ -0,0 +1,69 @@ +package main + +import ( + "fmt" + "log" + + tf "github.com/wamuir/graft/tensorflow" + "github.com/wamuir/graft/tensorflow/op" +) + +func generatePrimesGPU(upperLimit int64, primes []int64, progress chan float32) []int64 { + // setting up tensorflow graph + scope := op.NewScope() + candidatePlaceholder := must(Placeholder(scope, tf.Int64, "candidate")) + primesPlaceholder := must(Placeholder(scope, tf.Int64, "primes")) + minDimension := op.Const(scope, int64(0)) + mod := op.Mod(scope, candidatePlaceholder, primesPlaceholder) + min := op.Min(scope, mod, minDimension) + + graph := must(scope.Finalize()) + session := must(tf.NewSession(graph, nil)) + defer session.Close() + + feeds := make(map[tf.Output]*tf.Tensor) + fetches := []tf.Output{min} + + fmt.Printf("Calculating with tensorflow\n\n") + + for i := (primes[len(primes) - 1] + 1); i <= upperLimit; i++ { + select { + case progress <- calculateProgress(6, upperLimit, 1, i): + default: + } + + feeds[candidatePlaceholder] = must(tf.NewTensor(i)) + feeds[primesPlaceholder] = must(tf.NewTensor(primes)) + + result := must(session.Run(feeds, fetches, nil)) + + if result[0].Value().(int64) > 0 { // min of all rests. if > 0, no divisor was found + primes = append(primes, i) + } + } + return primes +} + +func Placeholder(scope *op.Scope, dtype tf.DataType, name string) (tf.Output, error) { + err := scope.Err() + if err != nil { + return tf.Output{}, err + } + + opspec := tf.OpSpec{ + Type: "Placeholder", + Name: name, + Attrs: map[string]interface{}{ + "dtype": dtype, + }, + } + + return scope.AddOperation(opspec).Output(0), nil +} + +func must[T any](obj T, err error) T { + if err != nil { + log.Fatal(err) + } + return obj +}