feat: closes #24

This commit is contained in:
Andre Henriques 2023-10-03 19:02:02 +01:00
parent f68f0e2444
commit 84b9c40a72
4 changed files with 26 additions and 38 deletions

View File

@ -110,7 +110,7 @@ func handleRun(handle *Handle) {
root := tg.NewRoot() root := tg.NewRoot()
tf_img := ReadPNG(root, path.Join(run_path, "img.png"), 3) tf_img := ReadPNG(root, path.Join(run_path, "img.png"), int64(model.ImageMode))
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{}) exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
inputImage, err:= tf.NewTensor(exec_results[0].Value()) inputImage, err:= tf.NewTensor(exec_results[0].Value())
@ -130,8 +130,6 @@ func handleRun(handle *Handle) {
vi := 0 vi := 0
var predictions = results[0].Value().([][]float32)[0] var predictions = results[0].Value().([][]float32)[0]
fmt.Println(predictions)
for i, v := range predictions { for i, v := range predictions {
if v > vmax { if v > vmax {
vi = i vi = i
@ -155,7 +153,6 @@ func handleRun(handle *Handle) {
var name string var name string
if err = rows.Scan(&name); err != nil { return nil } if err = rows.Scan(&name); err != nil { return nil }
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{ LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
"Model": model, "Model": model,
"Result": name, "Result": name,

View File

@ -19,27 +19,12 @@ import (
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) { func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
id = "" id = ""
_, err = db.Exec("insert into model_definition (model_id, target_accuracy) values ($1, $2);", model_id, target_accuracy) rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
if err != nil { if err != nil { return }
return defer rows.Close()
} if !rows.Next() { return id, errors.New("Something wrong!") }
rows, err := db.Query("select id from model_definition where model_id=$1 order by created_on DESC;", model_id)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return id, errors.New("Something wrong!")
}
err = rows.Scan(&id) err = rows.Scan(&id)
if err != nil { return
return
}
return
} }
type ModelDefinitionStatus int type ModelDefinitionStatus int
@ -103,7 +88,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
return return
} }
func trainDefinition(handle *Handle, model_id string, definition_id string) (accuracy float64, err error) { func trainDefinition(handle *Handle, model *BaseModel, definition_id string) (accuracy float64, err error) {
accuracy = 0 accuracy = 0
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil { if err != nil {
@ -128,14 +113,14 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
} }
// Generate run folder // Generate run folder
run_path := path.Join("/tmp", model_id, "defs", definition_id) run_path := path.Join("/tmp", model.Id, "defs", definition_id)
err = os.MkdirAll(run_path, os.ModePerm) err = os.MkdirAll(run_path, os.ModePerm)
if err != nil { if err != nil {
return return
} }
_, err = generateCvs(handle, run_path, model_id) _, err = generateCvs(handle, run_path, model.Id)
if err != nil { return } if err != nil { return }
// Create python script // Create python script
@ -153,8 +138,9 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
if err = tmpl.Execute(f, AnyMap{ if err = tmpl.Execute(f, AnyMap{
"Layers": got, "Layers": got,
"Size": got[0].Shape, "Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model_id, "data"), "DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"RunPath": run_path, "RunPath": run_path,
"ColorMode": model.ImageMode,
}); err != nil { }); err != nil {
return return
} }
@ -167,7 +153,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
} }
// Copy result around // Copy result around
result_path := path.Join("savedData", model_id, "defs", definition_id) result_path := path.Join("savedData", model.Id, "defs", definition_id)
if err = os.MkdirAll(result_path, os.ModePerm); err != nil { if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
return return
@ -235,7 +221,7 @@ func trainModel(handle *Handle, model *BaseModel) {
} }
for _, def := range definitions { for _, def := range definitions {
accuracy, err := trainDefinition(handle, model.Id, def.id) accuracy, err := trainDefinition(handle, model, def.id)
if err != nil { if err != nil {
fmt.Printf("Failed to train definition!Err:\n") fmt.Printf("Failed to train definition!Err:\n")
fmt.Println(err) fmt.Println(err)

View File

@ -10,8 +10,9 @@ type BaseModel struct {
Status int Status int
Id string Id string
Width int ImageMode int
Height int Width int
Height int
} }
const ( const (
@ -30,7 +31,7 @@ const (
var ModelNotFoundError = errors.New("Model not found error") var ModelNotFoundError = errors.New("Model not found error")
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
rows, err := db.Query("select name, status, id, width, height from models where id=$1;", id) rows, err := db.Query("select name, status, id, width, height, color_mode from models where id=$1;", id)
if err != nil { if err != nil {
return return
} }
@ -41,10 +42,17 @@ func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
} }
base = &BaseModel{} base = &BaseModel{}
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height) var colorMode string
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch colorMode {
case "greyscale":
base.ImageMode = 1
default:
panic("unkown color mode")
}
return return
} }

View File

@ -32,10 +32,7 @@ def pathToLabel(path):
#return tf.strings.as_string([path]) #return tf.strings.as_string([path])
def decode_image(img): def decode_image(img):
# channels were reduced to 1 since image is grayscale img = tf.io.decode_png(img, channels={{.ColorMode}})
# TODO chnage channel number based if grayscale
img = tf.io.decode_png(img, channels=3)
return tf.image.resize(img, image_size) return tf.image.resize(img, image_size)
def process_path(path): def process_path(path):