more work on the expadable models
This commit is contained in:
parent
d08a0a2a4c
commit
70b4141223
5
auto_reload.sh
Executable file
5
auto_reload.sh
Executable file
@ -0,0 +1,5 @@
|
||||
#!/bin/fish
|
||||
|
||||
cd $(dirname "$0")
|
||||
|
||||
go run .
|
@ -17,6 +17,7 @@ import (
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
const EPOCH_PER_RUN = 20
|
||||
@ -198,6 +199,151 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
|
||||
return
|
||||
}
|
||||
|
||||
func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
||||
accuracy = 0
|
||||
|
||||
c.Logger.Warn("About to start training definition")
|
||||
|
||||
// Get untrained models heads
|
||||
|
||||
// Status = 2 (INIT)
|
||||
rows, err := c.Db.Query("select id, range_start, range_end exp_model_head where def_id=$1 and status = 2", definition_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type ExpHead struct {
|
||||
id string
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
exp := ExpHead{}
|
||||
|
||||
if rows.Next() {
|
||||
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err == nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.Error("Failed to get the exp head of the model")
|
||||
err = errors.New("Failed to get the exp head of the model")
|
||||
return
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
log.Error("This training function can only train one model at the time")
|
||||
err = errors.New("This training function can only train one model at the time")
|
||||
return
|
||||
}
|
||||
|
||||
layers, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer layers.Close()
|
||||
|
||||
type layerrow struct {
|
||||
LayerType int
|
||||
Shape string
|
||||
}
|
||||
|
||||
got := []layerrow{}
|
||||
|
||||
for layers.Next() {
|
||||
var row = layerrow{}
|
||||
if err = layers.Scan(&row.LayerType, &row.Shape); err != nil {
|
||||
return
|
||||
}
|
||||
row.Shape = shapeToSize(row.Shape)
|
||||
got = append(got, row)
|
||||
}
|
||||
|
||||
got = append(got, layerrow{
|
||||
LayerType: LAYER_DENSE,
|
||||
Shape: fmt.Sprintf("%d", exp.end - exp.start),
|
||||
})
|
||||
|
||||
// Generate run folder
|
||||
run_path := path.Join("/tmp", model.Id, "defs", definition_id)
|
||||
|
||||
err = os.MkdirAll(run_path, os.ModePerm)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer os.RemoveAll(run_path)
|
||||
|
||||
_, err = generateCvs(c, run_path, model.Id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO update the run script
|
||||
|
||||
// Create python script
|
||||
f, err := os.Create(path.Join(run_path, "run.py"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Copy result around
|
||||
result_path := path.Join("savedData", model.Id, "defs", definition_id)
|
||||
|
||||
if err = tmpl.Execute(f, AnyMap{
|
||||
"Layers": got,
|
||||
"Size": got[0].Shape,
|
||||
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
|
||||
"RunPath": run_path,
|
||||
"ColorMode": model.ImageMode,
|
||||
"Model": model,
|
||||
"EPOCH_PER_RUN": EPOCH_PER_RUN,
|
||||
"DefId": definition_id,
|
||||
"LoadPrev": load_prev,
|
||||
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
|
||||
"SaveModelPath": path.Join(getDir(), result_path),
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Run the command
|
||||
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput()
|
||||
if err != nil {
|
||||
c.Logger.Debug(string(out))
|
||||
return
|
||||
}
|
||||
|
||||
c.Logger.Info("Python finished running")
|
||||
|
||||
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer accuracy_file.Close()
|
||||
|
||||
accuracy_file_bytes, err := io.ReadAll(accuracy_file)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
||||
return
|
||||
}
|
||||
|
||||
func remove[T interface{}](lst []T, i int) []T {
|
||||
lng := len(lst)
|
||||
if i >= lng {
|
||||
@ -474,7 +620,7 @@ func trainModelExp(c *Context, model *BaseModel) {
|
||||
var toRemove ToRemoveList = []int{}
|
||||
for i, def := range definitions {
|
||||
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
|
||||
accuracy, err := trainDefinition(c, model, def.id, !firstRound)
|
||||
accuracy, err := trainDefinitionExp(c, model, def.id, !firstRound)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to train definition!Err:", "err", err)
|
||||
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||
@ -1019,8 +1165,11 @@ func handleTrain(handle *Handle) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if model_type_id == 2 {
|
||||
go trainModelExp(c, model)
|
||||
} else {
|
||||
go trainModel(c, model)
|
||||
}
|
||||
|
||||
_, err = c.Db.Exec("update models set status = $1, model_type = $2 where id = $3", TRAINING, model_type_id, model.Id)
|
||||
if err != nil {
|
||||
|
@ -636,7 +636,7 @@ func NewHandler(db *sql.DB) *Handle {
|
||||
}
|
||||
|
||||
func (x Handle) Startup() {
|
||||
fmt.Printf("Starting up!\n")
|
||||
log.Info("Starting up!\n")
|
||||
|
||||
port := os.Getenv("PORT")
|
||||
if port == "" {
|
||||
|
Loading…
Reference in New Issue
Block a user