diff --git a/example/mnist/README.md b/example/mnist/README.md index 4502bef..5da7688 100644 --- a/example/mnist/README.md +++ b/example/mnist/README.md @@ -31,6 +31,236 @@ - Accuracy should be about **99.3%**. +## Benchmark against Python + +- Train batch size: 256 +- Test batch size: 1000 +- Adam optimizer, learning rate = 3*1e-4 +- Epochs: 30 + +```python +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import time + +class Net(nn.Module): + + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 5, 1) + self.conv2 = nn.Conv2d(32, 64, 5, 1) + self.fc1 = nn.Linear(1024, 1024) + self.dropout = nn.Dropout(0.5) + self.fc2 = nn.Linear(1024, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.max_pool2d(x, 2) + x = self.conv2(x) + x = F.max_pool2d(x, 2) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss( + output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax( + dim=1, + keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', + type=int, + default=256, + metavar='N', + help='input batch size for training (default: 256)') + parser.add_argument('--test-batch-size', + type=int, + default=1000, + metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', + type=int, + default=14, + metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', + type=float, + default=1e-4, + metavar='LR', + help='learning rate (default: 1e-4)') + parser.add_argument('--no-cuda', + action='store_true', + default=False, + help='disables CUDA training') + parser.add_argument('--seed', + type=int, + default=1, + metavar='S', + help='random seed (default: 1)') + parser.add_argument('--save-model', + action='store_true', + default=False, + help='For Saving the current Model') + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + + torch.manual_seed(args.seed) + + if use_cuda: + device = torch.device("cuda") + else: + device = torch.device("cpu") + + train_kwargs = {'batch_size': args.batch_size} + test_kwargs = {'batch_size': args.test_batch_size} + if use_cuda: + cuda_kwargs = {'num_workers': 1, 'pin_memory': True, 'shuffle': True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform = transforms.Compose([ + transforms.ToTensor(), + # transforms.Normalize((0.1307, ), (0.3081, )), + ]) + dataset1 = datasets.MNIST('../data', + train=True, + download=True, + transform=transform) + dataset2 = datasets.MNIST('../data', train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + model = Net().to(device) + optimizer = optim.Adam(model.parameters(), lr=args.lr) + + start = time.time() + + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + + end = time.time() + + print("taken time: {:.2f}mins".format((end - start) / 60.0)) + + if args.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == '__main__': + main() +``` + +```bash +Test set: Average loss: 0.1101, Accuracy: 9666/10000 (96.66%) +Test set: Average loss: 0.0697, Accuracy: 9779/10000 (97.79%) +Test set: Average loss: 0.0442, Accuracy: 9856/10000 (98.56%) +Test set: Average loss: 0.0384, Accuracy: 9873/10000 (98.73%) +Test set: Average loss: 0.0358, Accuracy: 9875/10000 (98.75%) +Test set: Average loss: 0.0323, Accuracy: 9898/10000 (98.98%) +Test set: Average loss: 0.0290, Accuracy: 9906/10000 (99.06%) +Test set: Average loss: 0.0272, Accuracy: 9910/10000 (99.10%) +Test set: Average loss: 0.0280, Accuracy: 9913/10000 (99.13%) +Test set: Average loss: 0.0295, Accuracy: 9908/10000 (99.08%) +Test set: Average loss: 0.0251, Accuracy: 9919/10000 (99.19%) +Test set: Average loss: 0.0246, Accuracy: 9924/10000 (99.24%) +Test set: Average loss: 0.0258, Accuracy: 9921/10000 (99.21%) +Test set: Average loss: 0.0296, Accuracy: 9911/10000 (99.11%) +Test set: Average loss: 0.0271, Accuracy: 9912/10000 (99.12%) +Test set: Average loss: 0.0251, Accuracy: 9918/10000 (99.18%) +Test set: Average loss: 0.0276, Accuracy: 9916/10000 (99.16%) +Test set: Average loss: 0.0291, Accuracy: 9912/10000 (99.12%) +Test set: Average loss: 0.0291, Accuracy: 9920/10000 (99.20%) +Test set: Average loss: 0.0333, Accuracy: 9904/10000 (99.04%) +Test set: Average loss: 0.0268, Accuracy: 9919/10000 (99.19%) +Test set: Average loss: 0.0265, Accuracy: 9931/10000 (99.31%) +Test set: Average loss: 0.0316, Accuracy: 9918/10000 (99.18%) +Test set: Average loss: 0.0299, Accuracy: 9917/10000 (99.17%) +Test set: Average loss: 0.0303, Accuracy: 9923/10000 (99.23%) +Test set: Average loss: 0.0327, Accuracy: 9914/10000 (99.14%) +Test set: Average loss: 0.0314, Accuracy: 9918/10000 (99.18%) +Test set: Average loss: 0.0316, Accuracy: 9920/10000 (99.20%) +Test set: Average loss: 0.0346, Accuracy: 9916/10000 (99.16%) +Test set: Average loss: 0.0308, Accuracy: 9923/10000 (99.23%) +taken time: 5.63mins +``` + +Gotch CNN performance + +```bash +testImages: [10000 784] +testLabels: [10000] +Epoch: 0 Loss: 0.16 Test accuracy: 96.53% +Epoch: 1 Loss: 0.08 Test accuracy: 97.27% +Epoch: 2 Loss: 0.14 Test accuracy: 97.28% +Epoch: 3 Loss: 0.08 Test accuracy: 97.64% +Epoch: 4 Loss: 0.07 Test accuracy: 98.44% +Epoch: 5 Loss: 0.05 Test accuracy: 98.59% +Epoch: 6 Loss: 0.06 Test accuracy: 98.67% +Epoch: 7 Loss: 0.07 Test accuracy: 98.80% +Epoch: 8 Loss: 0.11 Test accuracy: 98.01% +Epoch: 9 Loss: 0.07 Test accuracy: 98.81% +Epoch: 10 Loss: 0.05 Test accuracy: 98.76% +Epoch: 11 Loss: 0.04 Test accuracy: 98.78% +Epoch: 12 Loss: 0.02 Test accuracy: 98.81% +Epoch: 13 Loss: 0.05 Test accuracy: 98.78% +Epoch: 14 Loss: 0.05 Test accuracy: 98.74% +Epoch: 15 Loss: 0.06 Test accuracy: 98.86% +Epoch: 16 Loss: 0.07 Test accuracy: 98.95% +Epoch: 17 Loss: 0.03 Test accuracy: 98.93% +Epoch: 18 Loss: 0.04 Test accuracy: 98.99% +Epoch: 19 Loss: 0.05 Test accuracy: 99.05% +Epoch: 20 Loss: 0.06 Test accuracy: 99.11% +Epoch: 21 Loss: 0.03 Test accuracy: 98.78% +Epoch: 22 Loss: 0.05 Test accuracy: 98.88% +Epoch: 23 Loss: 0.02 Test accuracy: 99.04% +Epoch: 24 Loss: 0.04 Test accuracy: 99.08% +Epoch: 25 Loss: 0.03 Test accuracy: 98.96% +Epoch: 26 Loss: 0.07 Test accuracy: 98.78% +Epoch: 27 Loss: 0.05 Test accuracy: 98.81% +Epoch: 28 Loss: 0.03 Test accuracy: 98.79% +Epoch: 29 Loss: 0.07 Test accuracy: 98.82% +Best test accuracy: 99.11% +Taken time: 2.81 mins +``` diff --git a/example/mnist/cnn.go b/example/mnist/cnn.go index d812c84..d30af35 100644 --- a/example/mnist/cnn.go +++ b/example/mnist/cnn.go @@ -14,11 +14,11 @@ import ( const ( MnistDirCNN string = "../../data/mnist" - epochsCNN = 100 + epochsCNN = 30 batchCNN = 256 batchSize = 256 - LrCNN = 1e-4 + LrCNN = 3 * 1e-4 ) type Net struct { @@ -84,6 +84,7 @@ func runCNN1() { net := newNet(vs.Root()) opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN) + // opt, err := nn.DefaultSGDConfig().Build(vs, LrCNN) if err != nil { log.Fatal(err) } @@ -132,7 +133,7 @@ func runCNN1() { } ts.NoGrad(func() { - testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1024) + testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1000) fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss, testAccuracy*100.0) if testAccuracy > bestAccuracy { bestAccuracy = testAccuracy diff --git a/nn/conv.go b/nn/conv.go index a930314..baeef66 100644 --- a/nn/conv.go +++ b/nn/conv.go @@ -4,6 +4,7 @@ package nn import ( "fmt" + "math" "reflect" "github.com/sugarme/gotch/ts" @@ -76,14 +77,15 @@ func WithBsInit1D(val Init) Conv1DConfigOpt { // DefaultConvConfig create a default 1D ConvConfig func DefaultConv1DConfig() *Conv1DConfig { + negSlope := math.Sqrt(5) return &Conv1DConfig{ Stride: []int64{1}, Padding: []int64{0}, Dilation: []int64{1}, Groups: 1, Bias: true, - WsInit: NewKaimingUniformInit(), - BsInit: NewConstInit(float64(0.0)), + WsInit: NewKaimingUniformInit(WithKaimingNegativeSlope(negSlope)), + BsInit: nil, } } @@ -165,14 +167,15 @@ func WithBsInit2D(val Init) Conv2DConfigOpt { // DefaultConvConfig2D creates a default 2D ConvConfig func DefaultConv2DConfig() *Conv2DConfig { + negSlope := math.Sqrt(5) return &Conv2DConfig{ Stride: []int64{1, 1}, Padding: []int64{0, 0}, Dilation: []int64{1, 1}, Groups: 1, Bias: true, - WsInit: NewKaimingUniformInit(), - BsInit: NewConstInit(float64(0.0)), + WsInit: NewKaimingUniformInit(WithKaimingNegativeSlope(negSlope)), + BsInit: nil, } } @@ -254,14 +257,15 @@ func WithBsInit3D(val Init) Conv3DConfigOpt { // DefaultConvConfig3D creates a default 3D ConvConfig func DefaultConv3DConfig() *Conv3DConfig { + negSlope := math.Sqrt(5) return &Conv3DConfig{ Stride: []int64{1, 1, 1}, Padding: []int64{0, 0, 0}, Dilation: []int64{1, 1, 1}, Groups: 1, Bias: true, - WsInit: NewKaimingUniformInit(), - BsInit: NewConstInit(float64(0.0)), + WsInit: NewKaimingUniformInit(WithKaimingNegativeSlope(negSlope)), + BsInit: nil, } } @@ -288,12 +292,27 @@ func NewConv1D(vs *Path, inDim, outDim, k int64, cfg *Conv1DConfig) *Conv1D { ws *ts.Tensor bs *ts.Tensor = ts.NewTensor() ) - if cfg.Bias { - bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit) - } weightSize := []int64{outDim, int64(inDim / cfg.Groups)} weightSize = append(weightSize, k) ws = vs.MustNewVar("weight", weightSize, cfg.WsInit) + if cfg.Bias { + switch { + case cfg.BsInit == nil: + fanIn, _, err := CalculateFans(weightSize) + if err != nil { + err := fmt.Errorf("NewConv1D() initiate bias failed: %v", err) + panic(err) + } + bound := 0.0 + if fanIn > 0 { + bound = 1 / math.Sqrt(float64(fanIn)) + } + bsInit := NewUniformInit(-bound, bound) + bs = vs.MustNewVar("bias", []int64{outDim}, bsInit) + case cfg.BsInit != nil: + bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit) + } + } return &Conv1D{ Ws: ws, @@ -315,13 +334,29 @@ func NewConv2D(vs *Path, inDim, outDim int64, k int64, cfg *Conv2DConfig) *Conv2 ws *ts.Tensor bs *ts.Tensor = ts.NewTensor() ) - if cfg.Bias { - bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit) - } weightSize := []int64{outDim, int64(inDim / cfg.Groups)} weightSize = append(weightSize, k, k) ws = vs.MustNewVar("weight", weightSize, cfg.WsInit) + if cfg.Bias { + switch { + case cfg.BsInit == nil: + fanIn, _, err := CalculateFans(weightSize) + if err != nil { + err := fmt.Errorf("NewConv2D() initiate bias failed: %v", err) + panic(err) + } + bound := 0.0 + if fanIn > 0 { + bound = 1 / math.Sqrt(float64(fanIn)) + } + bsInit := NewUniformInit(-bound, bound) + bs = vs.MustNewVar("bias", []int64{outDim}, bsInit) + case cfg.BsInit != nil: + bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit) + } + } + return &Conv2D{ Ws: ws, Bs: bs, @@ -342,13 +377,29 @@ func NewConv3D(vs *Path, inDim, outDim, k int64, cfg *Conv3DConfig) *Conv3D { ws *ts.Tensor bs *ts.Tensor = ts.NewTensor() ) - if cfg.Bias { - bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit) - } weightSize := []int64{outDim, int64(inDim / cfg.Groups)} weightSize = append(weightSize, k, k, k) ws = vs.MustNewVar("weight", weightSize, cfg.WsInit) + if cfg.Bias { + switch { + case cfg.BsInit == nil: + fanIn, _, err := CalculateFans(weightSize) + if err != nil { + err := fmt.Errorf("NewConv3D() initiate bias failed: %v", err) + panic(err) + } + bound := 0.0 + if fanIn > 0 { + bound = 1 / math.Sqrt(float64(fanIn)) + } + bsInit := NewUniformInit(-bound, bound) + bs = vs.MustNewVar("bias", []int64{outDim}, bsInit) + case cfg.BsInit != nil: + bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit) + } + } + return &Conv3D{ Ws: ws, Bs: bs, diff --git a/nn/init.go b/nn/init.go index 8ea06cf..91197f7 100644 --- a/nn/init.go +++ b/nn/init.go @@ -1,8 +1,10 @@ package nn import ( + "fmt" "log" "math" + "strings" "github.com/sugarme/gotch" "github.com/sugarme/gotch/ts" @@ -120,24 +122,88 @@ func (u uniformInit) Set(tensor *ts.Tensor) { // kaiminguniformInit : // ==================== - -type kaimingUniformInit struct{} - -func NewKaimingUniformInit() kaimingUniformInit { - return kaimingUniformInit{} +type KaimingOptions struct { + NegativeSlope float64 + Mode string + NonLinearity string } -func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) { - var fanIn int64 - if len(dims) == 0 { - log.Fatalf("KaimingUniformInit method call: dims (%v) should have length >= 1", dims) - } else if len(dims) == 1 { - fanIn = factorial(dims[0]) - } else { - fanIn = product(dims[1:]) +type KaimingOption func(*KaimingOptions) + +func DefaultKaimingOptions() *KaimingOptions { + return &KaimingOptions{ + NegativeSlope: 0.01, + Mode: "fanIn", + NonLinearity: "leaky_relu", + } +} + +func WithKaimingMode(v string) KaimingOption { + if v != "fanIn" && v != "fanOut" { + panic("Mode must be either 'fanIn' or 'fanOut'.") + } + return func(opt *KaimingOptions) { + opt.Mode = v + } +} + +func WithKaimingNonLinearity(v string) KaimingOption { + return func(opt *KaimingOptions) { + opt.NonLinearity = v + } +} + +func WithKaimingNegativeSlope(v float64) KaimingOption { + return func(opt *KaimingOptions) { + opt.NegativeSlope = v + } +} + +func NewKaimingOptions(opts ...KaimingOption) *KaimingOptions { + options := DefaultKaimingOptions() + for _, opt := range opts { + opt(options) } - bound := math.Sqrt(1.0 / float64(fanIn)) + return options +} + +type kaimingUniformInit struct { + NegativeSlope float64 + Mode string + NonLinearity string +} + +func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit { + o := DefaultKaimingOptions() + for _, opt := range opts { + opt(o) + } + + return &kaimingUniformInit{ + NegativeSlope: o.NegativeSlope, + Mode: o.Mode, + NonLinearity: o.NonLinearity, + } +} + +func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) { + fanIn, _, err := CalculateFans(dims) + if err != nil { + panic(err) + } + + gain, err := calculateGain(k.NonLinearity, k.NegativeSlope) // default non-linearity="leaky_relu", negative_slope=0.01 + if err != nil { + err = fmt.Errorf("kaimingUniformInit.InitTensor() failed: %v\n", err) + panic(err) + } + + std := gain / math.Sqrt(float64(fanIn)) // default using fanIn + + // Calculate uniform bounds from standard deviation + bound := math.Sqrt(3.0) * std + kind := gotch.Float retVal = ts.MustZeros(dims, kind, device) retVal.Uniform_(-bound, bound) @@ -172,16 +238,22 @@ func (k kaimingUniformInit) Set(tensor *ts.Tensor) { log.Fatalf("uniformInit - Set method call error: %v\n", err) } - var fanIn int64 - if len(dims) == 0 { - log.Fatalf("KaimingUniformInit Set method call: Tensor (%v) should have length >= 1", tensor.MustSize()) - } else if len(dims) == 1 { - fanIn = factorial(dims[0]) - } else { - fanIn = product(dims[1:]) + fanIn, _, err := CalculateFans(dims) + if err != nil { + panic(err) } - bound := math.Sqrt(1.0 / float64(fanIn)) + gain, err := calculateGain(k.NonLinearity, k.NegativeSlope) // default non-linearity="leaky_relu", negative_slope=0.01 + if err != nil { + err = fmt.Errorf("kaimingUniformInit.Set() failed: %v\n", err) + panic(err) + } + + std := gain / math.Sqrt(float64(fanIn)) // default using fanIn + + // Calculate uniform bounds from standard deviation + bound := math.Sqrt(3.0) * std + tensor.Uniform_(-bound, bound) } @@ -202,3 +274,76 @@ func (gl glorotNInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts. func (gl glorotNInit) Set(tensor *ts.Tensor) { // TODO: implement } + +// KaimingUniform: +// =============== +// Base on Pytorch: +// https://github.com/pytorch/pytorch/blob/98f40af7e3133e042454efab668a842c4d01176e/torch/nn/init.py#L284 +func calculateFan(shape []int64) (fan map[string]int64, err error) { + if len(shape) < 2 { + err = fmt.Errorf("calculateFan() failed: fan in and fan out can not be computed for tensor with fewer than 2 dimensions") + return + } + + fan = make(map[string]int64) + + numInputFmap := shape[1] + numOutputFmap := shape[0] + var receptiveFieldSize int64 = 1 + if len(shape) > 2 { + // calculate product + for _, s := range shape[2:] { + receptiveFieldSize *= int64(s) + } + } + + fan["fanIn"] = numInputFmap * receptiveFieldSize + fan["fanOut"] = numOutputFmap * receptiveFieldSize + + return fan, nil +} + +// CalculateFans calculates fan-in and fan-out based on tensor shape. +func CalculateFans(shape []int64) (fanIn, fanOut int64, err error) { + fan, err := calculateFan(shape) + return fan["fanIn"], fan["fanOut"], err +} + +// Return the recommended gain value for the given nonlinearity function. +// Default fn should be `leaky_relu` +func calculateGain(fn string, paramOpt ...float64) (float64, error) { + linearFns := []string{"linear", "conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d"} + + negativeSlope := 0.01 + if len(paramOpt) > 0 { + negativeSlope = paramOpt[0] + } + + fn = strings.ToLower(fn) + if contains(linearFns, fn) || fn == "sigmoid" { + return 1, nil + } + + switch fn { + case "tanh": + return 5.0 / 3.0, nil + case "relu": + return math.Sqrt(2.0), nil + case "leaky_relu": // default fn + return math.Sqrt(2.0 / (1 + math.Pow(negativeSlope, 2))), nil + case "selu": + return 3.0 / 4, nil // Value found empirically (https://github.com/pytorch/pytorch/pull/50664) + default: + err := fmt.Errorf("calculateGain() failed: unsupported non-linearity function %q\n", fn) + return -1, err + } +} + +func contains(items []string, item string) bool { + for _, i := range items { + if item == i { + return true + } + } + return false +} diff --git a/nn/linear.go b/nn/linear.go index 28452b6..9551536 100644 --- a/nn/linear.go +++ b/nn/linear.go @@ -3,6 +3,7 @@ package nn // linear is a fully-connected layer import ( + "fmt" "math" "github.com/sugarme/gotch" @@ -19,8 +20,9 @@ type LinearConfig struct { // DefaultLinearConfig creates default LinearConfig with // weights initiated using KaimingUniform and Bias is set to true func DefaultLinearConfig() *LinearConfig { + negSlope := math.Sqrt(5) return &LinearConfig{ - WsInit: NewKaimingUniformInit(), + WsInit: NewKaimingUniformInit(WithKaimingNegativeSlope(negSlope)), BsInit: nil, Bias: true, } @@ -38,7 +40,6 @@ type Linear struct { // outDim - output dimension (y) [output features - columns] // NOTE: w will have shape{outDim, inDim}; b will have shape{outDim} func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear { - var bs *ts.Tensor // bs has size of output dimension switch c.Bias { @@ -47,7 +48,16 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear { case true: switch { case c.BsInit == nil: - bound := 1.0 / math.Sqrt(float64(inDim)) + shape := []int64{inDim, outDim} + fanIn, _, err := CalculateFans(shape) + if err != nil { + err := fmt.Errorf("NewLinear() initiate bias failed: %v", err) + panic(err) + } + bound := 0.0 + if fanIn > 0 { + bound = 1 / math.Sqrt(float64(fanIn)) + } bsInit := NewUniformInit(-bound, bound) bs = vs.MustNewVar("bias", []int64{outDim}, bsInit) case c.BsInit != nil: