2023-09-26 20:15:28 +01:00
package models_train
import (
"errors"
"fmt"
2023-09-27 21:20:39 +01:00
"io"
2023-10-19 11:42:38 +01:00
"math"
2023-09-27 21:20:39 +01:00
"os"
"os/exec"
"path"
2024-04-19 15:39:51 +01:00
"runtime/debug"
2023-10-22 23:02:39 +01:00
"sort"
2023-09-27 21:20:39 +01:00
"strconv"
2024-03-09 10:52:08 +00:00
"strings"
2023-09-27 21:20:39 +01:00
"text/template"
2023-09-26 20:15:28 +01:00
2024-04-17 17:46:43 +01:00
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
2024-04-14 14:51:16 +01:00
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
2024-04-22 00:09:07 +01:00
2024-04-19 15:39:51 +01:00
my_torch "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
modelloader "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/modelloader"
2024-04-22 00:09:07 +01:00
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
2024-04-15 23:04:53 +01:00
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
2023-09-26 20:15:28 +01:00
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
2024-04-15 23:04:53 +01:00
2024-04-22 00:09:07 +01:00
"git.andr3h3nriqu3s.com/andr3/gotch"
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
2024-02-02 16:16:26 +00:00
"github.com/charmbracelet/log"
2024-04-15 23:04:53 +01:00
"github.com/goccy/go-json"
2023-09-26 20:15:28 +01:00
)
2023-10-25 14:50:58 +01:00
const EPOCH_PER_RUN = 20
2023-10-19 10:44:13 +01:00
const MAX_EPOCH = 100
2024-03-09 10:52:08 +00:00
func shapeToSize ( shape string ) string {
split := strings . Split ( shape , "," )
return strings . Join ( split [ : len ( split ) - 1 ] , "," )
}
func getDir ( ) string {
dir , err := os . Getwd ( )
if err != nil {
panic ( err )
}
return dir
}
2024-04-17 17:46:43 +01:00
func MakeLayer ( db db . Db , def_id string , layer_order int , layer_type LayerType , shape string ) ( err error ) {
2023-09-26 20:15:28 +01:00
_ , err = db . Exec ( "insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)" , def_id , layer_order , layer_type , shape )
return
}
2024-04-17 17:46:43 +01:00
func MakeLayerExpandable ( db db . Db , def_id string , layer_order int , layer_type LayerType , shape string , exp_type int ) ( err error ) {
2024-01-31 21:48:35 +00:00
_ , err = db . Exec ( "insert into model_definition_layer (def_id, layer_order, layer_type, shape, exp_type) values ($1, $2, $3, $4, $5)" , def_id , layer_order , layer_type , shape , exp_type )
return
}
2024-04-15 23:04:53 +01:00
func setModelClassStatus ( c BasePack , status ModelClassStatus , filter string , args ... any ) ( err error ) {
_ , err = c . GetDb ( ) . Exec ( fmt . Sprintf ( "update model_classes set status=%d where %s" , status , filter ) , args ... )
2024-03-06 23:33:54 +00:00
return
}
2024-04-15 23:04:53 +01:00
func generateCvsExp ( c BasePack , run_path string , model_id string , doPanic bool ) ( count int , err error ) {
db := c . GetDb ( )
2024-03-06 23:33:54 +00:00
2024-04-08 14:17:13 +01:00
var co struct {
Count int ` db:"count(*)" `
2024-03-06 23:33:54 +00:00
}
2024-04-19 15:39:51 +01:00
err = GetDBOnce ( db , & co , "model_classes where model_id=$1 and status=$2;" , model_id , CLASS_STATUS_TRAINING )
2024-04-08 14:17:13 +01:00
if err != nil {
2024-03-06 23:33:54 +00:00
return
}
2024-04-08 14:17:13 +01:00
count = co . Count
2024-03-06 23:33:54 +00:00
if count == 0 {
2024-04-19 15:39:51 +01:00
err = setModelClassStatus ( c , CLASS_STATUS_TRAINING , "model_id=$1 and status=$2;" , model_id , CLASS_STATUS_TO_TRAIN )
2024-03-06 23:33:54 +00:00
if err != nil {
return
}
if doPanic {
return 0 , errors . New ( "No model classes available" )
}
return generateCvsExp ( c , run_path , model_id , true )
}
2024-04-19 15:39:51 +01:00
data , err := db . Query ( "select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;" , model_id , DATA_POINT_MODE_TRAINING , CLASS_STATUS_TRAINING )
2024-03-06 23:33:54 +00:00
if err != nil {
return
}
defer data . Close ( )
f , err := os . Create ( path . Join ( run_path , "train.csv" ) )
if err != nil {
return
}
defer f . Close ( )
f . Write ( [ ] byte ( "Id,Index\n" ) )
for data . Next ( ) {
var id string
var class_order int
var file_path string
if err = data . Scan ( & id , & class_order , & file_path ) ; err != nil {
return
}
if file_path == "id://" {
f . Write ( [ ] byte ( id + "," + strconv . Itoa ( class_order ) + "\n" ) )
} else {
return count , errors . New ( "TODO generateCvs to file_path " + file_path )
}
}
return
}
2024-04-19 15:39:51 +01:00
func trainDefinition ( c BasePack , m * BaseModel , def * Definition , in_model * my_torch . ContainerModel , classes [ ] * ModelClass ) ( accuracy float64 , model * my_torch . ContainerModel , err error ) {
log := c . GetLogger ( )
2024-04-15 23:04:53 +01:00
db := c . GetDb ( )
2024-04-19 15:39:51 +01:00
log . Warn ( "About to start training definition" )
2023-09-27 21:20:39 +01:00
2024-04-19 15:39:51 +01:00
model = in_model
accuracy = 0
2023-09-27 21:20:39 +01:00
2024-04-19 15:39:51 +01:00
if model == nil {
var layers [ ] * Layer
layers , err = def . GetLayers ( db , " order by layer_order asc" )
if err != nil {
2023-09-27 21:20:39 +01:00
return
}
2024-04-19 15:39:51 +01:00
model = my_torch . BuildModel ( layers , 0 , true )
2024-04-22 00:09:07 +01:00
2023-09-27 21:20:39 +01:00
}
2024-04-19 15:39:51 +01:00
// TODO Make the runner provide this
2024-04-22 00:09:07 +01:00
device := gotch . CudaIfAvailable ( )
// device := gotch.CPU
2023-09-27 21:20:39 +01:00
2024-04-19 15:39:51 +01:00
result_path := path . Join ( getDir ( ) , "savedData" , m . Id , "defs" , def . Id )
err = os . MkdirAll ( result_path , os . ModePerm )
2023-09-27 21:20:39 +01:00
if err != nil {
return
}
2024-04-19 15:39:51 +01:00
model . To ( device )
defer model . To ( gotch . CPU )
var ds * modelloader . Dataset
ds , err = modelloader . NewDataset ( db , m , classes [ 0 ] . ClassOrder , classes [ len ( classes ) - 1 ] . ClassOrder )
2023-10-10 12:28:49 +01:00
if err != nil {
return
}
2023-10-02 21:15:31 +01:00
2024-04-22 00:09:07 +01:00
opt , err := my_nn . DefaultAdamConfig ( ) . Build ( model . Vs , 0.001 )
2023-09-27 21:20:39 +01:00
if err != nil {
return
}
2024-04-19 15:39:51 +01:00
for epoch := 0 ; epoch < EPOCH_PER_RUN ; epoch ++ {
2024-04-22 00:09:07 +01:00
var trainIter * torch . Iter2
trainIter , err = ds . TrainIter ( 32 )
if err != nil {
return
}
// trainIter.ToDevice(device)
2023-10-19 10:44:13 +01:00
2024-04-19 15:39:51 +01:00
log . Info ( "epoch" , "epoch" , epoch )
2023-09-27 21:20:39 +01:00
2024-04-19 15:39:51 +01:00
var trainLoss float64 = 0
var trainCorrect float64 = 0
ok := true
for ok {
var item torch . Iter2Item
var loss * torch . Tensor
item , ok = trainIter . Next ( )
if ! ok {
continue
}
2023-09-27 21:20:39 +01:00
2024-04-22 00:09:07 +01:00
data := item . Data
2024-04-23 00:14:35 +01:00
data , err = data . ToDevice ( device , gotch . Float , true , true , false )
2024-04-22 00:09:07 +01:00
if err != nil {
return
}
2023-09-27 21:20:39 +01:00
2024-04-22 00:09:07 +01:00
data , err = data . SetRequiresGrad ( true , true )
if err != nil {
return
}
err = data . RetainGrad ( false )
if err != nil {
return
}
2023-10-19 10:44:13 +01:00
2024-04-23 00:14:35 +01:00
var size [ ] int64
size , err = data . Size ( )
if err != nil {
return
}
var ones * torch . Tensor
ones , err = torch . Ones ( size , gotch . Float , device )
if err != nil {
return
}
ones , err = ones . SetRequiresGrad ( true , true )
if err != nil {
return
}
err = ones . RetainGrad ( false )
if err != nil {
return
}
//pred := model.ForwardT(data, true)
pred := model . ForwardT ( ones , true )
2024-04-22 00:09:07 +01:00
pred , err = pred . SetRequiresGrad ( true , true )
2024-04-19 15:39:51 +01:00
if err != nil {
return
}
2023-09-27 21:20:39 +01:00
2024-04-23 00:14:35 +01:00
err = pred . RetainGrad ( false )
if err != nil {
return
}
2024-04-22 00:09:07 +01:00
label := item . Label
label , err = label . ToDevice ( device , gotch . Float , false , true , false )
if err != nil {
return
}
label , err = label . SetRequiresGrad ( true , true )
if err != nil {
return
}
2024-04-23 00:14:35 +01:00
err = label . RetainGrad ( false )
if err != nil {
return
}
2024-04-22 00:09:07 +01:00
// Calculate loss
2024-04-23 00:14:35 +01:00
loss , err = pred . BinaryCrossEntropyWithLogits ( label , & torch . Tensor { } , & torch . Tensor { } , 2 , false )
2024-04-22 00:09:07 +01:00
if err != nil {
return
}
2024-04-19 15:39:51 +01:00
loss , err = loss . SetRequiresGrad ( true , false )
2024-04-22 00:09:07 +01:00
if err != nil {
return
}
2024-04-23 00:14:35 +01:00
err = loss . RetainGrad ( false )
if err != nil {
return
}
2024-04-19 15:39:51 +01:00
err = opt . ZeroGrad ( )
if err != nil {
return
}
err = loss . Backward ( )
if err != nil {
return
}
2024-04-23 00:14:35 +01:00
log . Info ( "pred grad" , "pred" , pred . MustGrad ( false ) . MustMax ( false ) . Float64Values ( ) )
log . Info ( "pred grad" , "ones" , ones . MustGrad ( false ) . MustMax ( false ) . Float64Values ( ) , "lol" , ones . MustRetainsGrad ( false ) )
log . Info ( "pred grad" , "data" , data . MustGrad ( false ) . MustMax ( false ) . Float64Values ( ) , "lol" , data . MustRetainsGrad ( false ) )
log . Info ( "pred grad" , "outs" , label . MustGrad ( false ) . MustMax ( false ) . Float64Values ( ) )
2024-04-19 15:39:51 +01:00
2024-04-22 00:09:07 +01:00
vars := model . Vs . Variables ( )
for k , v := range vars {
2024-04-23 00:14:35 +01:00
log . Info ( "[grad check]" , "k" , k , "grad" , v . MustGrad ( false ) . MustMax ( false ) . Float64Values ( ) , "lol" , v . MustRetainsGrad ( false ) )
}
2024-04-22 00:09:07 +01:00
2024-04-23 00:14:35 +01:00
model . Debug ( )
2024-04-22 00:09:07 +01:00
2024-04-23 00:14:35 +01:00
err = opt . Step ( )
if err != nil {
return
2024-04-22 00:09:07 +01:00
}
2024-04-19 15:39:51 +01:00
trainLoss = loss . Float64Values ( ) [ 0 ]
// Calculate accuracy
2024-04-22 00:09:07 +01:00
/ * var p_pred , p_labels * torch . Tensor
2024-04-19 15:39:51 +01:00
p_pred , err = pred . Argmax ( [ ] int64 { 1 } , true , false )
if err != nil {
return
}
p_labels , err = item . Label . Argmax ( [ ] int64 { 1 } , true , false )
if err != nil {
return
}
floats := p_pred . Float64Values ( )
floats_labels := p_labels . Float64Values ( )
for i := range floats {
if floats [ i ] == floats_labels [ i ] {
trainCorrect += 1
}
2024-04-22 00:09:07 +01:00
} * /
2024-04-23 00:14:35 +01:00
panic ( "fornow" )
2024-04-19 15:39:51 +01:00
}
2024-04-22 00:09:07 +01:00
//v := []float64{}
2024-04-19 15:39:51 +01:00
log . Info ( "model training epoch done loss" , "loss" , trainLoss , "correct" , trainCorrect , "out" , ds . TrainImagesSize , "accuracy" , trainCorrect / float64 ( ds . TrainImagesSize ) )
/ * correct := int64 ( 0 )
//torch.NoGrad(func() {
ok = true
testIter := ds . TestIter ( 64 )
for ok {
var item torch . Iter2Item
item , ok = testIter . Next ( )
if ! ok {
continue
}
output := model . Forward ( item . Data )
var pred , labels * torch . Tensor
pred , err = output . Argmax ( [ ] int64 { 1 } , true , false )
if err != nil {
return
}
labels , err = item . Label . Argmax ( [ ] int64 { 1 } , true , false )
if err != nil {
return
}
floats := pred . Float64Values ( )
floats_labels := labels . Float64Values ( )
for i := range floats {
if floats [ i ] == floats_labels [ i ] {
correct += 1
}
}
}
accuracy = float64 ( correct ) / float64 ( ds . TestImagesSize )
log . Info ( "Eval accuracy" , "accuracy" , accuracy )
err = def . UpdateAfterEpoch ( db , accuracy * 100 )
if err != nil {
return
} * /
//})
2023-09-27 21:20:39 +01:00
}
2024-04-19 15:39:51 +01:00
err = my_torch . SaveModel ( model , path . Join ( result_path , "model.dat" ) )
2023-09-27 21:20:39 +01:00
if err != nil {
return
}
2023-10-10 12:28:49 +01:00
2024-04-19 15:39:51 +01:00
log . Info ( "Model finished training!" , "accuracy" , accuracy )
2023-09-27 21:20:39 +01:00
return
}
2024-04-16 17:48:52 +01:00
func generateCvsExpandExp ( c BasePack , run_path string , model_id string , offset int , doPanic bool ) ( count_re int , err error ) {
l , db := c . GetLogger ( ) , c . GetDb ( )
2024-04-08 14:17:13 +01:00
var co struct {
Count int ` db:"count(*)" `
}
2024-04-19 15:39:51 +01:00
err = GetDBOnce ( db , & co , "model_classes where model_id=$1 and status=$2;" , model_id , CLASS_STATUS_TRAINING )
2024-03-06 23:33:54 +00:00
if err != nil {
2024-04-08 14:17:13 +01:00
return
2024-03-06 23:33:54 +00:00
}
2024-04-16 17:48:52 +01:00
l . Info ( "test here" , "count" , co )
2024-04-08 14:17:13 +01:00
count_re = co . Count
2024-04-08 15:47:31 +01:00
count := co . Count
2024-04-08 14:17:13 +01:00
if count == 0 {
2024-04-19 15:39:51 +01:00
err = setModelClassStatus ( c , CLASS_STATUS_TRAINING , "model_id=$1 and status=$2;" , model_id , CLASS_STATUS_TO_TRAIN )
2024-04-08 14:17:13 +01:00
if err != nil {
return
} else if doPanic {
return 0 , errors . New ( "No model classes available" )
}
return generateCvsExpandExp ( c , run_path , model_id , offset , true )
}
2024-04-19 15:39:51 +01:00
data , err := db . Query ( "select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;" , model_id , DATA_POINT_MODE_TRAINING , CLASS_STATUS_TRAINING )
2024-04-08 14:17:13 +01:00
if err != nil {
return
}
defer data . Close ( )
f , err := os . Create ( path . Join ( run_path , "train.csv" ) )
if err != nil {
return
}
defer f . Close ( )
f . Write ( [ ] byte ( "Id,Index\n" ) )
count = 0
for data . Next ( ) {
var id string
var class_order int
var file_path string
if err = data . Scan ( & id , & class_order , & file_path ) ; err != nil {
return
}
if file_path == "id://" {
f . Write ( [ ] byte ( id + "," + strconv . Itoa ( class_order - offset ) + "\n" ) )
} else {
return count , errors . New ( "TODO generateCvs to file_path " + file_path )
}
count += 1
}
//
// This is to load some extra data so that the model has more things to train on
//
2024-04-19 15:39:51 +01:00
data_other , err := db . Query ( "select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;" , model_id , DATA_POINT_MODE_TRAINING , CLASS_STATUS_TRAINED , count * 10 )
2024-04-08 14:17:13 +01:00
if err != nil {
return
}
defer data_other . Close ( )
for data_other . Next ( ) {
var id string
var class_order int
var file_path string
if err = data_other . Scan ( & id , & class_order , & file_path ) ; err != nil {
return
}
if file_path == "id://" {
2024-04-14 16:51:15 +01:00
f . Write ( [ ] byte ( id + "," + strconv . Itoa ( - 2 ) + "\n" ) )
2024-04-08 14:17:13 +01:00
} else {
return count , errors . New ( "TODO generateCvs to file_path " + file_path )
}
}
return
2024-03-06 23:33:54 +00:00
}
2024-04-16 17:48:52 +01:00
func trainDefinitionExpandExp ( c BasePack , model * BaseModel , definition_id string , load_prev bool ) ( accuracy float64 , err error ) {
2024-02-02 16:16:26 +00:00
accuracy = 0
2024-04-16 17:48:52 +01:00
l := c . GetLogger ( )
l . Warn ( "About to retrain model" )
2024-02-02 16:16:26 +00:00
// Get untrained models heads
2024-04-08 14:17:13 +01:00
type ExpHead struct {
Id string
Start int ` db:"range_start" `
End int ` db:"range_end" `
}
// status = 2 (INIT) 3 (TRAINING)
2024-04-16 17:48:52 +01:00
heads , err := GetDbMultitple [ ExpHead ] ( c . GetDb ( ) , "exp_model_head where def_id=$1 and (status = 2 or status = 3)" , definition_id )
2024-02-02 16:16:26 +00:00
if err != nil {
return
2024-04-08 14:17:13 +01:00
} else if len ( heads ) == 0 {
log . Error ( "Failed to get the exp head of the model" )
return
} else if len ( heads ) != 1 {
err = errors . New ( "This training function can only train one model at the time" )
return
2024-02-02 16:16:26 +00:00
}
2024-04-08 14:17:13 +01:00
exp := heads [ 0 ]
2024-04-16 17:48:52 +01:00
l . Info ( "Got exp head" , "head" , exp )
2024-04-08 14:17:13 +01:00
2024-04-19 15:39:51 +01:00
if err = UpdateStatus ( c . GetDb ( ) , "exp_model_head" , exp . Id , DEFINITION_STATUS_TRAINING ) ; err != nil {
2024-04-08 14:17:13 +01:00
return
2024-02-02 16:16:26 +00:00
}
2024-04-16 17:48:52 +01:00
layers , err := c . GetDb ( ) . Query ( "select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;" , definition_id )
2024-04-08 14:17:13 +01:00
if err != nil {
return
}
defer layers . Close ( )
2024-02-02 16:16:26 +00:00
2024-04-08 14:17:13 +01:00
type layerrow struct {
LayerType int
Shape string
ExpType int
LayerNum int
}
got := [ ] layerrow { }
i := 1
var last * layerrow = nil
got_2 := false
2024-04-08 15:47:31 +01:00
var first * layerrow = nil
2024-04-08 14:17:13 +01:00
for layers . Next ( ) {
var row = layerrow { }
if err = layers . Scan ( & row . LayerType , & row . Shape , & row . ExpType ) ; err != nil {
2024-02-03 12:39:22 +00:00
return
}
2024-04-08 14:17:13 +01:00
2024-04-08 15:47:31 +01:00
// Keep track of the first layer so we can keep the size of the image
if first == nil {
first = & row
}
2024-04-08 14:17:13 +01:00
row . LayerNum = i
row . Shape = shapeToSize ( row . Shape )
if row . ExpType == 2 {
if ! got_2 {
got = append ( got , * last )
got_2 = true
}
got = append ( got , row )
}
last = & row
i += 1
}
got = append ( got , layerrow {
LayerType : LAYER_DENSE ,
Shape : fmt . Sprintf ( "%d" , exp . End - exp . Start + 1 ) ,
ExpType : 2 ,
LayerNum : i ,
} )
2024-04-16 17:48:52 +01:00
l . Info ( "Got layers" , "layers" , got )
2024-04-08 14:17:13 +01:00
// Generate run folder
run_path := path . Join ( "/tmp" , model . Id + "-defs-" + definition_id + "-retrain" )
err = os . MkdirAll ( run_path , os . ModePerm )
if err != nil {
return
}
classCount , err := generateCvsExpandExp ( c , run_path , model . Id , exp . Start , false )
if err != nil {
return
}
2024-04-16 17:48:52 +01:00
l . Info ( "Generated cvs" , "classCount" , classCount )
2024-04-08 14:17:13 +01:00
// TODO update the run script
// Create python script
f , err := os . Create ( path . Join ( run_path , "run.py" ) )
if err != nil {
return
}
defer f . Close ( )
2024-04-16 17:48:52 +01:00
l . Info ( "About to run python!" )
2024-04-08 14:17:13 +01:00
tmpl , err := template . New ( "python_model_template_expand.py" ) . ParseFiles ( "views/py/python_model_template_expand.py" )
if err != nil {
return
}
// Copy result around
result_path := path . Join ( "savedData" , model . Id , "defs" , definition_id )
if err = tmpl . Execute ( f , AnyMap {
"Layers" : got ,
"Size" : first . Shape ,
"DataDir" : path . Join ( getDir ( ) , "savedData" , model . Id , "data" ) ,
"HeadId" : exp . Id ,
"RunPath" : run_path ,
"ColorMode" : model . ImageMode ,
"Model" : model ,
"EPOCH_PER_RUN" : EPOCH_PER_RUN ,
"LoadPrev" : load_prev ,
"BaseModel" : path . Join ( getDir ( ) , result_path , "base" , "model.keras" ) ,
"LastModelRunPath" : path . Join ( getDir ( ) , result_path , "head" , exp . Id , "model.keras" ) ,
"SaveModelPath" : path . Join ( getDir ( ) , result_path , "head" , exp . Id ) ,
"Depth" : classCount ,
"StartPoint" : 0 ,
2024-04-16 17:48:52 +01:00
"Host" : c . GetHost ( ) ,
2024-04-08 14:17:13 +01:00
} ) ; err != nil {
return
}
// Run the command
out , err := exec . Command ( "bash" , "-c" , fmt . Sprintf ( "cd %s && python run.py" , run_path ) ) . CombinedOutput ( )
if err != nil {
2024-04-16 17:48:52 +01:00
l . Warn ( "Python failed to run" , "err" , err , "out" , string ( out ) )
2024-04-08 14:17:13 +01:00
return
}
2024-04-16 17:48:52 +01:00
l . Info ( "Python finished running" )
2024-04-08 14:17:13 +01:00
if err = os . MkdirAll ( result_path , os . ModePerm ) ; err != nil {
return
}
accuracy_file , err := os . Open ( path . Join ( run_path , "accuracy.val" ) )
if err != nil {
return
}
defer accuracy_file . Close ( )
accuracy_file_bytes , err := io . ReadAll ( accuracy_file )
if err != nil {
return
}
accuracy , err = strconv . ParseFloat ( string ( accuracy_file_bytes ) , 64 )
if err != nil {
2024-02-03 12:39:22 +00:00
return
}
2024-04-15 23:04:53 +01:00
2024-04-08 14:17:13 +01:00
os . RemoveAll ( run_path )
2024-04-16 17:48:52 +01:00
l . Info ( "Model finished training!" , "accuracy" , accuracy )
2024-04-08 14:17:13 +01:00
return
}
2024-04-15 23:04:53 +01:00
func trainDefinitionExp ( c BasePack , model * BaseModel , definition_id string , load_prev bool ) ( accuracy float64 , err error ) {
2024-04-08 14:17:13 +01:00
accuracy = 0
2024-04-15 23:04:53 +01:00
l := c . GetLogger ( )
db := c . GetDb ( )
2024-04-08 14:17:13 +01:00
2024-04-15 23:04:53 +01:00
l . Warn ( "About to start training definition" )
2024-04-08 14:17:13 +01:00
// Get untrained models heads
type ExpHead struct {
Id string
Start int ` db:"range_start" `
End int ` db:"range_end" `
}
// status = 2 (INIT) 3 (TRAINING)
2024-04-15 23:04:53 +01:00
heads , err := GetDbMultitple [ ExpHead ] ( db , "exp_model_head where def_id=$1 and (status = 2 or status = 3)" , definition_id )
2024-04-08 14:17:13 +01:00
if err != nil {
return
} else if len ( heads ) == 0 {
log . Error ( "Failed to get the exp head of the model" )
return
} else if len ( heads ) != 1 {
2024-02-03 12:39:22 +00:00
log . Error ( "This training function can only train one model at the time" )
err = errors . New ( "This training function can only train one model at the time" )
return
}
2024-02-02 16:16:26 +00:00
2024-04-08 14:17:13 +01:00
exp := heads [ 0 ]
2024-04-19 15:39:51 +01:00
if err = UpdateStatus ( db , "exp_model_head" , exp . Id , DEFINITION_STATUS_TRAINING ) ; err != nil {
2024-04-08 14:17:13 +01:00
return
}
2024-02-05 16:42:23 +00:00
2024-04-15 23:04:53 +01:00
layers , err := db . Query ( "select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;" , definition_id )
2024-02-02 16:16:26 +00:00
if err != nil {
return
}
defer layers . Close ( )
type layerrow struct {
LayerType int
Shape string
2024-02-03 12:39:22 +00:00
ExpType int
LayerNum int
2024-02-02 16:16:26 +00:00
}
got := [ ] layerrow { }
2024-02-03 12:39:22 +00:00
i := 1
2024-02-02 16:16:26 +00:00
for layers . Next ( ) {
var row = layerrow { }
2024-02-03 12:39:22 +00:00
if err = layers . Scan ( & row . LayerType , & row . Shape , & row . ExpType ) ; err != nil {
2024-02-02 16:16:26 +00:00
return
}
2024-02-03 12:39:22 +00:00
row . LayerNum = i
2024-02-02 16:16:26 +00:00
row . Shape = shapeToSize ( row . Shape )
got = append ( got , row )
2024-02-03 12:39:22 +00:00
i += 1
2024-02-02 16:16:26 +00:00
}
2024-02-03 12:39:22 +00:00
got = append ( got , layerrow {
LayerType : LAYER_DENSE ,
2024-04-08 14:17:13 +01:00
Shape : fmt . Sprintf ( "%d" , exp . End - exp . Start + 1 ) ,
2024-02-03 12:39:22 +00:00
ExpType : 2 ,
LayerNum : i ,
} )
2024-02-02 16:16:26 +00:00
// Generate run folder
2024-04-08 14:17:13 +01:00
run_path := path . Join ( "/tmp" , model . Id + "-defs-" + definition_id )
2024-02-02 16:16:26 +00:00
err = os . MkdirAll ( run_path , os . ModePerm )
if err != nil {
return
}
2024-03-06 23:33:54 +00:00
classCount , err := generateCvsExp ( c , run_path , model . Id , false )
2024-02-02 16:16:26 +00:00
if err != nil {
return
}
2024-02-03 12:39:22 +00:00
// TODO update the run script
2024-02-02 16:16:26 +00:00
// Create python script
f , err := os . Create ( path . Join ( run_path , "run.py" ) )
if err != nil {
return
}
defer f . Close ( )
2024-02-05 16:42:23 +00:00
tmpl , err := template . New ( "python_model_template.py" ) . ParseFiles ( "views/py/python_model_template.py" )
2024-02-02 16:16:26 +00:00
if err != nil {
return
}
// Copy result around
result_path := path . Join ( "savedData" , model . Id , "defs" , definition_id )
if err = tmpl . Execute ( f , AnyMap {
"Layers" : got ,
"Size" : got [ 0 ] . Shape ,
"DataDir" : path . Join ( getDir ( ) , "savedData" , model . Id , "data" ) ,
2024-04-08 14:17:13 +01:00
"HeadId" : exp . Id ,
2024-02-02 16:16:26 +00:00
"RunPath" : run_path ,
"ColorMode" : model . ImageMode ,
"Model" : model ,
"EPOCH_PER_RUN" : EPOCH_PER_RUN ,
"LoadPrev" : load_prev ,
"LastModelRunPath" : path . Join ( getDir ( ) , result_path , "model.keras" ) ,
"SaveModelPath" : path . Join ( getDir ( ) , result_path ) ,
2024-03-06 23:33:54 +00:00
"Depth" : classCount ,
"StartPoint" : 0 ,
2024-04-15 23:04:53 +01:00
"Host" : c . GetHost ( ) ,
2024-02-02 16:16:26 +00:00
} ) ; err != nil {
return
}
// Run the command
out , err := exec . Command ( "bash" , "-c" , fmt . Sprintf ( "cd %s && python run.py" , run_path ) ) . CombinedOutput ( )
if err != nil {
2024-04-15 23:04:53 +01:00
l . Debug ( string ( out ) )
2024-02-02 16:16:26 +00:00
return
}
2024-04-15 23:04:53 +01:00
l . Info ( "Python finished running" )
2024-02-02 16:16:26 +00:00
if err = os . MkdirAll ( result_path , os . ModePerm ) ; err != nil {
return
}
accuracy_file , err := os . Open ( path . Join ( run_path , "accuracy.val" ) )
if err != nil {
return
}
defer accuracy_file . Close ( )
accuracy_file_bytes , err := io . ReadAll ( accuracy_file )
if err != nil {
return
}
accuracy , err = strconv . ParseFloat ( string ( accuracy_file_bytes ) , 64 )
if err != nil {
return
}
2024-04-08 14:17:13 +01:00
os . RemoveAll ( run_path )
2024-04-15 23:04:53 +01:00
l . Info ( "Model finished training!" , "accuracy" , accuracy )
2024-02-02 16:16:26 +00:00
return
}
2023-10-19 10:44:13 +01:00
func remove [ T interface { } ] ( lst [ ] T , i int ) [ ] T {
lng := len ( lst )
if i >= lng {
return [ ] T { }
}
if i + 1 >= lng {
return lst [ : lng - 1 ]
}
if i == 0 {
return lst [ 1 : ]
}
return append ( lst [ : i ] , lst [ i + 1 : ] ... )
}
2023-10-22 23:02:39 +01:00
type TrainModelRow struct {
id string
target_accuracy int
epoch int
acuracy float64
}
type TraingModelRowDefinitions [ ] TrainModelRow
func ( nf TraingModelRowDefinitions ) Len ( ) int { return len ( nf ) }
func ( nf TraingModelRowDefinitions ) Swap ( i , j int ) { nf [ i ] , nf [ j ] = nf [ j ] , nf [ i ] }
func ( nf TraingModelRowDefinitions ) Less ( i , j int ) bool {
return nf [ i ] . acuracy < nf [ j ] . acuracy
}
type ToRemoveList [ ] int
func ( nf ToRemoveList ) Len ( ) int { return len ( nf ) }
func ( nf ToRemoveList ) Swap ( i , j int ) { nf [ i ] , nf [ j ] = nf [ j ] , nf [ i ] }
func ( nf ToRemoveList ) Less ( i , j int ) bool {
return nf [ i ] < nf [ j ]
}
2024-04-15 23:04:53 +01:00
func trainModel ( c BasePack , model * BaseModel ) ( err error ) {
db := c . GetDb ( )
2024-04-19 15:39:51 +01:00
log := c . GetLogger ( )
2024-03-09 09:41:16 +00:00
2024-04-19 15:39:51 +01:00
fail := func ( err error ) {
log . Error ( "Failed to train Model!" , "err" , err , "stack" , string ( debug . Stack ( ) ) )
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
}
2024-04-19 15:39:51 +01:00
defs , err := model . GetDefinitions ( db , "and md.status=$2" , DEFINITION_STATUS_INIT )
if err != nil {
fail ( err )
return
2023-09-27 21:20:39 +01:00
}
2024-04-19 15:39:51 +01:00
var definitions SortByAccuracyDefinitions = defs
2023-09-27 21:20:39 +01:00
if len ( definitions ) == 0 {
2024-04-19 15:39:51 +01:00
fail ( errors . New ( "No definitons defined!" ) )
2023-09-27 21:20:39 +01:00
return
}
2023-10-22 23:02:39 +01:00
finished := false
2024-04-19 15:39:51 +01:00
models := map [ string ] * my_torch . ContainerModel { }
classes , err := model . GetClasses ( db , " and status=$2 order by mc.class_order asc" , CLASS_STATUS_TO_TRAIN )
2023-10-19 10:44:13 +01:00
for {
2024-04-19 15:39:51 +01:00
// Keep track of definitions that did not train fast enough
2023-10-22 23:02:39 +01:00
var toRemove ToRemoveList = [ ] int { }
2024-04-19 15:39:51 +01:00
2023-10-19 10:44:13 +01:00
for i , def := range definitions {
2024-04-19 15:39:51 +01:00
err := def . UpdateStatus ( c , DEFINITION_STATUS_TRAINING )
2023-10-19 10:44:13 +01:00
if err != nil {
2024-04-19 15:39:51 +01:00
log . Error ( "Could not make model into training" , "err" , err )
def . UpdateStatus ( c , DEFINITION_STATUS_FAILED_TRAINING )
2023-10-22 23:02:39 +01:00
toRemove = append ( toRemove , i )
2023-10-19 10:44:13 +01:00
continue
}
2023-10-24 22:35:11 +01:00
2024-04-19 15:39:51 +01:00
accuracy , ml_model , err := trainDefinition ( c , model , def , models [ def . Id ] , classes )
if err != nil {
log . Error ( "Failed to train definition!Err:" , "err" , err )
def . UpdateStatus ( c , DEFINITION_STATUS_FAILED_TRAINING )
toRemove = append ( toRemove , i )
continue
}
models [ def . Id ] = ml_model
2023-10-19 10:44:13 +01:00
2024-04-19 15:39:51 +01:00
if accuracy >= float64 ( def . TargetAccuracy ) {
log . Info ( "Found a definition that reaches target_accuracy!" )
_ , err = db . Exec ( "update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4" , accuracy , DEFINITION_STATUS_TRANIED , def . Epoch , def . Id )
2023-10-19 10:44:13 +01:00
if err != nil {
2024-04-19 15:39:51 +01:00
log . Error ( "Failed to train definition!Err:\n" , "err" , err )
2023-10-19 10:44:13 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2024-04-15 23:04:53 +01:00
return err
2023-10-19 10:44:13 +01:00
}
2024-04-19 15:39:51 +01:00
_ , err = db . Exec ( "update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4" , DEFINITION_STATUS_CANCELD_TRAINING , def . Id , model . Id , DEFINITION_STATUS_FAILED_TRAINING )
2023-10-19 10:44:13 +01:00
if err != nil {
2024-04-19 15:39:51 +01:00
log . Error ( "Failed to train definition!Err:\n" , "err" , err )
2023-10-19 10:44:13 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2024-04-15 23:04:53 +01:00
return err
2023-10-19 10:44:13 +01:00
}
2023-10-22 23:02:39 +01:00
finished = true
2023-10-19 10:44:13 +01:00
break
}
2023-09-27 21:20:39 +01:00
2024-04-19 15:39:51 +01:00
if def . Epoch > MAX_EPOCH {
fmt . Printf ( "Failed to train definition! Accuracy less %f < %d\n" , accuracy , def . TargetAccuracy )
def . UpdateStatus ( c , DEFINITION_STATUS_FAILED_TRAINING )
2023-10-22 23:02:39 +01:00
toRemove = append ( toRemove , i )
2023-10-19 10:44:13 +01:00
continue
}
2023-09-27 21:20:39 +01:00
2024-04-19 15:39:51 +01:00
_ , err = db . Exec ( "update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4" , accuracy , def . Epoch , DEFINITION_STATUS_PAUSED_TRAINING , def . Id )
2023-10-22 23:02:39 +01:00
if err != nil {
2024-04-19 15:39:51 +01:00
log . Error ( "Failed to train definition!Err:\n" , "err" , err )
2023-10-22 23:02:39 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2024-04-15 23:04:53 +01:00
return err
2023-10-22 23:02:39 +01:00
}
2023-09-27 21:20:39 +01:00
}
2023-10-22 23:02:39 +01:00
if finished {
break
}
2023-10-25 14:50:58 +01:00
sort . Sort ( sort . Reverse ( toRemove ) )
2023-10-22 23:02:39 +01:00
2024-04-19 15:39:51 +01:00
log . Info ( "Round done" , "toRemove" , toRemove )
2023-10-22 23:02:39 +01:00
for _ , n := range toRemove {
2024-04-19 15:39:51 +01:00
// Clean up unsed models
models [ definitions [ n ] . Id ] = nil
2023-10-22 23:02:39 +01:00
definitions = remove ( definitions , n )
}
len_def := len ( definitions )
if len_def == 0 {
2023-10-19 10:44:13 +01:00
break
2023-09-27 21:20:39 +01:00
}
2023-10-22 23:02:39 +01:00
if len_def == 1 {
continue
}
2023-10-25 14:50:58 +01:00
sort . Sort ( sort . Reverse ( definitions ) )
2023-10-22 23:02:39 +01:00
2024-04-19 15:39:51 +01:00
acc := definitions [ 0 ] . Accuracy - 20.0
2023-10-22 23:02:39 +01:00
2024-04-19 15:39:51 +01:00
log . Info ( "Training models, Highest acc" , "acc" , definitions [ 0 ] . Accuracy , "mod_acc" , acc )
2023-10-22 23:02:39 +01:00
toRemove = [ ] int { }
for i , def := range definitions {
2024-04-19 15:39:51 +01:00
if def . Accuracy < acc {
2023-10-22 23:02:39 +01:00
toRemove = append ( toRemove , i )
}
}
2024-04-19 15:39:51 +01:00
log . Info ( "Removing due to accuracy" , "toRemove" , toRemove )
2023-10-22 23:02:39 +01:00
2023-10-25 14:50:58 +01:00
sort . Sort ( sort . Reverse ( toRemove ) )
2023-10-22 23:02:39 +01:00
for _ , n := range toRemove {
2024-04-19 15:39:51 +01:00
log . Warn ( "Removing definition not fast enough learning" , "n" , n )
definitions [ n ] . UpdateStatus ( c , DEFINITION_STATUS_FAILED_TRAINING )
models [ definitions [ n ] . Id ] = nil
2023-10-22 23:02:39 +01:00
definitions = remove ( definitions , n )
}
2023-09-27 21:20:39 +01:00
}
2024-04-19 15:39:51 +01:00
var def Definition
err = GetDBOnce ( c , & def , "model_definition as md where md.model_id=$1 and md.status=$2 order by md.accuracy desc limit 1;" , model . Id , DEFINITION_STATUS_TRANIED )
2023-09-27 21:20:39 +01:00
if err != nil {
2024-04-19 15:39:51 +01:00
if err == NotFoundError {
log . Error ( "All definitions failed to train!" )
} else {
log . Error ( "DB: failed to read definition" , "err" , err )
}
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
2023-09-27 18:07:04 +01:00
2024-04-19 15:39:51 +01:00
if err = def . UpdateStatus ( c , DEFINITION_STATUS_READY ) ; err != nil {
log . Error ( "Failed to update model definition" , "err" , err )
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
2023-09-27 21:20:39 +01:00
2024-04-19 15:39:51 +01:00
to_delete , err := db . Query ( "select id from model_definition where status != $1 and model_id=$2" , DEFINITION_STATUS_READY , model . Id )
2023-10-10 12:28:49 +01:00
if err != nil {
2024-04-19 15:39:51 +01:00
log . Error ( "Failed to select model_definition to delete" )
log . Error ( err )
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
defer to_delete . Close ( )
for to_delete . Next ( ) {
var id string
2024-03-06 23:33:54 +00:00
if err = to_delete . Scan ( & id ) ; err != nil {
2024-04-19 15:39:51 +01:00
log . Error ( "Failed to scan the id of a model_definition to delete" , "err" , err )
2023-10-10 12:28:49 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
return
}
os . RemoveAll ( path . Join ( "savedData" , model . Id , "defs" , id ) )
}
// TODO Check if returning also works here
2024-04-19 15:39:51 +01:00
if _ , err = db . Exec ( "delete from model_definition where status!=$1 and model_id=$2;" , DEFINITION_STATUS_READY , model . Id ) ; err != nil {
log . Error ( "Failed to delete model_definition" )
log . Error ( err )
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
2023-09-27 21:20:39 +01:00
return
2023-10-10 12:28:49 +01:00
}
2023-09-27 18:07:04 +01:00
2023-10-06 12:13:19 +01:00
ModelUpdateStatus ( c , model . Id , READY )
2024-04-15 23:04:53 +01:00
return
2023-09-27 18:07:04 +01:00
}
2024-03-06 23:33:54 +00:00
type TrainModelRowUsable struct {
Id string
TargetAccuracy int ` db:"target_accuracy" `
Epoch int
Acuracy float64 ` db:"0" `
}
type TrainModelRowUsables [ ] * TrainModelRowUsable
func ( nf TrainModelRowUsables ) Len ( ) int { return len ( nf ) }
func ( nf TrainModelRowUsables ) Swap ( i , j int ) { nf [ i ] , nf [ j ] = nf [ j ] , nf [ i ] }
func ( nf TrainModelRowUsables ) Less ( i , j int ) bool {
return nf [ i ] . Acuracy < nf [ j ] . Acuracy
}
2024-04-15 23:04:53 +01:00
func trainModelExp ( c BasePack , model * BaseModel ) ( err error ) {
l := c . GetLogger ( )
db := c . GetDb ( )
2024-01-31 21:48:35 +00:00
2024-03-09 10:52:08 +00:00
var definitions TrainModelRowUsables
2024-01-31 21:48:35 +00:00
2024-04-19 15:39:51 +01:00
definitions , err = GetDbMultitple [ TrainModelRowUsable ] ( db , "model_definition where status=$1 and model_id=$2" , DEFINITION_STATUS_INIT , model . Id )
2024-03-09 10:52:08 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to get definitions" )
2024-03-09 10:52:08 +00:00
return
}
2024-01-31 21:48:35 +00:00
if len ( definitions ) == 0 {
2024-04-15 23:04:53 +01:00
l . Error ( "No Definitions defined!" )
return errors . New ( "No Definitions found" )
2024-01-31 21:48:35 +00:00
}
firstRound := true
finished := false
for {
var toRemove ToRemoveList = [ ] int { }
for i , def := range definitions {
2024-04-19 15:39:51 +01:00
Definition { Id : def . Id } . UpdateStatus ( c , DEFINITION_STATUS_TRAINING )
2024-03-06 23:33:54 +00:00
accuracy , err := trainDefinitionExp ( c , model , def . Id , ! firstRound )
2024-01-31 21:48:35 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!Err:" , "err" , err )
2024-04-19 15:39:51 +01:00
Definition { Id : def . Id } . UpdateStatus ( c , DEFINITION_STATUS_TRAINING )
2024-01-31 21:48:35 +00:00
toRemove = append ( toRemove , i )
continue
}
2024-03-06 23:33:54 +00:00
def . Epoch += EPOCH_PER_RUN
2024-01-31 21:48:35 +00:00
accuracy = accuracy * 100
2024-03-06 23:33:54 +00:00
def . Acuracy = float64 ( accuracy )
2024-01-31 21:48:35 +00:00
2024-03-06 23:33:54 +00:00
definitions [ i ] . Epoch += EPOCH_PER_RUN
definitions [ i ] . Acuracy = accuracy
2024-01-31 21:48:35 +00:00
2024-03-06 23:33:54 +00:00
if accuracy >= float64 ( def . TargetAccuracy ) {
2024-04-15 23:04:53 +01:00
l . Info ( "Found a definition that reaches target_accuracy!" )
2024-04-19 15:39:51 +01:00
_ , err = db . Exec ( "update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4" , accuracy , DEFINITION_STATUS_TRANIED , def . Epoch , def . Id )
2024-01-31 21:48:35 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!" )
return err
2024-01-31 21:48:35 +00:00
}
2024-04-19 15:39:51 +01:00
_ , err = db . Exec ( "update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4" , DEFINITION_STATUS_CANCELD_TRAINING , def . Id , model . Id , DEFINITION_STATUS_FAILED_TRAINING )
2024-01-31 21:48:35 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!" )
return err
2024-01-31 21:48:35 +00:00
}
2024-04-15 23:04:53 +01:00
_ , err = db . Exec ( "update exp_model_head set status=$1 where def_id=$2;" , MODEL_HEAD_STATUS_READY , def . Id )
2024-04-08 14:17:13 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!" )
return err
2024-04-08 14:17:13 +01:00
}
2024-01-31 21:48:35 +00:00
finished = true
break
}
2024-03-06 23:33:54 +00:00
if def . Epoch > MAX_EPOCH {
fmt . Printf ( "Failed to train definition! Accuracy less %f < %d\n" , accuracy , def . TargetAccuracy )
2024-04-19 15:39:51 +01:00
Definition { Id : def . Id } . UpdateStatus ( c , DEFINITION_STATUS_FAILED_TRAINING )
2024-01-31 21:48:35 +00:00
toRemove = append ( toRemove , i )
continue
}
2024-04-19 15:39:51 +01:00
_ , err = db . Exec ( "update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4" , accuracy , def . Epoch , DEFINITION_STATUS_PAUSED_TRAINING , def . Id )
2024-01-31 21:48:35 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to train definition!" )
return err
2024-01-31 21:48:35 +00:00
}
}
firstRound = false
if finished {
break
}
sort . Sort ( sort . Reverse ( toRemove ) )
2024-04-15 23:04:53 +01:00
l . Info ( "Round done" , "toRemove" , toRemove )
2024-01-31 21:48:35 +00:00
for _ , n := range toRemove {
definitions = remove ( definitions , n )
}
len_def := len ( definitions )
if len_def == 0 {
break
2024-03-09 09:41:16 +00:00
} else if len_def == 1 {
2024-01-31 21:48:35 +00:00
continue
}
sort . Sort ( sort . Reverse ( definitions ) )
2024-03-06 23:33:54 +00:00
acc := definitions [ 0 ] . Acuracy - 20.0
2024-01-31 21:48:35 +00:00
2024-04-15 23:04:53 +01:00
l . Info ( "Training models, Highest acc" , "acc" , definitions [ 0 ] . Acuracy , "mod_acc" , acc )
2024-01-31 21:48:35 +00:00
toRemove = [ ] int { }
for i , def := range definitions {
2024-03-06 23:33:54 +00:00
if def . Acuracy < acc {
2024-01-31 21:48:35 +00:00
toRemove = append ( toRemove , i )
}
}
2024-04-15 23:04:53 +01:00
l . Info ( "Removing due to accuracy" , "toRemove" , toRemove )
2024-01-31 21:48:35 +00:00
sort . Sort ( sort . Reverse ( toRemove ) )
for _ , n := range toRemove {
2024-04-15 23:04:53 +01:00
l . Warn ( "Removing definition not fast enough learning" , "n" , n )
2024-04-19 15:39:51 +01:00
Definition { Id : definitions [ n ] . Id } . UpdateStatus ( c , DEFINITION_STATUS_FAILED_TRAINING )
2024-01-31 21:48:35 +00:00
definitions = remove ( definitions , n )
}
}
2024-03-09 10:52:08 +00:00
var dat JustId
2024-04-19 15:39:51 +01:00
err = GetDBOnce ( db , & dat , "model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;" , model . Id , DEFINITION_STATUS_TRANIED )
2024-03-09 10:52:08 +00:00
if err == NotFoundError {
2024-04-08 14:17:13 +01:00
// Set the class status to trained
2024-04-19 15:39:51 +01:00
err = setModelClassStatus ( c , CLASS_STATUS_TO_TRAIN , "model_id=$1 and status=$2;" , model . Id , CLASS_STATUS_TRAINING )
2024-04-08 14:17:13 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "All definitions failed to train! And Failed to set class status" )
return err
2024-04-08 14:17:13 +01:00
}
2024-04-15 23:04:53 +01:00
l . Error ( "All definitions failed to train!" )
return err
2024-03-09 10:52:08 +00:00
} else if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "All definitions failed to train!" )
return err
2024-03-09 10:52:08 +00:00
}
2024-01-31 21:48:35 +00:00
2024-04-19 15:39:51 +01:00
if _ , err = db . Exec ( "update model_definition set status=$1 where id=$2;" , DEFINITION_STATUS_READY , dat . Id ) ; err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to update model definition" )
return err
2024-01-31 21:48:35 +00:00
}
2024-04-19 15:39:51 +01:00
to_delete , err := GetDbMultitple [ JustId ] ( db , "model_definition where status!=$1 and model_id=$2" , DEFINITION_STATUS_READY , model . Id )
2024-03-09 10:52:08 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to select model_definition to delete" )
return err
2024-03-09 10:52:08 +00:00
}
2024-01-31 21:48:35 +00:00
2024-03-09 10:52:08 +00:00
for _ , d := range to_delete {
2024-03-09 09:41:16 +00:00
os . RemoveAll ( path . Join ( "savedData" , model . Id , "defs" , d . Id ) )
2024-03-09 10:52:08 +00:00
}
2024-01-31 21:48:35 +00:00
// TODO Check if returning also works here
2024-04-19 15:39:51 +01:00
if _ , err = db . Exec ( "delete from model_definition where status!=$1 and model_id=$2;" , DEFINITION_STATUS_READY , model . Id ) ; err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to delete model_definition" )
return err
2024-03-02 12:45:49 +00:00
}
if err = splitModel ( c , model ) ; err != nil {
2024-04-19 15:39:51 +01:00
err = setModelClassStatus ( c , CLASS_STATUS_TO_TRAIN , "model_id=$1 and status=$2;" , model . Id , CLASS_STATUS_TRAINING )
2024-04-08 14:17:13 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to split the model! And Failed to set class status" )
return err
2024-04-08 14:17:13 +01:00
}
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to split the model" )
return err
2024-01-31 21:48:35 +00:00
}
2024-04-08 14:17:13 +01:00
// Set the class status to trained
2024-04-19 15:39:51 +01:00
err = setModelClassStatus ( c , CLASS_STATUS_TRAINED , "model_id=$1 and status=$2;" , model . Id , CLASS_STATUS_TRAINING )
2024-04-08 14:17:13 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
l . Error ( "Failed to set class status" )
return err
2024-04-08 14:17:13 +01:00
}
2024-03-02 12:45:49 +00:00
// There should only be one def availabale
def := JustId { }
2024-04-15 23:04:53 +01:00
if err = GetDBOnce ( db , & def , "model_definition where model_id=$1" , model . Id ) ; err != nil {
2024-02-12 14:30:43 +00:00
return
2024-03-02 12:45:49 +00:00
}
2024-02-14 15:11:45 +00:00
2024-03-02 12:45:49 +00:00
// Remove the base model
2024-04-15 23:04:53 +01:00
l . Warn ( "Removing base model for" , "model" , model . Id , "def" , def . Id )
2024-02-14 15:11:45 +00:00
os . RemoveAll ( path . Join ( "savedData" , model . Id , "defs" , def . Id , "model" ) )
os . RemoveAll ( path . Join ( "savedData" , model . Id , "defs" , def . Id , "model.keras" ) )
2024-02-12 14:30:43 +00:00
2024-01-31 21:48:35 +00:00
ModelUpdateStatus ( c , model . Id , READY )
2024-04-15 23:04:53 +01:00
return
2024-01-31 21:48:35 +00:00
}
2024-04-15 23:04:53 +01:00
func splitModel ( c BasePack , model * BaseModel ) ( err error ) {
db := c . GetDb ( )
l := c . GetLogger ( )
2024-02-12 14:30:43 +00:00
2024-03-02 12:45:49 +00:00
def := JustId { }
2024-04-15 23:04:53 +01:00
if err = GetDBOnce ( db , & def , "model_definition where model_id=$1" , model . Id ) ; err != nil {
2024-03-02 12:45:49 +00:00
return
}
2024-02-12 14:30:43 +00:00
2024-03-02 12:45:49 +00:00
head := JustId { }
2024-04-15 23:04:53 +01:00
if err = GetDBOnce ( db , & head , "exp_model_head where def_id=$1" , def . Id ) ; err != nil {
2024-03-02 12:45:49 +00:00
return
}
2024-02-12 14:30:43 +00:00
// Generate run folder
2024-04-08 14:17:13 +01:00
run_path := path . Join ( "/tmp" , model . Id + "-defs-" + def . Id + "-split" )
2024-02-12 14:30:43 +00:00
err = os . MkdirAll ( run_path , os . ModePerm )
if err != nil {
return
}
// Create python script
f , err := os . Create ( path . Join ( run_path , "run.py" ) )
if err != nil {
return
}
defer f . Close ( )
tmpl , err := template . New ( "python_split_model_template.py" ) . ParseFiles ( "views/py/python_split_model_template.py" )
if err != nil {
return
}
// Copy result around
result_path := path . Join ( getDir ( ) , "savedData" , model . Id , "defs" , def . Id )
2024-03-02 12:45:49 +00:00
// TODO maybe move this to a select count(*)
// Get only fixed lawers
2024-04-15 23:04:53 +01:00
layers , err := db . Query ( "select exp_type from model_definition_layer where def_id=$1 and exp_type=$2 order by layer_order asc;" , def . Id , 1 )
2024-02-12 14:30:43 +00:00
if err != nil {
return
}
defer layers . Close ( )
type layerrow struct {
2024-03-02 12:45:49 +00:00
ExpType int
2024-02-12 14:30:43 +00:00
}
2024-03-02 12:45:49 +00:00
count := - 1
2024-02-12 14:30:43 +00:00
for layers . Next ( ) {
2024-03-02 12:45:49 +00:00
count += 1
2024-02-12 14:30:43 +00:00
}
2024-03-02 12:45:49 +00:00
if count == - 1 {
err = errors . New ( "Can not get layers" )
return
}
2024-02-12 14:30:43 +00:00
2024-03-02 12:45:49 +00:00
log . Warn ( "Spliting model" , "def" , def . Id , "head" , head . Id , "count" , count )
2024-02-12 14:30:43 +00:00
2024-03-02 12:45:49 +00:00
basePath := path . Join ( result_path , "base" )
headPath := path . Join ( result_path , "head" , head . Id )
2024-02-12 14:30:43 +00:00
if err = os . MkdirAll ( basePath , os . ModePerm ) ; err != nil {
return
}
if err = os . MkdirAll ( headPath , os . ModePerm ) ; err != nil {
return
}
if err = tmpl . Execute ( f , AnyMap {
2024-03-02 12:45:49 +00:00
"SplitLen" : count ,
"ModelPath" : path . Join ( result_path , "model.keras" ) ,
"BaseModelPath" : basePath ,
"HeadModelPath" : headPath ,
2024-02-12 14:30:43 +00:00
} ) ; err != nil {
return
}
out , err := exec . Command ( "bash" , "-c" , fmt . Sprintf ( "cd %s && python run.py" , run_path ) ) . CombinedOutput ( )
if err != nil {
2024-04-15 23:04:53 +01:00
l . Debug ( string ( out ) )
2024-02-12 14:30:43 +00:00
return
}
2024-04-08 14:17:13 +01:00
os . RemoveAll ( run_path )
2024-04-15 23:04:53 +01:00
l . Info ( "Python finished running" )
2024-03-02 12:45:49 +00:00
return
2024-02-12 14:30:43 +00:00
}
2023-10-19 11:42:38 +01:00
// This generates a definition
2024-04-15 23:04:53 +01:00
func generateDefinition ( c BasePack , model * BaseModel , target_accuracy int , number_of_classes int , complexity int ) ( err error ) {
failed := func ( ) {
2023-10-19 11:42:38 +01:00
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
2023-10-21 00:26:52 +01:00
}
2023-10-19 11:42:38 +01:00
2024-04-15 23:04:53 +01:00
db := c . GetDb ( )
l := c . GetLogger ( )
2024-04-19 15:39:51 +01:00
def , err := MakeDefenition ( db , model . Id , target_accuracy )
2023-10-21 00:26:52 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2023-10-21 00:26:52 +01:00
}
2023-10-19 11:42:38 +01:00
2023-10-22 23:02:39 +01:00
order := 1
2023-10-21 00:26:52 +01:00
2023-10-22 23:02:39 +01:00
// Note the shape of the first layer defines the import size
2024-04-19 15:39:51 +01:00
//_, err = def.MakeLayer(db, order, LAYER_INPUT, ShapeToString(model.Width, model.Height, model.ImageMode))
_ , err = def . MakeLayer ( db , order , LAYER_INPUT , ShapeToString ( 3 , model . Width , model . Height ) )
if err != nil {
failed ( )
return
2023-10-21 00:26:52 +01:00
}
2024-04-19 15:39:51 +01:00
order ++
2023-10-19 11:42:38 +01:00
2023-10-21 00:26:52 +01:00
if complexity == 0 {
2024-04-19 15:39:51 +01:00
_ , err = def . MakeLayer ( db , order , LAYER_FLATTEN , "" )
2023-10-19 11:42:38 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2023-10-19 11:42:38 +01:00
}
2023-10-22 23:02:39 +01:00
order ++
2023-10-19 11:42:38 +01:00
2023-10-21 12:01:10 +01:00
loop := int ( math . Log2 ( float64 ( number_of_classes ) ) )
2023-10-21 00:26:52 +01:00
for i := 0 ; i < loop ; i ++ {
2024-04-19 15:39:51 +01:00
_ , err = def . MakeLayer ( db , order , LAYER_DENSE , ShapeToString ( number_of_classes * ( loop - i ) ) )
2023-10-22 23:02:39 +01:00
order ++
2023-10-21 00:26:52 +01:00
if err != nil {
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
2024-04-15 23:04:53 +01:00
return
2023-10-21 00:26:52 +01:00
}
}
2024-01-31 21:48:35 +00:00
} else if complexity == 1 || complexity == 2 {
2024-04-19 15:39:51 +01:00
loop := max ( 1 , int ( ( math . Log ( float64 ( model . Width ) ) / math . Log ( float64 ( 10 ) ) ) ) )
2023-10-21 00:26:52 +01:00
for i := 0 ; i < loop ; i ++ {
2024-04-19 15:39:51 +01:00
_ , err = def . MakeLayer ( db , order , LAYER_SIMPLE_BLOCK , "" )
2023-10-22 23:02:39 +01:00
order ++
2023-10-21 00:26:52 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2023-10-21 00:26:52 +01:00
}
}
2024-04-19 15:39:51 +01:00
_ , err = def . MakeLayer ( db , order , LAYER_FLATTEN , "" )
2023-10-21 00:26:52 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2023-10-21 00:26:52 +01:00
}
2023-10-22 23:02:39 +01:00
order ++
2023-10-21 00:26:52 +01:00
2023-10-22 23:02:39 +01:00
loop = int ( ( math . Log ( float64 ( number_of_classes ) ) / math . Log ( float64 ( 10 ) ) ) / 2 )
if loop == 0 {
loop = 1
}
2023-10-21 00:26:52 +01:00
for i := 0 ; i < loop ; i ++ {
2024-04-19 15:39:51 +01:00
_ , err = def . MakeLayer ( db , order , LAYER_DENSE , ShapeToString ( number_of_classes * ( loop - i ) ) )
2023-10-22 23:02:39 +01:00
order ++
2023-10-21 00:26:52 +01:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2023-10-22 23:02:39 +01:00
}
}
2023-10-21 00:26:52 +01:00
} else {
2024-04-19 15:39:51 +01:00
l . Error ( "Unkown complexity" , "complexity" , complexity )
2024-04-15 23:04:53 +01:00
failed ( )
return
2023-10-21 00:26:52 +01:00
}
2023-10-19 11:42:38 +01:00
2024-04-19 15:39:51 +01:00
return def . UpdateStatus ( db , DEFINITION_STATUS_INIT )
2023-10-19 11:42:38 +01:00
}
2024-04-15 23:04:53 +01:00
func generateDefinitions ( c BasePack , model * BaseModel , target_accuracy int , number_of_models int ) ( err error ) {
2024-04-19 15:39:51 +01:00
cls , err := model . GetClasses ( c , "" )
2023-10-19 11:42:38 +01:00
if err != nil {
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
2024-04-15 23:04:53 +01:00
return
2023-10-19 11:42:38 +01:00
}
2023-10-22 23:02:39 +01:00
cls_len := len ( cls )
2023-10-19 11:42:38 +01:00
2023-10-22 23:02:39 +01:00
if number_of_models == 1 {
if model . Width < 100 && model . Height < 100 && cls_len < 30 {
2024-04-19 15:39:51 +01:00
err = generateDefinition ( c , model , target_accuracy , cls_len , 0 )
2023-10-22 23:02:39 +01:00
} else if model . Width > 100 && model . Height > 100 {
2024-04-19 15:39:51 +01:00
err = generateDefinition ( c , model , target_accuracy , cls_len , 2 )
2023-10-22 23:02:39 +01:00
} else {
2024-04-19 15:39:51 +01:00
err = generateDefinition ( c , model , target_accuracy , cls_len , 1 )
2023-10-22 23:02:39 +01:00
}
2024-04-19 15:39:51 +01:00
if err != nil {
return
2023-10-22 23:02:39 +01:00
}
} else {
for i := 0 ; i < number_of_models ; i ++ {
2024-04-19 15:39:51 +01:00
err = generateDefinition ( c , model , target_accuracy , cls_len , min ( i , 2 ) )
if err != nil {
return
}
2023-10-22 23:02:39 +01:00
}
}
2023-10-21 00:26:52 +01:00
return nil
2023-10-19 11:42:38 +01:00
}
2024-04-19 15:39:51 +01:00
func ExpModelHeadUpdateStatus ( db db . Db , id string , status DefinitionStatus ) ( err error ) {
2024-01-31 21:48:35 +00:00
_ , err = db . Exec ( "update model_definition set status = $1 where id = $2" , status , id )
return
}
// This generates a definition
2024-04-15 23:04:53 +01:00
func generateExpandableDefinition ( c BasePack , model * BaseModel , target_accuracy int , number_of_classes int , complexity int ) ( err error ) {
l := c . GetLogger ( )
db := c . GetDb ( )
l . Info ( "Generating expandable new definition for model" , "id" , model . Id , "complexity" , complexity )
2024-02-08 18:20:58 +00:00
2024-04-15 23:04:53 +01:00
failed := func ( ) {
2024-01-31 21:48:35 +00:00
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
}
if complexity == 0 {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-01-31 21:48:35 +00:00
}
2024-04-19 15:39:51 +01:00
def , err := MakeDefenition ( c . GetDb ( ) , model . Id , target_accuracy )
2024-01-31 21:48:35 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-01-31 21:48:35 +00:00
}
order := 1
2024-02-02 16:16:26 +00:00
width := model . Width
height := model . Height
2024-01-31 21:48:35 +00:00
// Note the shape of the first layer defines the import size
if complexity == 2 {
// Note the shape for now is no used
width := int ( math . Pow ( 2 , math . Floor ( math . Log ( float64 ( model . Width ) ) / math . Log ( 2.0 ) ) ) )
height := int ( math . Pow ( 2 , math . Floor ( math . Log ( float64 ( model . Height ) ) / math . Log ( 2.0 ) ) ) )
2024-04-15 23:04:53 +01:00
l . Warn ( "Complexity 2 creating model with smaller size" , "width" , width , "height" , height )
2024-02-02 16:16:26 +00:00
}
2024-01-31 21:48:35 +00:00
2024-04-19 15:39:51 +01:00
err = MakeLayerExpandable ( c . GetDb ( ) , def . Id , order , LAYER_INPUT , fmt . Sprintf ( "%d,%d,1" , width , height ) , 1 )
2024-01-31 21:48:35 +00:00
order ++
2024-02-02 16:16:26 +00:00
// handle the errors inside the pervious if block
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-02-02 16:16:26 +00:00
}
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
// Create the blocks
loop := int ( ( math . Log ( float64 ( model . Width ) ) / math . Log ( float64 ( 10 ) ) ) )
2024-02-08 18:20:58 +00:00
2024-04-18 15:01:36 +01:00
/ * if model . Width < 50 && model . Height < 50 {
loop = 0
} * /
2024-02-08 18:20:58 +00:00
2024-03-02 12:45:49 +00:00
log . Info ( "Size of the simple block" , "loop" , loop )
2024-02-08 18:20:58 +00:00
2024-03-02 12:45:49 +00:00
//loop = max(loop, 3)
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
for i := 0 ; i < loop ; i ++ {
2024-04-19 15:39:51 +01:00
err = MakeLayerExpandable ( db , def . Id , order , LAYER_SIMPLE_BLOCK , "" , 1 )
2024-02-02 16:16:26 +00:00
order ++
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-02-02 16:16:26 +00:00
}
}
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
// Flatten the blocks into dense
2024-04-19 15:39:51 +01:00
err = MakeLayerExpandable ( db , def . Id , order , LAYER_FLATTEN , "" , 1 )
2024-02-02 16:16:26 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-02-02 16:16:26 +00:00
}
order ++
2024-01-31 21:48:35 +00:00
2024-02-02 16:16:26 +00:00
// Flatten the blocks into dense
2024-04-19 15:39:51 +01:00
err = MakeLayerExpandable ( db , def . Id , order , LAYER_DENSE , fmt . Sprintf ( "%d,1" , number_of_classes * 2 ) , 1 )
2024-02-02 16:16:26 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-02-02 16:16:26 +00:00
}
order ++
loop = int ( ( math . Log ( float64 ( number_of_classes ) ) / math . Log ( float64 ( 10 ) ) ) / 2 )
2024-02-08 18:20:58 +00:00
2024-03-02 12:45:49 +00:00
log . Info ( "Size of the dense layers" , "loop" , loop )
2024-02-08 18:20:58 +00:00
2024-03-02 12:45:49 +00:00
// loop = max(loop, 3)
2024-02-02 16:16:26 +00:00
for i := 0 ; i < loop ; i ++ {
2024-04-19 15:39:51 +01:00
err = MakeLayer ( db , def . Id , order , LAYER_DENSE , fmt . Sprintf ( "%d,1" , number_of_classes * ( loop - i ) ) )
2024-02-02 16:16:26 +00:00
order ++
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-02-02 16:16:26 +00:00
}
}
2024-03-02 12:45:49 +00:00
2024-04-15 23:04:53 +01:00
var newHead = struct {
2024-04-19 15:39:51 +01:00
DefId string ` db:"def_id" `
RangeStart int ` db:"range_start" `
RangeEnd int ` db:"range_end" `
Status DefinitionStatus ` db:"status" `
2024-04-15 23:04:53 +01:00
} {
2024-04-19 15:39:51 +01:00
def . Id , 0 , number_of_classes - 1 , DEFINITION_STATUS_INIT ,
2024-04-15 23:04:53 +01:00
}
_ , err = InsertReturnId ( c . GetDb ( ) , & newHead , "exp_model_head" , "id" )
2024-02-02 16:16:26 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-02-02 16:16:26 +00:00
}
2024-01-31 21:48:35 +00:00
2024-04-19 15:39:51 +01:00
err = def . UpdateStatus ( c , DEFINITION_STATUS_INIT )
2024-01-31 21:48:35 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
failed ( )
return
2024-01-31 21:48:35 +00:00
}
2024-04-15 23:04:53 +01:00
return
2024-01-31 21:48:35 +00:00
}
2024-03-02 12:45:49 +00:00
// TODO make this json friendy
2024-04-15 23:04:53 +01:00
func generateExpandableDefinitions ( c BasePack , model * BaseModel , target_accuracy int , number_of_models int ) ( err error ) {
2024-04-19 15:39:51 +01:00
cls , err := model . GetClasses ( c , "" )
2024-01-31 21:48:35 +00:00
if err != nil {
ModelUpdateStatus ( c , model . Id , FAILED_PREPARING_TRAINING )
2024-04-15 23:04:53 +01:00
return
2024-01-31 21:48:35 +00:00
}
cls_len := len ( cls )
if number_of_models == 1 {
if model . Width > 100 && model . Height > 100 {
generateExpandableDefinition ( c , model , target_accuracy , cls_len , 2 )
} else {
2024-04-18 15:01:36 +01:00
generateExpandableDefinition ( c , model , target_accuracy , cls_len , 2 )
2024-01-31 21:48:35 +00:00
}
} else if number_of_models == 3 {
for i := 0 ; i < number_of_models ; i ++ {
generateExpandableDefinition ( c , model , target_accuracy , cls_len , i )
}
} else {
// TODO handle incrisea the complexity
for i := 0 ; i < number_of_models ; i ++ {
2024-04-18 15:01:36 +01:00
generateExpandableDefinition ( c , model , target_accuracy , cls_len , 2 )
2024-01-31 21:48:35 +00:00
}
}
return nil
}
2024-04-16 17:48:52 +01:00
func ResetClasses ( c BasePack , model * BaseModel ) {
2024-04-19 15:39:51 +01:00
_ , err := c . GetDb ( ) . Exec ( "update model_classes set status=$1 where status=$2 and model_id=$3" , CLASS_STATUS_TO_TRAIN , CLASS_STATUS_TRAINING , model . Id )
2024-04-08 14:17:13 +01:00
if err != nil {
2024-04-16 17:48:52 +01:00
c . GetLogger ( ) . Error ( "Error while reseting the classes" , "error" , err )
2024-04-08 14:17:13 +01:00
}
}
func trainExpandable ( c * Context , model * BaseModel ) {
var err error = nil
failed := func ( msg string ) {
c . Logger . Error ( msg , "err" , err )
ModelUpdateStatus ( c , model . Id , FAILED_TRAINING )
ResetClasses ( c , model )
}
var definitions TrainModelRowUsables
2024-04-19 15:39:51 +01:00
definitions , err = GetDbMultitple [ TrainModelRowUsable ] ( c , "model_definition where status=$1 and model_id=$2" , DEFINITION_STATUS_READY , model . Id )
2024-04-08 14:17:13 +01:00
if err != nil {
failed ( "Failed to get definitions" )
return
}
if len ( definitions ) != 1 {
failed ( "There should only be one definition available!" )
return
}
firstRound := true
def := definitions [ 0 ]
epoch := 0
for {
acc , err := trainDefinitionExp ( c , model , def . Id , ! firstRound )
if err != nil {
failed ( "Failed to train definition!" )
return
}
epoch += EPOCH_PER_RUN
if float64 ( acc * 100 ) >= float64 ( def . Acuracy ) {
c . Logger . Info ( "Found a definition that reaches target_accuracy!" )
_ , err = c . Db . Exec ( "update exp_model_head set status=$1 where def_id=$2 and status=$3;" , MODEL_HEAD_STATUS_READY , def . Id , MODEL_HEAD_STATUS_TRAINING )
if err != nil {
failed ( "Failed to train definition!" )
return
}
break
} else if def . Epoch > MAX_EPOCH {
failed ( fmt . Sprintf ( "Failed to train definition! Accuracy less %f < %d\n" , acc * 100 , def . TargetAccuracy ) )
return
}
}
// Set the class status to trained
2024-04-19 15:39:51 +01:00
err = setModelClassStatus ( c , CLASS_STATUS_TRAINED , "model_id=$1 and status=$2;" , model . Id , CLASS_STATUS_TRAINING )
2024-04-08 14:17:13 +01:00
if err != nil {
failed ( "Failed to set class status" )
return
}
ModelUpdateStatus ( c , model . Id , READY )
}
2024-04-15 23:04:53 +01:00
func RunTaskTrain ( b BasePack , task Task ) ( err error ) {
l := b . GetLogger ( )
2024-03-02 12:45:49 +00:00
2024-04-17 14:56:57 +01:00
model , err := GetBaseModel ( b . GetDb ( ) , * task . ModelId )
2024-04-15 23:04:53 +01:00
if err != nil {
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Failed to get model information" )
l . Error ( "Failed to get model information" , "err" , err )
return err
2024-04-19 15:39:51 +01:00
} else if model . Status != TRAINING {
2024-04-15 23:04:53 +01:00
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Model not in the correct status for training" )
return errors . New ( "Model not in the right status" )
}
task . UpdateStatusLog ( b , TASK_RUNNING , "Training model" )
2023-09-26 20:15:28 +01:00
2024-04-15 23:04:53 +01:00
var dat struct {
NumberOfModels int
Accuracy int
}
err = json . Unmarshal ( [ ] byte ( task . ExtraTaskInfo ) , & dat )
if err != nil {
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Failed to get model extra information" )
}
if model . ModelType == 2 {
2024-04-19 15:39:51 +01:00
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "TODO expandable models" )
ModelUpdateStatus ( b , model . Id , FAILED_TRAINING )
panic ( "todo" )
2024-04-15 23:04:53 +01:00
full_error := generateExpandableDefinitions ( b , model , dat . Accuracy , dat . NumberOfModels )
if full_error != nil {
l . Error ( "Failed to generate defintions" , "err" , full_error )
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Failed generate model" )
return errors . New ( "Failed to generate definitions" )
}
} else {
full_error := generateDefinitions ( b , model , dat . Accuracy , dat . NumberOfModels )
if full_error != nil {
2024-04-19 15:39:51 +01:00
l . Error ( "Failed to generate defintions" , "err" , full_error )
2024-04-15 23:04:53 +01:00
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Failed generate model" )
return errors . New ( "Failed to generate definitions" )
2024-03-09 10:52:08 +00:00
}
2024-04-15 23:04:53 +01:00
}
2024-01-31 21:48:35 +00:00
2024-04-15 23:04:53 +01:00
if model . ModelType == 2 {
err = trainModelExp ( b , model )
} else {
err = trainModel ( b , model )
}
if err != nil {
l . Error ( "Failed to train model" , "err" , err )
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Failed generate model" )
ModelUpdateStatus ( b , model . Id , FAILED_TRAINING )
return
}
task . UpdateStatusLog ( b , TASK_DONE , "Model finished training" )
return
}
2024-04-16 17:48:52 +01:00
func RunTaskRetrain ( b BasePack , task Task ) ( err error ) {
2024-04-19 15:39:51 +01:00
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "TODO retrain with torch" )
panic ( "TODO" )
2024-04-17 14:56:57 +01:00
model , err := GetBaseModel ( b . GetDb ( ) , * task . ModelId )
2024-04-16 17:48:52 +01:00
if err != nil {
return err
} else if model . Status != READY_RETRAIN {
return errors . New ( "Model in invalid status for re-training" )
}
l := b . GetLogger ( )
db := b . GetDb ( )
failed := func ( ) {
ResetClasses ( b , model )
ModelUpdateStatus ( b , model . Id , READY_RETRAIN_FAILED )
task . UpdateStatusLog ( b , TASK_FAILED_RUNNING , "Model failed retraining" )
l . Error ( "Failed to retrain" , "err" , err )
}
task . UpdateStatusLog ( b , TASK_RUNNING , "Model retraining" )
2024-04-18 15:01:36 +01:00
var defData struct {
Id string ` db:"md.id" `
TargetAcuuracy float64 ` db:"md.target_accuracy" `
}
err = GetDBOnce ( db , & defData , "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;" , task . ModelId )
2024-04-16 17:48:52 +01:00
if err != nil {
failed ( )
return
}
2024-04-18 15:01:36 +01:00
var acc float64 = 0
var epocs = 0
// TODO make max epochs come from db
for acc * 100 < defData . TargetAcuuracy && epocs < 20 {
// This is something I have to check
acc , err = trainDefinitionExpandExp ( b , model , defData . Id , epocs > 0 )
if err != nil {
failed ( )
return
}
l . Info ( "Retrained model" , "accuracy" , acc , "target" , defData . TargetAcuuracy )
epocs += 1
}
if acc * 100 < defData . TargetAcuuracy {
l . Error ( "Model never achived targetd accuracy" , "acc" , acc * 100 , "target" , defData . TargetAcuuracy )
2024-04-16 17:48:52 +01:00
failed ( )
return
}
// TODO check accuracy
err = UpdateStatus ( db , "models" , model . Id , READY )
if err != nil {
failed ( )
return
}
l . Info ( "Model updaded" )
2024-04-19 15:39:51 +01:00
_ , err = db . Exec ( "update model_classes set status=$1 where status=$2 and model_id=$3" , CLASS_STATUS_TRAINED , CLASS_STATUS_TRAINING , model . Id )
2024-04-16 17:48:52 +01:00
if err != nil {
l . Error ( "Error while updating the classes" , "error" , err )
failed ( )
return
}
task . UpdateStatusLog ( b , TASK_DONE , "Model finished retraining" )
return
}
2024-04-15 23:04:53 +01:00
func handleTrain ( handle * Handle ) {
type TrainReq struct {
Id string ` json:"id" validate:"required" `
ModelType string ` json:"model_type" `
NumberOfModels int ` json:"number_of_models" `
Accuracy int ` json:"accuracy" `
}
PostAuthJson ( handle , "/models/train" , User_Normal , func ( c * Context , dat * TrainReq ) * Error {
2024-03-09 10:52:08 +00:00
modelTypeId := 1
if dat . ModelType == "expandable" {
modelTypeId = 2
} else if dat . ModelType != "simple" {
return c . JsonBadRequest ( "Invalid model type!" )
2023-09-27 21:20:39 +01:00
}
2024-03-09 10:52:08 +00:00
model , err := GetBaseModel ( c . Db , dat . Id )
2024-04-19 15:39:51 +01:00
if err == NotFoundError {
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Model not found" )
2023-09-26 20:15:28 +01:00
} else if err != nil {
2024-04-15 23:04:53 +01:00
return c . E500M ( "Failed to get model information" , err )
2024-04-17 14:56:57 +01:00
} else if model . CanTrain == 0 {
return c . JsonBadRequest ( "Model can not be trained!" )
2023-09-26 20:15:28 +01:00
}
if model . Status != CONFIRM_PRE_TRAINING {
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Model in invalid status for training" )
2023-09-26 20:15:28 +01:00
}
2024-04-15 23:04:53 +01:00
_ , err = c . Db . Exec ( "update models set status = $1, model_type = $2 where id = $3" , TRAINING , modelTypeId , model . Id )
if err != nil {
return c . E500M ( "Failed to update model_status" , err )
2024-02-02 16:16:26 +00:00
}
2023-09-26 20:15:28 +01:00
2024-04-15 23:04:53 +01:00
text , err := json . Marshal ( struct {
NumberOfModels int
Accuracy int
} {
NumberOfModels : dat . NumberOfModels ,
Accuracy : dat . Accuracy ,
} )
if err != nil {
return c . E500M ( "Failed create data" , err )
2024-02-02 16:16:26 +00:00
}
2023-09-27 18:07:04 +01:00
2024-04-15 23:04:53 +01:00
type CreateNewTask struct {
UserId string ` db:"user_id" `
ModelId string ` db:"model_id" `
TaskType TaskType ` db:"task_type" `
Status int ` db:"status" `
ExtraTaskInfo string ` db:"extra_task_info" `
}
newTask := CreateNewTask {
UserId : c . User . Id ,
ModelId : model . Id ,
TaskType : TASK_TYPE_TRAINING ,
Status : 1 ,
ExtraTaskInfo : string ( text ) ,
}
id , err := InsertReturnId ( c , & newTask , "tasks" , "id" )
2024-02-02 16:16:26 +00:00
if err != nil {
2024-04-15 23:04:53 +01:00
return c . E500M ( "Failed to create task" , err )
2024-02-02 16:16:26 +00:00
}
2023-10-12 12:08:12 +01:00
2024-04-15 23:04:53 +01:00
return c . SendJSON ( id )
2024-03-09 10:52:08 +00:00
} )
2023-10-12 12:08:12 +01:00
2024-04-16 16:02:57 +01:00
PostAuthJson ( handle , "/model/train/retrain" , User_Normal , func ( c * Context , dat * JustId ) * Error {
model , err := GetBaseModel ( c . Db , dat . Id )
2024-04-19 15:39:51 +01:00
if err == NotFoundError {
2024-04-16 16:02:57 +01:00
return c . JsonBadRequest ( "Model not found" )
} else if err != nil {
return c . E500M ( "Faield to get model" , err )
} else if model . Status != READY && model . Status != READY_RETRAIN_FAILED && model . Status != READY_ALTERATION_FAILED {
return c . JsonBadRequest ( "Model in invalid status for re-training" )
2024-04-17 14:56:57 +01:00
} else if model . CanTrain == 0 {
return c . JsonBadRequest ( "Model can not be trained!" )
2024-04-16 16:02:57 +01:00
}
c . Logger . Info ( "Expanding definitions for models" , "id" , model . Id )
classesUpdated := false
failed := func ( ) * Error {
if classesUpdated {
ResetClasses ( c , model )
}
ModelUpdateStatus ( c , model . Id , READY_RETRAIN_FAILED )
return c . E500M ( "Failed to retrain model" , err )
}
var def struct {
Id string
TargetAccuracy int ` db:"target_accuracy" `
}
err = GetDBOnce ( c , & def , "model_definition where model_id=$1;" , model . Id )
if err != nil {
return failed ( )
}
type C struct {
Id string
ClassOrder int ` db:"class_order" `
}
err = c . StartTx ( )
if err != nil {
return failed ( )
}
classes , err := GetDbMultitple [ C ] (
c ,
"model_classes where model_id=$1 and status=$2 order by class_order asc" ,
model . Id ,
2024-04-19 15:39:51 +01:00
CLASS_STATUS_TO_TRAIN ,
2024-04-16 16:02:57 +01:00
)
if err != nil {
_err := c . RollbackTx ( )
if _err != nil {
c . Logger . Error ( "Two errors happended rollback failed" , "err" , _err )
}
return failed ( )
}
if len ( classes ) == 0 {
c . Logger . Error ( "No classes are available!" )
_err := c . RollbackTx ( )
if _err != nil {
c . Logger . Error ( "Two errors happended rollback failed" , "err" , _err )
}
return failed ( )
}
//Update the classes
{
2024-04-19 15:39:51 +01:00
_ , err = c . Exec ( "update model_classes set status=$1 where status=$2 and model_id=$3" , CLASS_STATUS_TRAINING , CLASS_STATUS_TO_TRAIN , model . Id )
2024-04-16 16:02:57 +01:00
if err != nil {
_err := c . RollbackTx ( )
if _err != nil {
c . Logger . Error ( "Two errors happended rollback failed" , "err" , _err )
}
return failed ( )
}
err = c . CommitTx ( )
if err != nil {
_err := c . RollbackTx ( )
if _err != nil {
c . Logger . Error ( "Two errors happended rollback failed" , "err" , _err )
}
return failed ( )
}
classesUpdated = true
}
var newHead = struct {
2024-04-19 15:39:51 +01:00
DefId string ` db:"def_id" `
RangeStart int ` db:"range_start" `
RangeEnd int ` db:"range_end" `
Status DefinitionStatus ` db:"status" `
2024-04-16 16:02:57 +01:00
} {
2024-04-19 15:39:51 +01:00
def . Id , classes [ 0 ] . ClassOrder , classes [ len ( classes ) - 1 ] . ClassOrder , DEFINITION_STATUS_INIT ,
2024-04-16 16:02:57 +01:00
}
_ , err = InsertReturnId ( c . GetDb ( ) , & newHead , "exp_model_head" , "id" )
if err != nil {
return failed ( )
}
_ , err = c . Db . Exec ( "update models set status=$1 where id=$2;" , READY_RETRAIN , model . Id )
if err != nil {
2024-04-16 17:48:52 +01:00
return c . E500M ( "Failed to update model status" , err )
}
newTask := struct {
UserId string ` db:"user_id" `
ModelId string ` db:"model_id" `
TaskType TaskType ` db:"task_type" `
Status int ` db:"status" `
} {
UserId : c . User . Id ,
ModelId : model . Id ,
TaskType : TASK_TYPE_RETRAINING ,
Status : 1 ,
}
id , err := InsertReturnId ( c , & newTask , "tasks" , "id" )
if err != nil {
return c . E500M ( "Failed to create task" , err )
2024-04-16 16:02:57 +01:00
}
2024-04-16 17:48:52 +01:00
return c . SendJSON ( JustId { Id : id } )
2024-04-16 16:02:57 +01:00
} )
2024-04-08 14:17:13 +01:00
2024-03-09 10:52:08 +00:00
handle . Get ( "/model/epoch/update" , func ( c * Context ) * Error {
f := c . R . URL . Query ( )
2023-10-12 12:08:12 +01:00
2023-10-22 23:02:39 +01:00
accuracy := 0.0
2023-10-21 00:26:52 +01:00
2023-10-22 23:02:39 +01:00
if ! CheckId ( f , "model_id" ) || ! CheckId ( f , "definition" ) || CheckEmpty ( f , "epoch" ) || ! CheckFloat64 ( f , "accuracy" , & accuracy ) {
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Invalid: model_id or definition or epoch or accuracy" )
2023-10-19 10:44:13 +01:00
}
2023-10-12 12:08:12 +01:00
2023-10-22 23:02:39 +01:00
accuracy = accuracy * 100
2023-10-21 00:26:52 +01:00
2023-10-12 12:08:12 +01:00
model_id := f . Get ( "model_id" )
def_id := f . Get ( "definition" )
2023-10-19 10:44:13 +01:00
epoch , err := strconv . Atoi ( f . Get ( "epoch" ) )
if err != nil {
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Epoch is not a number" )
2023-10-19 10:44:13 +01:00
}
rows , err := c . Db . Query ( "select md.status from model_definition as md where md.model_id=$1 and md.id=$2" , model_id , def_id )
if err != nil {
return c . Error500 ( err )
}
defer rows . Close ( )
if ! rows . Next ( ) {
c . Logger . Error ( "Could not get status of model definition" )
return c . Error500 ( nil )
}
var status int
err = rows . Scan ( & status )
if err != nil {
return c . Error500 ( err )
}
if status != 3 {
c . Logger . Warn ( "Definition not on status 3(training)" , "status" , status )
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Definition not on status 3(training)" )
2023-10-19 10:44:13 +01:00
}
2024-04-08 17:45:32 +01:00
c . Logger . Debug ( "Updated model_definition!" , "model" , model_id , "progress" , epoch , "accuracy" , accuracy )
2023-10-21 00:26:52 +01:00
_ , err = c . Db . Exec ( "update model_definition set epoch_progress=$1, accuracy=$2 where id=$3" , epoch , accuracy , def_id )
2023-10-19 10:44:13 +01:00
if err != nil {
return c . Error500 ( err )
}
2024-04-08 17:45:32 +01:00
2024-04-14 14:51:16 +01:00
c . ShowMessage = false
2023-10-12 12:08:12 +01:00
return nil
} )
2024-02-05 16:42:23 +00:00
2024-03-09 10:52:08 +00:00
handle . Get ( "/model/head/epoch/update" , func ( c * Context ) * Error {
f := c . R . URL . Query ( )
2024-02-05 16:42:23 +00:00
accuracy := 0.0
if ! CheckId ( f , "head_id" ) || CheckEmpty ( f , "epoch" ) || ! CheckFloat64 ( f , "accuracy" , & accuracy ) {
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Invalid: model_id or definition or epoch or accuracy" )
2024-02-05 16:42:23 +00:00
}
accuracy = accuracy * 100
head_id := f . Get ( "head_id" )
epoch , err := strconv . Atoi ( f . Get ( "epoch" ) )
if err != nil {
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Epoch is not a number" )
2024-02-05 16:42:23 +00:00
}
rows , err := c . Db . Query ( "select hd.status from exp_model_head as hd where hd.id=$1;" , head_id )
if err != nil {
return c . Error500 ( err )
}
defer rows . Close ( )
if ! rows . Next ( ) {
c . Logger . Error ( "Could not get status of model head" )
return c . Error500 ( nil )
}
var status int
err = rows . Scan ( & status )
if err != nil {
return c . Error500 ( err )
}
if status != 3 {
c . Logger . Warn ( "Head not on status 3(training)" , "status" , status )
2024-03-09 10:52:08 +00:00
return c . JsonBadRequest ( "Head not on status 3(training)" )
2024-02-05 16:42:23 +00:00
}
2024-04-08 17:45:32 +01:00
c . Logger . Debug ( "Updated model_head!" , "head" , head_id , "progress" , epoch , "accuracy" , accuracy )
2024-02-05 16:42:23 +00:00
_ , err = c . Db . Exec ( "update exp_model_head set epoch_progress=$1, accuracy=$2 where id=$3" , epoch , accuracy , head_id )
if err != nil {
return c . Error500 ( err )
}
2024-04-08 17:45:32 +01:00
2024-04-14 14:51:16 +01:00
c . ShowMessage = false
2024-02-05 16:42:23 +00:00
return nil
} )
2023-09-26 20:15:28 +01:00
}