80 lines
1.4 KiB
Go
80 lines
1.4 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"math/rand"
|
|
|
|
"git.andr3h3nriqu3s.com/andr3/gotch"
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
|
)
|
|
|
|
const (
|
|
ImageDimNN int64 = 784
|
|
HiddenNodesNN int64 = 128
|
|
LabelNN int64 = 10
|
|
|
|
BatchSize int64 = 3000
|
|
|
|
epochsNN = 200
|
|
LrNN = 1e-3
|
|
)
|
|
|
|
type model struct {
|
|
fc *nn.Linear
|
|
act nn.Func
|
|
}
|
|
|
|
func newModel(vs *nn.VarStore) *model {
|
|
fc := nn.NewLinear(vs.Root(), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig())
|
|
act := nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
|
|
return xs.MustRelu(false)
|
|
})
|
|
|
|
return &model{
|
|
fc: fc,
|
|
act: act,
|
|
}
|
|
}
|
|
|
|
func (m *model) Forward(x *ts.Tensor) *ts.Tensor {
|
|
fc := m.fc.Forward(x)
|
|
act := m.act.Forward(fc)
|
|
|
|
return act
|
|
}
|
|
|
|
func newData() []float32 {
|
|
n := int(BatchSize * ImageDimNN)
|
|
data := make([]float32, n)
|
|
for i := 0; i < n; i++ {
|
|
data[i] = rand.Float32()
|
|
}
|
|
|
|
return data
|
|
}
|
|
|
|
func main() {
|
|
epochs := 4000
|
|
|
|
// device := gotch.CPU
|
|
device := gotch.CudaIfAvailable()
|
|
vs := nn.NewVarStore(device)
|
|
m := newModel(vs)
|
|
|
|
for i := 0; i < epochs; i++ {
|
|
// input := ts.MustOfSlice(newData()).MustView([]int64{BatchSize, ImageDimNN}, true).MustTo(device, true)
|
|
input := ts.MustRandn([]int64{BatchSize, ImageDimNN}, gotch.Float, device)
|
|
|
|
ts.NoGrad(func() {
|
|
_ = m.Forward(input)
|
|
})
|
|
|
|
if i%10 == 0 {
|
|
fmt.Printf("=================== Epoch %03d completed========================\n", i)
|
|
}
|
|
}
|
|
|
|
ts.CleanUp()
|
|
}
|